# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch from ...utils import log # Flash Attention imports try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True except Exception as e: FLASH_ATTN_3_AVAILABLE = False try: import flash_attn FLASH_ATTN_2_AVAILABLE = True except Exception as e: FLASH_ATTN_2_AVAILABLE = False # Sage Attention imports try: from sageattention import sageattn @torch.compiler.disable() def sageattn_func(q, k, v, attn_mask=None, dropout_p=0, is_causal=False, tensor_layout="HND"): if not (q.dtype == k.dtype == v.dtype): return sageattn(q, k.to(q.dtype), v.to(q.dtype), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout) elif q.dtype == torch.float32: return sageattn(q.to(torch.float16), k.to(torch.float16), v.to(torch.float16), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout).to(torch.float32) else: return sageattn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, tensor_layout=tensor_layout) except Exception as e: log.warning(f"Warning: Could not load sageattention: {str(e)}") if isinstance(e, ModuleNotFoundError): log.warning("sageattention package is not installed, sageattention will not be available") elif isinstance(e, ImportError) and "DLL" in str(e): log.warning("sageattention DLL loading error, sageattention will not be available") sageattn_func = None try: from sageattn3 import sageattn3_blackwell as sageattn_blackwell except: try: from sageattn import sageattn_blackwell except: SAGE3_AVAILABLE = False try: from sageattention import sageattn_varlen @torch.compiler.disable() def sageattn_varlen_func(q, k, v, q_lens, k_lens, max_seqlen_q, max_seqlen_k, dropout_p=0, is_causal=False): cu_seqlens_q = torch.tensor([0] + list(torch.cumsum(torch.tensor(q_lens), dim=0)), device=q.device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0] + list(torch.cumsum(torch.tensor(k_lens), dim=0)), device=q.device, dtype=torch.int32) if not (q.dtype == k.dtype == v.dtype): return sageattn_varlen(q, k.to(q.dtype), v.to(q.dtype), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal) elif q.dtype == torch.float32: return sageattn_varlen(q.to(torch.float16), k.to(torch.float16), v.to(torch.float16), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal).to(torch.float32) else: return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=dropout_p, is_causal=is_causal) except: sageattn_varlen_func = None __all__ = [ 'flash_attention', 'attention', ] def flash_attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, version=None, ): """ q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. causal: bool. Whether to apply causal attention mask. window_size: (left right). If not (-1, -1), apply sliding window local attention. deterministic: bool. If True, slightly slower and uses more memory. dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. """ half_dtypes = (torch.float16, torch.bfloat16) #assert dtype in half_dtypes #assert q.device.type == 'cuda' and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype def half(x): return x if x.dtype in half_dtypes else x.to(dtype) # preprocess query if q_lens is None: q = half(q.flatten(0, 1)) q_lens = torch.tensor( [lq] * b, dtype=torch.int32).to( device=q.device, non_blocking=True) else: q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) # preprocess key, value if k_lens is None: k = half(k.flatten(0, 1)) v = half(v.flatten(0, 1)) k_lens = torch.tensor( [lk] * b, dtype=torch.int32).to( device=k.device, non_blocking=True) else: k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) q = q.to(v.dtype) k = k.to(v.dtype) if q_scale is not None: q = q * q_scale if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: log.warning('Flash attention 3 is not available, use flash attention 2 instead.') # apply attention if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: # Note: dropout_p, window_size are not supported in FA3 now. x = flash_attn_interface.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), seqused_q=None, seqused_k=None, max_seqlen_q=lq, max_seqlen_k=lk, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic)[0].unflatten(0, (b, lq)) else: assert FLASH_ATTN_2_AVAILABLE x = flash_attn.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 0, dtype=torch.int32).to(q.device, non_blocking=True), max_seqlen_q=lq, max_seqlen_k=lk, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) # output return x.type(out_dtype) def attention( q, k, v, q_lens=None, k_lens=None, max_seqlen_q=None, max_seqlen_k=None, dropout_p=0., softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, attention_mode='sdpa', attn_mask=None, ): if "flash" in attention_mode: if attention_mode == 'flash_attn_2': fa_version = 2 elif attention_mode == 'flash_attn_3': fa_version = 3 return flash_attention( q=q, k=k, v=v, q_lens=q_lens, k_lens=k_lens, dropout_p=dropout_p, softmax_scale=softmax_scale, q_scale=q_scale, causal=causal, window_size=window_size, deterministic=deterministic, dtype=dtype, version=fa_version, ) elif attention_mode == 'sdpa': if not (q.dtype == k.dtype == v.dtype): return torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2).to(q.dtype), v.transpose(1, 2).to(q.dtype), attn_mask=attn_mask).transpose(1, 2).contiguous() return torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=attn_mask).transpose(1, 2).contiguous() elif attention_mode == 'sageattn_3': return sageattn_blackwell( q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), per_block_mean=False #seems necessary for reasonable VRAM usage, not sure of other implications ).transpose(1,2).contiguous() elif attention_mode == 'sageattn_varlen': return sageattn_varlen_func( q,k,v, q_lens=q_lens, k_lens=k_lens, max_seqlen_k=max_seqlen_k, max_seqlen_q=max_seqlen_q ) else: return sageattn_func(q, k, v, tensor_layout="NHD").contiguous()