File size: 1,427 Bytes
a2f57c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
def is_hopper_gpu():
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability(0)
major, minor = device_capability
return major == 9
return False
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
"""
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
Args:
head_dim (int): Size of the head dimension.
block_size (int): Size of the block in the attention matrix.
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
Returns:
tuple: (num_warps, num_stages) recommended values.
"""
# Determine if head_dim and block_size exceed 64
head_large = head_dim > 64
block_large = block_size > 64
if is_hopper_gpu:
# Hopper GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 4
num_stages = 3
else:
num_warps = 2
num_stages = 2
else:
# Ampere GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 8
num_stages = 3
else:
num_warps = 2
num_stages = 2
return num_warps, num_stages
|