# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128. # - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) # - varlen # - sliding window # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) # Features not supported yet: # - split (i.e. FlashDecoding) # - tuned block sizes # - paged KV # - append KV to existing KV cache # - FP8 # - bwd pass optimized for Hopper/Blackwell import os import math from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, Callable import torch import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Int32, Float32 from .quack.compile_utils import make_fake_tensor as fake_tensor from .cache_utils import get_jit_cache from .testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: from . import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() from . import utils from . import fa_logging from .cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) from .flash_fwd import FlashAttentionForwardSm80 from .flash_fwd_sm90 import FlashAttentionForwardSm90 from .flash_fwd_sm100 import FlashAttentionForwardSm100 from .flash_fwd_sm120 import FlashAttentionForwardSm120 from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess from .flash_bwd import FlashAttentionBackwardSm80 from .flash_bwd_sm90 import FlashAttentionBackwardSm90 from .flash_bwd_sm100 import FlashAttentionBackwardSm100 from .flash_bwd_sm120 import FlashAttentionBackwardSm120 from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess from .flash_fwd_combine import FlashAttentionForwardCombine from .block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, ) def _parse_arch_str(arch_str): """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" import re match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str) if not match: raise ValueError(f"Invalid arch format: {arch_str}") major, minor, _ = match.groups() return int(major) * 10 + int(minor) @lru_cache(maxsize=None) def _get_device_arch(): """Cached device arch check. Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation target (CUTE_DSL_ARCH). For CPU-only compilation (no GPU), set both: FLASH_ATTENTION_ARCH=sm_80 (kernel selection) CUTE_DSL_ARCH=sm_80 (compilation target) """ arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None) if arch_override is not None: return _parse_arch_str(arch_override) major, minor = torch.cuda.get_device_capability() return major * 10 + int(minor) def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 if compute_capability == 9: assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. " f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." ) @dataclass(frozen=True) class FwdConfig: m_block_size: int n_block_size: int mma_pv_is_rs: bool intra_wg_overlap: bool def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): """Return FwdConfig for SM90 forward. Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM). When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the optimal tile_m=192 is used when compatible, otherwise we fall back to 128. """ if head_dim <= 64: # C++: 192×192 non-causal, 192×128 causal/local. # Python: 192×128 RS+OL is consistently best across seqlens. if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: return FwdConfig(128, 128, True, True) return FwdConfig(192, 128, True, True) elif head_dim <= 96: # C++: 192×144 noRS+OL for all cases. # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS). # noRS+OL is always required. Causal: 192×128 slightly better short seqlen. if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: return FwdConfig(128, 128, False, True) if is_causal or is_local: return FwdConfig(192, 128, False, True) else: return FwdConfig(192, 144, False, True) elif head_dim <= 128: return FwdConfig(128, 128, True, True) elif head_dim <= 192: tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112) return FwdConfig(128, tile_n, True, True) else: # hdim 256 tile_n = 64 if is_local else 80 return FwdConfig(128, tile_n, True, True) @dataclass(frozen=True) class BwdConfig: m_block_size: int n_block_size: int num_stages_Q: int num_stages_dO: int num_stages_PdS: int SdP_swapAB: bool dKV_swapAB: bool dQ_swapAB: bool AtomLayoutMSdP: int AtomLayoutNdKV: int AtomLayoutMdQ: int num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128) dQ_single_wg: bool = False def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None): """Return BwdConfig for SM90. Configs based on C++ FA3 hopper/flash_bwd_launch_template.h, benchmarked on H100 SXM. """ if head_dim <= 64: # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2 return BwdConfig( m_block_size=128, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2, ) elif head_dim <= 96: # C++ FA3: 64, 128, 96, dQ_swapAB=False return BwdConfig( m_block_size=64, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, dQ_single_wg=True, ) elif head_dim <= 128: # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB is_causal_or_local = causal or local m_block_size = 64 if is_causal_or_local else 80 if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0: m_block_size = 64 return BwdConfig( m_block_size=m_block_size, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=m_block_size % 64 != 0, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, ) elif head_dim <= 192: hdimv128 = head_dim_v <= 128 if hdimv128: return BwdConfig( m_block_size=64, n_block_size=96, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, num_wg=2, ) else: return BwdConfig( m_block_size=64, n_block_size=96, num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, num_wg=2, ) else: # hdim 256 return BwdConfig( m_block_size=64, n_block_size=64, num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, ) def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA" torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, } def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if num_n_blocks <= 4: return 1 # NOTE: We should revisit this heuristic after persistence is supported for split KV. # Sometimes, it's ideal to over-schedule splits for better efficiency. return min(num_SMs // total_mblocks, max_splits, num_n_blocks) def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): """Resolve causal/local/window settings into canonical form. Returns (causal, local, window_size_left, window_size_right). """ if mask_mod is not None: return False, False, window_size_left, window_size_right if causal: window_size_right = 0 if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: window_size_left = None window_size_right = None if window_size_left is not None or window_size_right is not None: if window_size_left is None and window_size_right == 0: causal, local = True, False window_size_right = None else: causal, local = False, True else: local = False return causal, local, window_size_left, window_size_right def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, learnable_sink: Optional[torch.Tensor] = None, tile_mn: Optional[Tuple[int, int]] = None, mma_pv_is_rs: Optional[bool] = None, intra_wg_overlap: Optional[bool] = None, num_threads: int = 384, num_splits: int = 1, pack_gqa: Optional[bool] = None, _arch: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. Args: ... score_mod: A callable that takes the attention scores and applies a modification. mask_mod: A callable that takes token position information and selectively masks block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate The returned LSE supports taking gradient. out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 seqlen_q = None total_q = q.shape[0] if page_table is not None: assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" assert page_table.dtype == torch.int32, "page_table must be int32" assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" max_num_pages_per_seq = page_table.shape[1] assert page_table.shape == (batch_size, max_num_pages_per_seq) num_pages, page_size = k.shape[:2] seqlen_k = num_pages * page_size else: num_pages, page_size = None, None seqlen_k = k.shape[-3] num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] if cu_seqlens_k is None: if page_table is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: assert k.shape == (num_pages, page_size, num_head_kv, head_dim) assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), ( "cu_seqlens_k must have shape (batch_size + 1,)" ) if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), ( "cu_seqlens_q must have shape (batch_size + 1,)" ) assert seqused_q is None or seqused_q.shape == (batch_size,), ( "seqused_q must have shape (batch_size,)" ) assert seqused_k is None or seqused_k.shape == (batch_size,), ( "seqused_k must have shape (batch_size,)" ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: assert t.dtype == torch.int32, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" ) assert t.stride(0) == 1, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" ) if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" if not is_fake_mode(): assert all( t is None or t.is_cuda for t in ( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink, ) ), "inputs must be on CUDA device" arch = _get_device_arch() if _arch is None else _arch assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" alignment = 16 // q.element_size() if arch // 10 not in [8, 12]: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 out_torch_dtype = q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) requires_grad = q.requires_grad or k.requires_grad or v.requires_grad if out is None: out = torch.empty( *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device ) else: _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device) if lse is None: lse = ( torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad or return_lse else None ) elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] use_block_sparsity = block_sparse_tensors is not None causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right, mask_mod ) requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() requested_disable_2cta = utils._get_disable_2cta_default() current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) if arch // 10 in [8, 12]: num_threads = 128 fwd_cfg = FwdConfig(128, 128, True, True) # default if tile_mn is None: if arch // 10 == 12: # SM120 tile sizes tuned for 99 KB SMEM capacity: # D<=64: 128x128 → 48 KB (good occupancy) # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) if head_dim <= 64: fwd_cfg = FwdConfig(128, 128, True, True) else: fwd_cfg = FwdConfig(128, 64, True, True) elif arch // 10 == 8: fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune elif arch // 10 == 9: sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) else: fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size if mma_pv_is_rs is None: mma_pv_is_rs = fwd_cfg.mma_pv_is_rs if intra_wg_overlap is None: intra_wg_overlap = fwd_cfg.intra_wg_overlap # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 m_block_size_effective = q_stage * tile_m seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective total_mblocks = batch_size * num_head_kv * num_m_blocks num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count if num_splits < 1: num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: if num_n_blocks >= 64: tile_n = 64 num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) else: num_splits = 1 is_split_kv = num_splits > 1 if is_split_kv: out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) use_2cta_instrs = ( arch // 10 in [10, 11] and not requested_disable_2cta and not causal and not local and not is_split_kv and cu_seqlens_q is None and seqused_q is None and not use_block_sparsity and page_size in [None, 128] and int(math.ceil(head_dim / 16) * 16) in [128, 192] and int(math.ceil(head_dim_v / 16) * 16) == 128 and seqlen_q_packgqa > 2 * tile_m and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) if mask_mod is not None: if is_varlen: raise NotImplementedError( "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." ) if use_block_sparsity: if is_varlen: raise NotImplementedError( "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." ) # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: pack_gqa = False if is_split_kv: raise NotImplementedError( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." ) # See get_broadcast_dims for why this is needed in compile key block_sparse_broadcast_pattern = None normalized_block_sparse_tensors = None q_subtile_factor = None if block_sparse_tensors is not None: if seqlen_q is None: raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, q_subtile_factor, ) = normalize_block_sparse_config( block_sparse_tensors, batch_size=batch_size, num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, block_size=(tile_m, tile_n), q_stage=q_stage, ) if aux_tensors is not None: aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) else: aux_tensor_metadata = None compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, score_mod_hash, mask_mod_hash, use_block_sparsity, block_sparse_broadcast_pattern, aux_tensor_metadata, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, window_size_left is not None, window_size_right is not None, learnable_sink is not None, tile_m, tile_n, q_stage, num_threads, is_split_kv, pack_gqa, arch, page_size not in [None, tile_n], # paged KV non-TMA use_2cta_instrs, q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, requested_use_clc_scheduler, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: ( cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor, ) = [ to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] page_table_tensor = ( to_cute_tensor(page_table, assumed_align=4, leading_dim=1) if page_table is not None else None ) q_tensor, k_tensor, v_tensor, o_tensor = [ to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial) ] if is_split_kv: lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) elif lse is not None: lse_tensor = to_cute_tensor(lse, assumed_align=4) else: lse_tensor = None sparse_tensors = None if normalized_block_sparse_tensors is not None: sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None aux_tensor_metadata = None if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] if arch // 10 == 8: assert page_table is None, "paged KV not supported on SM 8.0" assert not is_split_kv, "SplitKV not supported on SM 8.0" fa_fwd = FlashAttentionForwardSm80( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, num_stages=1, num_threads=num_threads, Q_in_regs=False, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, ) elif arch // 10 == 9: assert not is_split_kv, "SplitKV not supported on SM 9.0" fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, intra_wg_overlap=intra_wg_overlap, mma_pv_is_rs=mma_pv_is_rs, mask_mod=mask_mod, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, q_subtile_factor=q_subtile_factor, paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, m_block_size=tile_m, n_block_size=tile_n, q_stage=q_stage, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None and not is_split_kv, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, tile_n], is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, use_clc_scheduler=requested_use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" assert page_table is None, "Paged KV not supported on SM 12.0 in this PR" assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR" fa_fwd = FlashAttentionForwardSm120( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, num_stages=1, num_threads=num_threads, Q_in_regs=False, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, ) else: raise ValueError( f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, window_size_left, window_size_right, learnable_sink_tensor, sparse_tensors, cute_aux_tensors, current_stream, options="--enable-tvm-ffi", ) # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: # - Use those fake metadata to populate compilation cache # - Return "fake" output tensors, which could be needed in follow-up fake operations # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): _flash_attn_fwd.compile_cache[compile_key]( q.detach(), k.detach(), v.detach(), out.detach() if not is_split_kv else out_partial, lse_partial if is_split_kv else lse, softmax_scale, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, window_size_left, window_size_right, learnable_sink, normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, aux_tensors, ) if is_split_kv: _flash_attn_fwd_combine( out_partial, lse_partial.transpose(-1, -2), out, lse.transpose(-1, -2) if lse is not None else None, cu_seqlens_q, seqused_q, ) return out, lse _flash_attn_fwd.compile_cache = get_jit_cache("fwd") def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): sym = cute.sym_int # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8 # For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4 div = 128 // dtype.width # 8 for fp16/bf16 # Shared sym_ints for dimensions that must match across tensors b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym() h_kv = h_q if not has_gqa else sym() seqlen_q_rounded, seqlen_k_rounded = sym(), sym() seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym() total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym() total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym() b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,) b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,) mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) if not varlen_q: mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1) mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4) else: mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1) mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4) if not has_gqa: mdKaccum, mdVaccum = None, None else: if not varlen_k: mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4) mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4) else: mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4) mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4) return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum def _compile_bwd_preprocess( dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False ) batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None compile_key = ( dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) if not is_fake_mode(): _bwd_preprocess.compile_cache[compile_key]( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse ) _bwd_preprocess.compile_cache = get_jit_cache("bwd_pre") def _compile_bwd_postprocess( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, has_cuseqlens_q, has_seqused_q, use_2cta_instrs, cluster_size, arch, ): """Compile bwd postprocess kernel using cute fake tensors.""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False ) batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=use_2cta_instrs, cluster_size=cluster_size, ) return cute.compile( fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _bwd_postprocess_convert( accum, output, scale, cu_seqlens, seqused, arch, dtype, hdim, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=False, cluster_size=1, ): """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" compile_key = ( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, cu_seqlens is not None, seqused is not None, use_2cta_instrs, cluster_size, arch, ) if compile_key not in _bwd_postprocess_convert.compile_cache: _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) if not is_fake_mode(): _bwd_postprocess_convert.compile_cache[compile_key]( accum, output, scale, cu_seqlens, seqused, ) _bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post") def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, dout: torch.Tensor, lse: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, pack_gqa: bool = False, num_stages_Q: int = 2, num_stages_dO: int = 2, SdP_swapAB: bool = False, dKV_swapAB: bool = False, dQ_swapAB: bool = False, AtomLayoutMSdP: int = 2, AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 2, V_in_regs: bool = False, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, deterministic: bool = False, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, aux_tensors: Optional[list[torch.Tensor]] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, dlse: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: arch = _get_device_arch() assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" sparse_q = None if block_sparse_tensors is not None and arch // 10 == 9: sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128 num_head, head_dim = q.shape[-2:] head_dim_v = v.shape[-1] causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right ) if arch // 10 == 12: # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps). m_block_size = 64 n_block_size = 64 if head_dim <= 64: num_stages_Q = 2 num_stages_dO = 2 else: num_stages_Q = 1 num_stages_dO = 1 SdP_swapAB = False dKV_swapAB = False dQ_swapAB = False AtomLayoutMSdP = 4 AtomLayoutNdKV = 4 AtomLayoutMdQ = 4 V_in_regs = False cluster_size = 1 use_2cta_instrs = False num_threads = 128 assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0" assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0" assert mask_mod is None, "mask_mod backward not supported on SM 12.0" assert deterministic is False, "deterministic backward not supported on SM 12.0" elif arch // 10 == 9: cfg = _tile_size_bwd_sm90( head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q, ) m_block_size = cfg.m_block_size n_block_size = cfg.n_block_size num_stages_Q = cfg.num_stages_Q num_stages_dO = cfg.num_stages_dO num_stages_PdS = cfg.num_stages_PdS SdP_swapAB = cfg.SdP_swapAB dKV_swapAB = cfg.dKV_swapAB dQ_swapAB = cfg.dQ_swapAB AtomLayoutMSdP = cfg.AtomLayoutMSdP AtomLayoutNdKV = cfg.AtomLayoutNdKV AtomLayoutMdQ = cfg.AtomLayoutMdQ num_threads = (cfg.num_wg + 1) * 128 dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 use_2cta_instrs = False is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) else: m_block_size = 128 n_block_size = 128 dQ_swapAB = False dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 requested_disable_2cta = utils._get_disable_2cta_default() disable_2cta = ( requested_disable_2cta or score_mod is not None or score_mod_bwd is not None or mask_mod is not None or block_sparse_tensors is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 total_q = q.shape[0] seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q if cu_seqlens_k is None: batch_size, seqlen_k = k.shape[:2] total_k = batch_size * seqlen_k else: batch_size = cu_seqlens_k.shape[0] - 1 total_k = k.shape[0] seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k num_head_kv = k.shape[-2] use_block_sparsity = block_sparse_tensors is not None subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size num_n_blocks = seqlen_k_rounded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: seqlen_k_rounded = seqlen_k_rounded + n_block_size if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: assert k.shape == (total_k, num_head_kv, head_dim) assert v.shape == (total_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), ( "cu_seqlens_k must have shape (batch_size + 1,)" ) if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), ( "cu_seqlens_q must have shape (batch_size + 1,)" ) assert out.shape == (total_q, num_head, head_dim_v) assert dout.shape == (total_q, num_head, head_dim_v) assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" else: assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert lse.shape == (batch_size, num_head, seqlen_q), ( "lse must have shape (batch_size, num_head, seqlen_q)" ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( "inputs must have the same dtype" ) for t in [cu_seqlens_q, cu_seqlens_k]: if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" if dlse is not None: dlse = maybe_contiguous(dlse) if not is_fake_mode(): assert all( t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" alignment = 16 // q.element_size() if arch // 10 != 12: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) device = q.device out_torch_dtype = q.dtype if dq is None: dq = torch.empty_like(q) else: _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device) if dk is None: dk = torch.empty_like(k) else: _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device) if dv is None: dv = torch.empty_like(v) else: _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device) head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: dq_accum = torch.empty( batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device, ) dpsum = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device ) lse_log2 = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device ) else: total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) dq_accum = torch.empty( num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses # ragged TMA tensors for direct store, so no longer needs accum+postprocess. dKV_postprocess = qhead_per_kvhead > 1 if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: dk_accum = torch.zeros( batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device, ) dv_accum = torch.zeros( batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device, ) else: cluster_tile_n = cluster_size * n_block_size total_k_rounded_padded = ( (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n ) dk_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device, ) dv_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_v_rounded, dtype=torch.float32, device=device, ) dtype = torch2cute_dtype_map[q.dtype] current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) if deterministic: dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device) else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) else: dK_semaphore = None dV_semaphore = None # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). if arch // 10 not in [9, 12]: num_threads = 384 # Backward kernel: compute dk, dv, dq_accum. score_mod_hash = utils.hash_callable(score_mod) if score_mod else False score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False num_aux_tensors = len(aux_tensors) if aux_tensors else 0 cute_aux_tensors = None if aux_tensors is not None: cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] block_sparse_broadcast_pattern = None normalized_block_sparse_tensors = None if block_sparse_tensors is not None: ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, ) = normalize_block_sparse_config_bwd( block_sparse_tensors, batch_size=batch_size, num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, block_size=(m_block_size, n_block_size), subtile_factor=subtile_factor, ) if arch // 10 in [8, 9, 12]: compile_key = ( arch, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, window_size_left is not None, window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, dQ_single_wg, deterministic, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, num_aux_tensors, use_block_sparsity, block_sparse_broadcast_pattern, get_broadcast_dims(q), get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), ) else: compile_key = ( arch, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, window_size_left is not None, window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, num_threads, pack_gqa, cluster_size, use_2cta_instrs, deterministic, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, num_aux_tensors, use_block_sparsity, block_sparse_broadcast_pattern, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, get_broadcast_dims(q), get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) ] dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) ] if dKV_postprocess: dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) ] cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) if t is not None else None for t in (dQ_semaphore, dK_semaphore, dV_semaphore) ] if arch // 10 in [8, 12]: flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80 fa_bwd_obj = flash_bwd_obj_cls( dtype, head_dim, head_dim_v, qhead_per_kvhead, m_block_size, n_block_size, num_stages_Q, num_stages_dO, num_threads, pack_gqa, causal, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, is_local=local, deterministic=deterministic, tile_m=m_block_size, tile_n=n_block_size, Q_stage=num_stages_Q, dO_stage=num_stages_dO, PdS_stage=num_stages_PdS, SdP_swapAB=SdP_swapAB, dKV_swapAB=dKV_swapAB, dQ_swapAB=dQ_swapAB, AtomLayoutMSdP=AtomLayoutMSdP, AtomLayoutNdKV=AtomLayoutNdKV, AtomLayoutMdQ=AtomLayoutMdQ, num_threads=num_threads, V_in_regs=V_in_regs, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, dQ_single_wg=dQ_single_wg, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( head_dim, head_dim_v, is_causal=causal, is_local=local, qhead_per_kvhead=qhead_per_kvhead, tile_m=m_block_size, tile_n=n_block_size, cluster_size=cluster_size, use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, ) # Block sparse tensors for backward use Q-direction indexing (transposed from forward). sparse_tensors_compile = None if normalized_block_sparse_tensors is not None: sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( fa_bwd_obj, q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, dq_accum_tensor, dk_tensor if not dKV_postprocess else dk_accum_tensor, dv_tensor if not dKV_postprocess else dv_accum_tensor, softmax_scale, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, current_stream, options="--enable-tvm-ffi", ) if not is_fake_mode(): _flash_attn_bwd.compile_cache[compile_key]( q.detach(), k.detach(), v.detach(), dout, lse_log2, dpsum, dq_accum, dk if not dKV_postprocess else dk_accum, dv if not dKV_postprocess else dv_accum, softmax_scale, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, dK_semaphore, dV_semaphore, aux_tensors, normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, ) if arch // 10 == 9: # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 num_threads_post_dKV = cfg.num_wg * 128 else: num_threads_post_dQ = 128 num_threads_post_dKV = 128 # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 _bwd_postprocess_convert( dq_accum, dq, softmax_scale, cu_seqlens_q, seqused_q, arch, dtype, head_dim, m_block_size, num_threads_post_dQ, AtomLayoutMdQ, dQ_swapAB, use_2cta_instrs=use_2cta_instrs, cluster_size=1, ) if dKV_postprocess: # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 _bwd_postprocess_convert( dk_accum, dk, softmax_scale, cu_seqlens_k, seqused_k, arch, dtype, head_dim, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, ) # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 _bwd_postprocess_convert( dv_accum, dv, 1.0, cu_seqlens_k, seqused_k, arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, ) return dq, dk, dv _flash_attn_bwd.compile_cache = get_jit_cache("bwd") class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, return_lse: bool = False, ): # Only create block sparse tensors if at least one block sparse parameter is provided block_sparse_tensors = None if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): block_sparse_tensors = BlockSparseTensorsTorch( full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, mask_block_cnt=mask_block_cnt, mask_block_idx=mask_block_idx, block_size=block_size, ) out, lse = _flash_attn_fwd( q, k, v, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, mask_mod=mask_mod, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.return_lse = return_lse ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse = ctx.saved_tensors if not ctx.return_lse: dlse = None if dout is None: dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, v, out, dout, lse, ctx.softmax_scale, ctx.causal, ctx.softcap, window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, dlse=dlse, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): out, lse = _flash_attn_fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, score_mod=score_mod, aux_tensors=aux_tensors, return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.return_lse = return_lse ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors assert ctx.softcap == 0.0 if not ctx.return_lse: dlse = None if dout is None: dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, v, out, dout, lse, ctx.softmax_scale, ctx.causal, ctx.softcap, window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, dlse=dlse, ) return dq, dk, dv, *((None,) * 20) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, return_lse: bool = False, ): return FlashAttnFunc.apply( q, k, v, softmax_scale, causal, window_size, learnable_sink, softcap, num_splits, pack_gqa, deterministic, mask_mod, full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, block_size, return_lse, ) def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, softmax_scale, causal, window_size, learnable_sink, softcap, num_splits, pack_gqa, deterministic, score_mod, aux_tensors, return_lse, ) def _compile_fwd_combine( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, ): """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" sym = cute.sym_int div = 128 // dtype_partial.width # 16-byte alignment in elements fa_combine = FlashAttentionForwardCombine( dtype=dtype, dtype_partial=dtype_partial, head_dim=head_dim, tile_m=tile_m, k_block_size=k_block_size, log_max_splits=log_max_splits, ) if not fa_combine.can_implement( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, num_threads=256, ): raise RuntimeError( "FlashAttention combine kernel cannot be implemented with given parameters" ) if has_cu_seqlens: # Varlen: (num_splits, total_q, nheads, headdim) num_splits, total_q, nheads = sym(), sym(), sym() mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div) mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1) mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div) mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None else: # Batched: (num_splits, batch, seqlen, nheads, headdim) num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym() mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div) mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2) mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div) mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None batch = mO_partial.shape[1] batch_for_1d = batch if not has_cu_seqlens else sym() batchp1 = sym() mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None mNumSplitsDynamic = None # Not parametrized in compile_key mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None mSemaphore = None # Not parametrized in compile_key return cute.compile( fa_combine, mO_partial, mLSE_partial, mO, mLSE, mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _flash_attn_fwd_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor, lse: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, varlen_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. Combines partial outputs and log-sum-exp values from multiple splits of attention computation into final outputs. Args: out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim) if there's cu_seqlens lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads) if there's cu_seqlens out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch num_splits_dynamic_ptr: Dynamic number of splits per batch semaphore_to_reset: Semaphore for synchronization k_block_size: Block size for head dimension Returns: None """ assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( "out_partial must be fp16, bf16, or fp32" ) if not is_fake_mode(): assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 # Validate optional tensors for t, name in [ (cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), ]: if t is not None: if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA device" assert t.is_contiguous(), f"{name} must be contiguous" head_dim = out_partial.shape[-1] num_splits = out_partial.shape[0] assert num_splits <= 256 # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively # so that kBlockM is smaller and we have more parallelism. k_block_size = 64 if head_dim <= 64 else 128 # We want kBlockM to be as small as possible to maximize parallelism. # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) log_max_splits = max(math.ceil(math.log2(num_splits)), 4) if tile_m == 8: # If kBlockM == 8 then the minimum number of splits is 32. # TODO: we can deal w this by using 128 threads instead log_max_splits = max(log_max_splits, 5) # Create combine kernel configuration dtype = torch2cute_dtype_map[out.dtype] dtype_partial = torch2cute_dtype_map[out_partial.dtype] compile_key = ( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, cu_seqlens is not None, seqused is not None, lse is not None, varlen_batch_idx is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( *compile_key ) if not is_fake_mode(): _flash_attn_fwd_combine.compile_cache[compile_key]( out_partial, lse_partial, out, lse, cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, semaphore_to_reset, ) _flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine") def flash_attn_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, varlen_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. Combines partial outputs and log-sum-exp values from multiple splits of attention computation into final outputs. This is the main user-facing interface for the combine kernel. Args: out_partial: Partial outputs tensor with shape: - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input - (num_splits, total_q, num_heads, head_size) for variable length input lse_partial: Partial LSE tensor with shape: - (num_splits, batch_size, seqlen, num_heads) for regular batched input - (num_splits, total_q, num_heads) for variable length input out: Optional output tensor. If None, will be created automatically. out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch varlen_batch_idx: Optional mapping from virtual batch index to real batch index (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. Returns: Tuple of (out, lse) where: - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) or (total_q, num_heads, head_size) for varlen - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) or (total_q, num_heads) for varlen. None if return_lse=False Note: This function expects the input tensors to be in the format produced by split attention computation, where the first dimension is num_splits. The permuting from user format to kernel format is now done inside the kernel. """ # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape # Determine output dtype if out_dtype is None: out_dtype = out_partial.dtype # Create output if not provided device = out_partial.device if out is None: if is_varlen: out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) else: out = torch.empty( batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device ) # Create lse output only if requested if return_lse: if is_varlen: lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device) else: lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device) lse = lse.transpose(-1, -2) else: lse = None _flash_attn_fwd_combine( out_partial, lse_partial, out, lse, cu_seqlens, seqused, varlen_batch_idx=varlen_batch_idx, ) return out, lse