import torch def is_hopper_gpu(): if torch.cuda.is_available(): device_capability = torch.cuda.get_device_capability(0) major, minor = device_capability return major == 9 return False def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): """ Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. Args: head_dim (int): Size of the head dimension. block_size (int): Size of the block in the attention matrix. is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. Returns: tuple: (num_warps, num_stages) recommended values. """ # Determine if head_dim and block_size exceed 64 head_large = head_dim > 64 block_large = block_size > 64 if is_hopper_gpu: # Hopper GPU recommendations if head_large and block_large: num_warps = 8 num_stages = 3 elif head_large or block_large: num_warps = 4 num_stages = 3 else: num_warps = 2 num_stages = 2 else: # Ampere GPU recommendations if head_large and block_large: num_warps = 8 num_stages = 3 elif head_large or block_large: num_warps = 8 num_stages = 3 else: num_warps = 2 num_stages = 2 return num_warps, num_stages