|
|
import torch |
|
|
|
|
|
|
|
|
def is_hopper_gpu(): |
|
|
if torch.cuda.is_available(): |
|
|
device_capability = torch.cuda.get_device_capability(0) |
|
|
major, minor = device_capability |
|
|
return major == 9 |
|
|
return False |
|
|
|
|
|
|
|
|
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): |
|
|
""" |
|
|
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. |
|
|
|
|
|
Args: |
|
|
head_dim (int): Size of the head dimension. |
|
|
block_size (int): Size of the block in the attention matrix. |
|
|
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. |
|
|
|
|
|
Returns: |
|
|
tuple: (num_warps, num_stages) recommended values. |
|
|
""" |
|
|
|
|
|
head_large = head_dim > 64 |
|
|
block_large = block_size > 64 |
|
|
|
|
|
if is_hopper_gpu: |
|
|
|
|
|
if head_large and block_large: |
|
|
num_warps = 8 |
|
|
num_stages = 3 |
|
|
elif head_large or block_large: |
|
|
num_warps = 4 |
|
|
num_stages = 3 |
|
|
else: |
|
|
num_warps = 2 |
|
|
num_stages = 2 |
|
|
else: |
|
|
|
|
|
if head_large and block_large: |
|
|
num_warps = 8 |
|
|
num_stages = 3 |
|
|
elif head_large or block_large: |
|
|
num_warps = 8 |
|
|
num_stages = 3 |
|
|
else: |
|
|
num_warps = 2 |
|
|
num_stages = 2 |
|
|
return num_warps, num_stages |
|
|
|