|
|
|
|
|
import torch |
|
|
from ...utils import log |
|
|
|
|
|
|
|
|
try: |
|
|
import flash_attn_interface |
|
|
FLASH_ATTN_3_AVAILABLE = True |
|
|
except Exception as e: |
|
|
FLASH_ATTN_3_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
FLASH_ATTN_2_AVAILABLE = True |
|
|
except Exception as e: |
|
|
FLASH_ATTN_2_AVAILABLE = False |
|
|
|
|
|
|
|
|
try: |
|
|
from sageattention import sageattn |
|
|
@torch.compiler.disable() |
|
|
def sageattn_func(q, k, v, attn_mask=None, dropout_p=0, is_causal=False, tensor_layout="HND"): |
|
|
if not (q.dtype == k.dtype == v.dtype): |
|
|
return sageattn(q, k.to(q.dtype), v.to(q.dtype), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout) |
|
|
elif q.dtype == torch.float32: |
|
|
return sageattn(q.to(torch.float16), k.to(torch.float16), v.to(torch.float16), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout).to(torch.float32) |
|
|
else: |
|
|
return sageattn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout) |
|
|
except Exception as e: |
|
|
log.warning(f"Warning: Could not load sageattention: {str(e)}") |
|
|
if isinstance(e, ModuleNotFoundError): |
|
|
log.warning("sageattention package is not installed, sageattention will not be available") |
|
|
elif isinstance(e, ImportError) and "DLL" in str(e): |
|
|
log.warning("sageattention DLL loading error, sageattention will not be available") |
|
|
sageattn_func = None |
|
|
|
|
|
try: |
|
|
from sageattn3 import sageattn3_blackwell as sageattn_blackwell |
|
|
except: |
|
|
try: |
|
|
from sageattn import sageattn_blackwell |
|
|
except: |
|
|
SAGE3_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from sageattention import sageattn_varlen |
|
|
@torch.compiler.disable() |
|
|
def sageattn_varlen_func(q, k, v, q_lens, k_lens, max_seqlen_q, max_seqlen_k, dropout_p=0, is_causal=False): |
|
|
cu_seqlens_q = torch.tensor([0] + list(torch.cumsum(torch.tensor(q_lens), dim=0)), device=q.device, dtype=torch.int32) |
|
|
cu_seqlens_k = torch.tensor([0] + list(torch.cumsum(torch.tensor(k_lens), dim=0)), device=q.device, dtype=torch.int32) |
|
|
if not (q.dtype == k.dtype == v.dtype): |
|
|
return sageattn_varlen(q, k.to(q.dtype), v.to(q.dtype), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal) |
|
|
elif q.dtype == torch.float32: |
|
|
return sageattn_varlen(q.to(torch.float16), k.to(torch.float16), v.to(torch.float16), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal).to(torch.float32) |
|
|
else: |
|
|
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal) |
|
|
except: |
|
|
sageattn_varlen_func = None |
|
|
|
|
|
__all__ = [ |
|
|
'flash_attention', |
|
|
'attention', |
|
|
] |
|
|
|
|
|
|
|
|
def flash_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
q_lens=None, |
|
|
k_lens=None, |
|
|
dropout_p=0., |
|
|
softmax_scale=None, |
|
|
q_scale=None, |
|
|
causal=False, |
|
|
window_size=(-1, -1), |
|
|
deterministic=False, |
|
|
dtype=torch.bfloat16, |
|
|
version=None, |
|
|
): |
|
|
""" |
|
|
q: [B, Lq, Nq, C1]. |
|
|
k: [B, Lk, Nk, C1]. |
|
|
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. |
|
|
q_lens: [B]. |
|
|
k_lens: [B]. |
|
|
dropout_p: float. Dropout probability. |
|
|
softmax_scale: float. The scaling of QK^T before applying softmax. |
|
|
causal: bool. Whether to apply causal attention mask. |
|
|
window_size: (left right). If not (-1, -1), apply sliding window local attention. |
|
|
deterministic: bool. If True, slightly slower and uses more memory. |
|
|
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. |
|
|
""" |
|
|
half_dtypes = (torch.float16, torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype |
|
|
|
|
|
def half(x): |
|
|
return x if x.dtype in half_dtypes else x.to(dtype) |
|
|
|
|
|
|
|
|
if q_lens is None: |
|
|
q = half(q.flatten(0, 1)) |
|
|
q_lens = torch.tensor( |
|
|
[lq] * b, dtype=torch.int32).to( |
|
|
device=q.device, non_blocking=True) |
|
|
else: |
|
|
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) |
|
|
|
|
|
|
|
|
if k_lens is None: |
|
|
k = half(k.flatten(0, 1)) |
|
|
v = half(v.flatten(0, 1)) |
|
|
k_lens = torch.tensor( |
|
|
[lk] * b, dtype=torch.int32).to( |
|
|
device=k.device, non_blocking=True) |
|
|
else: |
|
|
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) |
|
|
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) |
|
|
|
|
|
q = q.to(v.dtype) |
|
|
k = k.to(v.dtype) |
|
|
|
|
|
if q_scale is not None: |
|
|
q = q * q_scale |
|
|
|
|
|
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: |
|
|
log.warning('Flash attention 3 is not available, use flash attention 2 instead.') |
|
|
|
|
|
|
|
|
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: |
|
|
|
|
|
x = flash_attn_interface.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( |
|
|
0, dtype=torch.int32).to(q.device, non_blocking=True), |
|
|
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( |
|
|
0, dtype=torch.int32).to(q.device, non_blocking=True), |
|
|
seqused_q=None, |
|
|
seqused_k=None, |
|
|
max_seqlen_q=lq, |
|
|
max_seqlen_k=lk, |
|
|
softmax_scale=softmax_scale, |
|
|
causal=causal, |
|
|
deterministic=deterministic)[0].unflatten(0, (b, lq)) |
|
|
else: |
|
|
assert FLASH_ATTN_2_AVAILABLE |
|
|
x = flash_attn.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( |
|
|
0, dtype=torch.int32).to(q.device, non_blocking=True), |
|
|
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( |
|
|
0, dtype=torch.int32).to(q.device, non_blocking=True), |
|
|
max_seqlen_q=lq, |
|
|
max_seqlen_k=lk, |
|
|
dropout_p=dropout_p, |
|
|
softmax_scale=softmax_scale, |
|
|
causal=causal, |
|
|
window_size=window_size, |
|
|
deterministic=deterministic).unflatten(0, (b, lq)) |
|
|
|
|
|
|
|
|
return x.type(out_dtype) |
|
|
|
|
|
|
|
|
def attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
q_lens=None, |
|
|
k_lens=None, |
|
|
max_seqlen_q=None, |
|
|
max_seqlen_k=None, |
|
|
dropout_p=0., |
|
|
softmax_scale=None, |
|
|
q_scale=None, |
|
|
causal=False, |
|
|
window_size=(-1, -1), |
|
|
deterministic=False, |
|
|
dtype=torch.bfloat16, |
|
|
attention_mode='sdpa', |
|
|
attn_mask=None, |
|
|
): |
|
|
if "flash" in attention_mode: |
|
|
if attention_mode == 'flash_attn_2': |
|
|
fa_version = 2 |
|
|
elif attention_mode == 'flash_attn_3': |
|
|
fa_version = 3 |
|
|
return flash_attention( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
q_lens=q_lens, |
|
|
k_lens=k_lens, |
|
|
dropout_p=dropout_p, |
|
|
softmax_scale=softmax_scale, |
|
|
q_scale=q_scale, |
|
|
causal=causal, |
|
|
window_size=window_size, |
|
|
deterministic=deterministic, |
|
|
dtype=dtype, |
|
|
version=fa_version, |
|
|
) |
|
|
elif attention_mode == 'sdpa': |
|
|
if not (q.dtype == k.dtype == v.dtype): |
|
|
return torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2).to(q.dtype), v.transpose(1, 2).to(q.dtype), attn_mask=attn_mask).transpose(1, 2).contiguous() |
|
|
return torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=attn_mask).transpose(1, 2).contiguous() |
|
|
elif attention_mode == 'sageattn_3': |
|
|
return sageattn_blackwell( |
|
|
q.transpose(1,2), |
|
|
k.transpose(1,2), |
|
|
v.transpose(1,2), |
|
|
per_block_mean=False |
|
|
).transpose(1,2).contiguous() |
|
|
elif attention_mode == 'sageattn_varlen': |
|
|
return sageattn_varlen_func( |
|
|
q,k,v, |
|
|
q_lens=q_lens, |
|
|
k_lens=k_lens, |
|
|
max_seqlen_k=max_seqlen_k, |
|
|
max_seqlen_q=max_seqlen_q |
|
|
) |
|
|
else: |
|
|
return sageattn_func(q, k, v, tensor_layout="NHD").contiguous() |
|
|
|