GLM2NSA / utils.py
Maxtimer97's picture
Flattened repo
a2f57c7
import torch
def is_hopper_gpu():
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability(0)
major, minor = device_capability
return major == 9
return False
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
"""
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
Args:
head_dim (int): Size of the head dimension.
block_size (int): Size of the block in the attention matrix.
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
Returns:
tuple: (num_warps, num_stages) recommended values.
"""
# Determine if head_dim and block_size exceed 64
head_large = head_dim > 64
block_large = block_size > 64
if is_hopper_gpu:
# Hopper GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 4
num_stages = 3
else:
num_warps = 2
num_stages = 2
else:
# Ampere GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 8
num_stages = 3
else:
num_warps = 2
num_stages = 2
return num_warps, num_stages