Instructions to use kernels-community/flash-attn4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn4 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn4") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. | |
| # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. | |
| # Supported features: | |
| # - BF16 & FP16 dtype | |
| # - noncausal & causal attention | |
| # - MHA, GQA, MQA | |
| # - hdim 64, 96, 128. | |
| # - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) | |
| # - varlen | |
| # - sliding window | |
| # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) | |
| # Features not supported yet: | |
| # - split (i.e. FlashDecoding) | |
| # - tuned block sizes | |
| # - paged KV | |
| # - append KV to existing KV cache | |
| # - FP8 | |
| # - bwd pass optimized for Hopper/Blackwell | |
| import os | |
| import math | |
| from functools import lru_cache | |
| from typing import Optional, Tuple, Callable | |
| import torch | |
| import cuda.bindings.driver as cuda | |
| import cutlass | |
| import cutlass.cute as cute | |
| from .cache_utils import get_jit_cache | |
| from .testing import is_fake_mode | |
| if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: | |
| from . import cute_dsl_ptxas # noqa: F401 | |
| # Patch to dump ptx and then use system ptxas to compile to cubin | |
| cute_dsl_ptxas.patch() | |
| from . import utils | |
| from .cute_dsl_utils import ( | |
| to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, | |
| ) | |
| from .flash_fwd import FlashAttentionForwardSm90 | |
| from .flash_fwd_sm100 import FlashAttentionForwardSm100 | |
| from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess | |
| from .flash_bwd import FlashAttentionBackwardSm80 | |
| from .flash_bwd_sm90 import FlashAttentionBackwardSm90 | |
| from .flash_bwd_sm100 import FlashAttentionBackwardSm100 | |
| from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess | |
| from .flash_fwd_combine import FlashAttentionForwardCombine | |
| from .block_sparsity import ( | |
| BlockSparseTensorsTorch, | |
| to_cute_block_sparse_tensors, | |
| normalize_block_sparse_config, | |
| normalize_block_sparse_config_bwd, | |
| ) | |
| def _get_device_arch(): | |
| """Cached device arch check.""" | |
| major, minor = torch.cuda.get_device_capability() | |
| return major * 10 + minor | |
| def maybe_contiguous(x): | |
| return x.contiguous() if x is not None and x.stride(-1) != 1 else x | |
| def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): | |
| assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" | |
| assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" | |
| assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" | |
| assert t.is_cuda, f"{name} must be on CUDA" | |
| torch2cute_dtype_map = { | |
| torch.float16: cutlass.Float16, | |
| torch.bfloat16: cutlass.BFloat16, | |
| torch.float32: cutlass.Float32, | |
| } | |
| def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): | |
| # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. | |
| if num_n_blocks <= 4: | |
| return 1 | |
| # NOTE: We should revisit this heuristic after persistence is supported for split KV. | |
| # Sometimes, it's ideal to over-schedule splits for better efficiency. | |
| return min(num_SMs // total_mblocks, max_splits, num_n_blocks) | |
| def _flash_attn_fwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| seqused_q: Optional[torch.Tensor] = None, | |
| seqused_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| page_table: Optional[torch.Tensor] = None, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| softcap: Optional[float] = None, | |
| window_size_left: Optional[int] = None, | |
| window_size_right: Optional[int] = None, | |
| learnable_sink: Optional[torch.Tensor] = None, | |
| # m_block_size: int = 128, | |
| # n_block_size: int = 64, | |
| # num_threads: int = 128, | |
| m_block_size: int = 128, | |
| n_block_size: int = 128, | |
| num_threads: int = 384, | |
| num_splits: int = 1, | |
| pack_gqa: Optional[bool] = None, | |
| _arch: Optional[int] = None, | |
| score_mod: Optional[Callable] = None, | |
| mask_mod: Optional[Callable] = None, | |
| block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, | |
| return_lse: bool = False, | |
| out: Optional[torch.Tensor] = None, | |
| lse: Optional[torch.Tensor] = None, | |
| aux_tensors: Optional[list[torch.Tensor]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass for FlashAttention. | |
| Args: | |
| ... | |
| score_mod: A callable that takes the attention scores and applies a modification. | |
| mask_mod: A callable that takes token position information and selectively masks | |
| block_sparse_tensors: A tuple of tensors used for block sparsity. | |
| return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate | |
| Note: the returned LSE currently does not support taking gradient. | |
| out: Optional pre-allocated output tensor. If None, will be allocated internally. | |
| lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. | |
| aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. | |
| """ | |
| q, k, v = [maybe_contiguous(t) for t in (q, k, v)] | |
| num_head, head_dim = q.shape[-2:] | |
| if cu_seqlens_q is None: | |
| batch_size, seqlen_q = q.shape[:2] | |
| total_q = batch_size * seqlen_q | |
| else: | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| seqlen_q = None | |
| total_q = q.shape[0] | |
| if page_table is not None: | |
| assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" | |
| assert page_table.dtype == torch.int32, "page_table must be int32" | |
| assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" | |
| max_num_pages_per_seq = page_table.shape[1] | |
| assert page_table.shape == (batch_size, max_num_pages_per_seq) | |
| num_pages, page_size = k.shape[:2] | |
| seqlen_k = num_pages * page_size | |
| else: | |
| num_pages, page_size = None, None | |
| seqlen_k = k.shape[-3] | |
| num_head_kv = k.shape[-2] | |
| head_dim_v = v.shape[-1] | |
| if cu_seqlens_k is None: | |
| if page_table is None: | |
| assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) | |
| assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) | |
| else: | |
| assert k.shape == (num_pages, page_size, num_head_kv, head_dim) | |
| assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) | |
| else: | |
| assert k.shape == (seqlen_k, num_head_kv, head_dim) | |
| assert v.shape == (seqlen_k, num_head_kv, head_dim_v) | |
| assert cu_seqlens_k.shape == (batch_size + 1,), ( | |
| "cu_seqlens_k must have shape (batch_size + 1,)" | |
| ) | |
| if cu_seqlens_q is not None: | |
| assert cu_seqlens_q.shape == (batch_size + 1,), ( | |
| "cu_seqlens_q must have shape (batch_size + 1,)" | |
| ) | |
| assert seqused_q is None or seqused_q.shape == (batch_size,), ( | |
| "seqused_q must have shape (batch_size,)" | |
| ) | |
| assert seqused_k is None or seqused_k.shape == (batch_size,), ( | |
| "seqused_k must have shape (batch_size,)" | |
| ) | |
| assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" | |
| assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" | |
| for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: | |
| if t is not None: | |
| assert t.dtype == torch.int32, ( | |
| "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" | |
| ) | |
| assert t.stride(0) == 1, ( | |
| "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" | |
| ) | |
| if learnable_sink is not None: | |
| assert learnable_sink.shape == (num_head,) | |
| assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" | |
| assert all( | |
| t is None or t.is_cuda | |
| for t in ( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_q, | |
| seqused_k, | |
| page_table, | |
| learnable_sink, | |
| ) | |
| ), "inputs must be on CUDA device" | |
| assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" | |
| assert head_dim <= 256, "head_dim must be less than or equal to 256" | |
| alignment = 16 // q.element_size() | |
| assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" | |
| assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" | |
| if softmax_scale is None: | |
| softmax_scale = 1.0 / math.sqrt(head_dim) | |
| if softcap == 0.0: | |
| softcap = None | |
| qhead_per_kvhead = num_head // num_head_kv | |
| if pack_gqa is None: | |
| pack_gqa = qhead_per_kvhead > 1 | |
| out_torch_dtype = q.dtype | |
| device = q.device | |
| q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) | |
| lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) | |
| requires_grad = q.requires_grad or k.requires_grad or v.requires_grad | |
| if out is None: | |
| out = torch.empty( | |
| *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device | |
| ) | |
| else: | |
| _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device) | |
| if lse is None: | |
| lse = ( | |
| torch.empty(lse_shape, dtype=torch.float32, device=device) | |
| if requires_grad or return_lse | |
| else None | |
| ) | |
| elif lse is not None: | |
| _validate_tensor(lse, "lse", lse_shape, torch.float32, device) | |
| dtype = torch2cute_dtype_map[q.dtype] | |
| arch = _get_device_arch() if _arch is None else _arch | |
| assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" | |
| use_block_sparsity = block_sparse_tensors is not None | |
| if mask_mod is None: | |
| if causal: | |
| window_size_right = 0 | |
| if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: | |
| window_size_left = None | |
| window_size_right = None | |
| local = window_size_left is not None or window_size_right is not None | |
| if window_size_left is not None or window_size_right is not None: | |
| if window_size_left is None and window_size_right == 0: | |
| causal, local = True, False | |
| window_size_right = None | |
| else: | |
| causal, local = False, True | |
| else: | |
| causal, local = False, False | |
| current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | |
| if arch // 10 == 9: # TODO: tune block size according to hdim. | |
| if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: | |
| n_block_size = 192 | |
| if arch // 10 in [10, 11]: | |
| if ( | |
| pack_gqa | |
| and (128 % qhead_per_kvhead != 0) | |
| ): | |
| pack_gqa = False | |
| # TODO: fix GQA + SplitKV + non-varlen | |
| if pack_gqa and num_splits != 1 and cu_seqlens_q is None: | |
| pack_gqa = False | |
| if max_seqlen_q is None: | |
| max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q | |
| if max_seqlen_k is None: | |
| max_seqlen_k = seqlen_k | |
| seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead | |
| if arch // 10 == 10: | |
| q_stage = 2 if seqlen_q_packgqa > m_block_size else 1 | |
| else: | |
| q_stage = 1 | |
| if num_splits < 1: | |
| m_block_size_effective = q_stage * m_block_size | |
| seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) | |
| num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size | |
| num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective | |
| total_mblocks = batch_size * num_head_kv * num_m_blocks | |
| num_splits = num_splits_heuristic( | |
| total_mblocks, | |
| torch.cuda.get_device_properties(device).multi_processor_count, | |
| num_n_blocks, | |
| 128, | |
| ) | |
| is_split_kv = num_splits > 1 | |
| if is_split_kv: | |
| out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) | |
| lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) | |
| # hash score and mask mods for compile cache | |
| score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False | |
| mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False | |
| if softcap is not None: | |
| assert score_mod is None, "softcap and score_mod cannot be used together" | |
| score_mod = utils.create_softcap_scoremod(softcap) | |
| is_varlen = ( | |
| cu_seqlens_q is not None | |
| or cu_seqlens_k is not None | |
| or seqused_q is not None | |
| or seqused_k is not None | |
| ) | |
| if mask_mod is not None: | |
| if is_varlen: | |
| raise NotImplementedError( | |
| "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." | |
| ) | |
| if use_block_sparsity: | |
| if is_varlen: | |
| raise NotImplementedError( | |
| "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." | |
| ) | |
| # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) | |
| if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: | |
| pack_gqa = False | |
| if is_split_kv: | |
| raise NotImplementedError( | |
| "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." | |
| ) | |
| # See get_broadcast_dims for why this is needed in compile key | |
| block_sparse_broadcast_pattern = None | |
| normalized_block_sparse_tensors = None | |
| q_subtile_factor = None | |
| if block_sparse_tensors is not None: | |
| if seqlen_q is None: | |
| raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") | |
| ( | |
| normalized_block_sparse_tensors, | |
| block_sparse_broadcast_pattern, | |
| q_subtile_factor, | |
| ) = normalize_block_sparse_config( | |
| block_sparse_tensors, | |
| batch_size=batch_size, | |
| num_head=num_head, | |
| seqlen_q=seqlen_q, | |
| seqlen_k=seqlen_k, | |
| block_size=(m_block_size, n_block_size), | |
| q_stage=q_stage, | |
| ) | |
| if aux_tensors is not None: | |
| aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) | |
| else: | |
| aux_tensor_metadata = None | |
| compile_key = ( | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| causal, | |
| score_mod_hash, | |
| mask_mod_hash, | |
| use_block_sparsity, | |
| block_sparse_broadcast_pattern, | |
| aux_tensor_metadata, | |
| lse is None, | |
| cu_seqlens_q is None, | |
| cu_seqlens_k is None, | |
| seqused_q is None, | |
| seqused_k is None, | |
| page_table is not None, | |
| window_size_left is not None, | |
| window_size_right is not None, | |
| learnable_sink is not None, | |
| m_block_size, | |
| n_block_size, | |
| q_stage, | |
| num_threads, | |
| is_split_kv, | |
| pack_gqa, | |
| arch, | |
| page_size not in [None, 128], # paged KV non-TMA | |
| q_subtile_factor, | |
| ) | |
| if compile_key not in _flash_attn_fwd.compile_cache: | |
| ( | |
| cu_seqlens_q_tensor, | |
| cu_seqlens_k_tensor, | |
| seqused_q_tensor, | |
| seqused_k_tensor, | |
| learnable_sink_tensor, | |
| ) = [ | |
| to_cute_tensor(t, assumed_align=4, leading_dim=0) | |
| if t is not None | |
| else None | |
| for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) | |
| ] | |
| page_table_tensor = ( | |
| to_cute_tensor(page_table, assumed_align=4, leading_dim=1) | |
| if page_table is not None | |
| else None | |
| ) | |
| q_tensor, k_tensor, v_tensor, o_tensor = [ | |
| to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial) | |
| ] | |
| if is_split_kv: | |
| lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) | |
| elif lse is not None: | |
| lse_tensor = to_cute_tensor(lse, assumed_align=4) | |
| else: | |
| lse_tensor = None | |
| sparse_tensors = None | |
| if normalized_block_sparse_tensors is not None: | |
| sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) | |
| cute_aux_tensors = None | |
| aux_tensor_metadata = None | |
| if aux_tensors is not None: | |
| cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] | |
| if arch // 10 == 9: | |
| assert page_table is None, "paged KV not supported on SM 9.0" | |
| assert not is_split_kv, "SplitKV not supported on SM 9.0" | |
| # fa_fwd = FlashAttentionForwardSm80( | |
| fa_fwd = FlashAttentionForwardSm90( | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| is_causal=causal, | |
| is_local=local, | |
| pack_gqa=pack_gqa, | |
| tile_m=m_block_size, | |
| tile_n=n_block_size, | |
| # num_stages=1, | |
| num_stages=2, | |
| num_threads=num_threads, | |
| Q_in_regs=False, | |
| intra_wg_overlap=True, | |
| mma_pv_is_rs=True, | |
| mask_mod=mask_mod, | |
| score_mod=score_mod, | |
| has_aux_tensors=aux_tensors is not None, | |
| q_subtile_factor=q_subtile_factor, | |
| ) | |
| elif arch // 10 in [10, 11]: | |
| head_dim_padded = int(math.ceil(head_dim / 16) * 16) | |
| head_dim_v_padded = int(math.ceil(head_dim / 16) * 16) | |
| use_2cta_instrs = ( | |
| not causal | |
| and not local | |
| and not is_split_kv | |
| and cu_seqlens_q is None | |
| and seqused_q is None | |
| and not use_block_sparsity | |
| and page_size in [None, 128] | |
| and head_dim_padded == 128 | |
| and head_dim_v_padded == 128 | |
| ) | |
| fa_fwd = FlashAttentionForwardSm100( | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead=qhead_per_kvhead, | |
| is_causal=causal, | |
| is_local=local, | |
| is_split_kv=is_split_kv, | |
| pack_gqa=pack_gqa, | |
| m_block_size=m_block_size, | |
| n_block_size=n_block_size, | |
| q_stage=q_stage, | |
| is_persistent=not causal | |
| and not local | |
| and cu_seqlens_q is None | |
| and seqused_q is None | |
| and not is_split_kv, | |
| score_mod=score_mod, | |
| mask_mod=mask_mod, | |
| has_aux_tensors=aux_tensors is not None, | |
| paged_kv_non_tma=page_size not in [None, 128], | |
| is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, | |
| q_subtile_factor=q_subtile_factor, | |
| use_2cta_instrs=use_2cta_instrs, | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x" | |
| ) | |
| # TODO: check @can_implement | |
| _flash_attn_fwd.compile_cache[compile_key] = cute.compile( | |
| fa_fwd, | |
| q_tensor, | |
| k_tensor, | |
| v_tensor, | |
| o_tensor, | |
| lse_tensor, | |
| softmax_scale, | |
| current_stream, | |
| cu_seqlens_q_tensor, | |
| cu_seqlens_k_tensor, | |
| seqused_q_tensor, | |
| seqused_k_tensor, | |
| page_table_tensor, | |
| window_size_left, | |
| window_size_right, | |
| learnable_sink_tensor, | |
| sparse_tensors, | |
| cute_aux_tensors, | |
| options="--enable-tvm-ffi", | |
| ) | |
| # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: | |
| # - Use those fake metadata to populate compilation cache | |
| # - Return "fake" output tensors, which could be needed in follow-up fake operations | |
| # Thus, we skip the actual kernel invocation here. | |
| if not is_fake_mode(): | |
| _flash_attn_fwd.compile_cache[compile_key]( | |
| q.detach(), | |
| k.detach(), | |
| v.detach(), | |
| out.detach() if not is_split_kv else out_partial, | |
| lse_partial if is_split_kv else lse, | |
| softmax_scale, | |
| current_stream, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_q, | |
| seqused_k, | |
| page_table, | |
| window_size_left, | |
| window_size_right, | |
| learnable_sink, | |
| normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, | |
| aux_tensors, | |
| ) | |
| if is_split_kv: | |
| _flash_attn_fwd_combine( | |
| out_partial, | |
| lse_partial.transpose(-1, -2), | |
| out, | |
| lse.transpose(-1, -2) if lse is not None else None, | |
| cu_seqlens_q, | |
| seqused_q, | |
| ) | |
| return out, lse | |
| _flash_attn_fwd.compile_cache = get_jit_cache("fwd") | |
| def _flash_attn_bwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| out: torch.Tensor, | |
| dout: torch.Tensor, | |
| lse: torch.Tensor, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| softcap: float = 0.0, | |
| window_size_left: Optional[int] = None, | |
| window_size_right: Optional[int] = None, | |
| m_block_size: int = 64, | |
| n_block_size: int = 128, | |
| num_threads: int = 256, | |
| pack_gqa: bool = False, | |
| num_stages_Q: int = 2, | |
| num_stages_dO: int = 2, | |
| SdP_swapAB: bool = False, | |
| dKV_swapAB: bool = False, | |
| dQ_swapAB: bool = False, | |
| AtomLayoutMSdP: int = 2, | |
| AtomLayoutNdKV: int = 2, | |
| AtomLayoutMdQ: int = 2, | |
| V_in_regs: bool = False, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| seqused_q: Optional[torch.Tensor] = None, | |
| seqused_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| deterministic: bool = False, | |
| dq: Optional[torch.Tensor] = None, | |
| dk: Optional[torch.Tensor] = None, | |
| dv: Optional[torch.Tensor] = None, | |
| score_mod: Optional[Callable] = None, | |
| score_mod_bwd: Optional[Callable] = None, | |
| mask_mod: Optional[Callable] = None, | |
| aux_tensors: Optional[list[torch.Tensor]] = None, | |
| block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| arch = _get_device_arch() | |
| assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" | |
| num_head, head_dim = q.shape[-2:] | |
| if causal: | |
| window_size_right = 0 | |
| if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: | |
| window_size_left = None | |
| window_size_right = None | |
| local = window_size_left is not None or window_size_right is not None | |
| if local: | |
| if window_size_left is None and window_size_right == 0: | |
| causal, local = True, False | |
| window_size_right = None | |
| else: | |
| causal, local = False, True | |
| if arch // 10 == 9: | |
| m_block_size = 80 if not causal else 64 | |
| n_block_size = 128 | |
| num_stages_Q = 2 | |
| num_stages_dO = 2 | |
| num_stages_PdS = 2 | |
| SdP_swapAB = True | |
| dKV_swapAB = False | |
| dQ_swapAB = not causal | |
| AtomLayoutMSdP = 1 | |
| AtomLayoutNdKV = 2 | |
| AtomLayoutMdQ = 1 | |
| cluster_size = 1 | |
| use_2cta_instrs = False | |
| assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" | |
| is_varlen = ( | |
| cu_seqlens_q is not None | |
| or cu_seqlens_k is not None | |
| or seqused_q is not None | |
| or seqused_k is not None | |
| ) | |
| assert not is_varlen, "varlen backward is not yet supported on sm90" | |
| else: | |
| m_block_size = 128 | |
| n_block_size = 128 | |
| dQ_swapAB = False | |
| dKV_swapAB = False | |
| AtomLayoutMdQ = 1 | |
| AtomLayoutNdKV = 1 | |
| disable_2cta = ( | |
| local | |
| or score_mod is not None | |
| or score_mod_bwd is not None | |
| or mask_mod is not None | |
| ) | |
| cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 | |
| use_2cta_instrs = cluster_size==2 | |
| q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ | |
| maybe_contiguous(t) | |
| for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) | |
| ] | |
| if cu_seqlens_q is None: | |
| batch_size, seqlen_q = q.shape[:2] | |
| total_q = batch_size * seqlen_q | |
| else: | |
| batch_size = cu_seqlens_q.shape[0] - 1 | |
| total_q = q.shape[0] | |
| seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q | |
| if cu_seqlens_k is None: | |
| batch_size, seqlen_k = k.shape[:2] | |
| total_k = batch_size * seqlen_k | |
| else: | |
| batch_size = cu_seqlens_k.shape[0] - 1 | |
| total_k = k.shape[0] | |
| seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k | |
| num_head_kv = k.shape[-2] | |
| head_dim_v = v.shape[-1] | |
| use_block_sparsity = block_sparse_tensors is not None | |
| # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, | |
| # the base block_m of 128 from forward, and block-sparse size for subtiling. | |
| if arch // 10 == 9 and use_block_sparsity: | |
| m_block_size = 64 | |
| # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) | |
| dQ_swapAB = False | |
| # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 | |
| subtile_factor = 2 | |
| seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size | |
| seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size | |
| num_n_blocks = seqlen_k_rounded // n_block_size | |
| if cluster_size == 2 and num_n_blocks % cluster_size != 0: | |
| seqlen_k_rounded = seqlen_k_rounded + n_block_size | |
| if cu_seqlens_k is None: | |
| assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) | |
| assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) | |
| else: | |
| assert k.shape == (total_k, num_head_kv, head_dim) | |
| assert v.shape == (total_k, num_head_kv, head_dim_v) | |
| assert cu_seqlens_k.shape == (batch_size + 1,), ( | |
| "cu_seqlens_k must have shape (batch_size + 1,)" | |
| ) | |
| if cu_seqlens_q is not None: | |
| assert cu_seqlens_q.shape == (batch_size + 1,), ( | |
| "cu_seqlens_q must have shape (batch_size + 1,)" | |
| ) | |
| assert out.shape == (total_q, num_head, head_dim_v) | |
| assert dout.shape == (total_q, num_head, head_dim_v) | |
| assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" | |
| else: | |
| assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) | |
| assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) | |
| assert lse.shape == (batch_size, num_head, seqlen_q), ( | |
| "lse must have shape (batch_size, num_head, seqlen_q)" | |
| ) | |
| assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" | |
| assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( | |
| "inputs must have the same dtype" | |
| ) | |
| for t in [cu_seqlens_q, cu_seqlens_k]: | |
| if t is not None: | |
| assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" | |
| assert lse.dtype == torch.float32, "lse must be float32" | |
| assert all( | |
| t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) | |
| ), "inputs must be on CUDA device" | |
| assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" | |
| assert head_dim <= 256, "head_dim must be less than or equal to 256" | |
| alignment = 16 // q.element_size() | |
| assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" | |
| assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" | |
| if softmax_scale is None: | |
| softmax_scale = 1.0 / math.sqrt(head_dim) | |
| qhead_per_kvhead = num_head // num_head_kv | |
| if pack_gqa is None: | |
| pack_gqa = qhead_per_kvhead > 1 | |
| # pack_gqa backward not yet supported in bwd | |
| pack_gqa = False | |
| if arch // 10 not in [10, 11]: | |
| assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now" | |
| if score_mod is not None: | |
| assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" | |
| assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" | |
| assert cu_seqlens_q is None and cu_seqlens_k is None, ( | |
| "varlen + score_mod not supported in bwd yet" | |
| ) | |
| device = q.device | |
| out_torch_dtype = q.dtype | |
| if dq is None: | |
| dq = torch.empty_like(q) | |
| else: | |
| _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device) | |
| if dk is None: | |
| dk = torch.empty_like(k) | |
| else: | |
| _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device) | |
| if dv is None: | |
| dv = torch.empty_like(v) | |
| else: | |
| _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device) | |
| head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 | |
| if cu_seqlens_q is None: | |
| dq_accum = torch.empty( | |
| batch_size, | |
| num_head, | |
| seqlen_q_rounded * head_dim_rounded, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| dpsum = torch.empty( | |
| batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device | |
| ) | |
| lse_log2 = torch.empty( | |
| batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device | |
| ) | |
| else: | |
| total_q_rounded_padded = ( | |
| (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size | |
| ) | |
| dq_accum = torch.empty( | |
| num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device | |
| ) | |
| dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) | |
| lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) | |
| dKV_postprocess = qhead_per_kvhead > 1 | |
| if dKV_postprocess: | |
| head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 | |
| if cu_seqlens_k is None: | |
| dk_accum = torch.zeros( | |
| batch_size, | |
| num_head_kv, | |
| seqlen_k_rounded * head_dim_rounded, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| dv_accum = torch.zeros( | |
| batch_size, | |
| num_head_kv, | |
| seqlen_k_rounded * head_dim_v_rounded, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| else: | |
| cluster_tile_n = cluster_size * n_block_size | |
| total_k_rounded_padded = ( | |
| (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n | |
| ) | |
| dk_accum = torch.zeros( | |
| num_head_kv, | |
| total_k_rounded_padded * head_dim_rounded, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| dv_accum = torch.zeros( | |
| num_head_kv, | |
| total_k_rounded_padded * head_dim_v_rounded, | |
| dtype=torch.float32, | |
| device=device, | |
| ) | |
| dtype = torch2cute_dtype_map[q.dtype] | |
| current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | |
| if deterministic: | |
| dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device="cuda") | |
| else: | |
| dQ_semaphore = None | |
| if deterministic and qhead_per_kvhead > 1: | |
| dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") | |
| dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") | |
| else: | |
| dK_semaphore = None | |
| dV_semaphore = None | |
| # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. | |
| compile_key_pre = ( | |
| arch, | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| m_block_size, | |
| num_threads, | |
| cu_seqlens_q is None, | |
| seqused_q is None, | |
| get_broadcast_dims(out), | |
| get_broadcast_dims(dout), | |
| ) | |
| if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: | |
| o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] | |
| dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ | |
| to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) | |
| ] | |
| lse_tensor = to_cute_tensor(lse, assumed_align=4) | |
| cu_seqlens_q_tensor, seqused_q_tensor = [ | |
| to_cute_tensor(t, assumed_align=4) if t is not None else None | |
| for t in (cu_seqlens_q, seqused_q) | |
| ] | |
| fa_bwd_pre = FlashAttentionBackwardPreprocess( | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| arch, | |
| m_block_size, | |
| num_threads=num_threads, | |
| ) | |
| # TODO: check @can_implement | |
| _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( | |
| fa_bwd_pre, | |
| o_tensor, | |
| do_tensor, | |
| dpsum_tensor, | |
| lse_tensor, | |
| lse_log2_tensor, | |
| dq_accum_tensor, | |
| cu_seqlens_q_tensor, | |
| seqused_q_tensor, | |
| current_stream, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_bwd.compile_cache_pre[compile_key_pre]( | |
| out, | |
| dout, | |
| dpsum, | |
| lse, | |
| lse_log2, | |
| dq_accum, | |
| cu_seqlens_q, | |
| seqused_q, | |
| current_stream, | |
| ) | |
| # NB num_threads application for 3 kernels | |
| # There are pre, main, post processing kernels, currenlty num_threads is only actually | |
| # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do | |
| # before cache key gen | |
| num_threads = 384 | |
| # Backward kernel: compute dk, dv, dq_accum. | |
| score_mod_hash = utils.hash_callable(score_mod) if score_mod else False | |
| score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False | |
| mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False | |
| num_aux_tensors = len(aux_tensors) if aux_tensors else 0 | |
| cute_aux_tensors = None | |
| if aux_tensors is not None: | |
| cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] | |
| block_sparse_broadcast_pattern = None | |
| normalized_block_sparse_tensors = None | |
| if block_sparse_tensors is not None: | |
| ( | |
| normalized_block_sparse_tensors, | |
| block_sparse_broadcast_pattern, | |
| ) = normalize_block_sparse_config_bwd( | |
| block_sparse_tensors, | |
| batch_size=batch_size, | |
| num_head=num_head, | |
| seqlen_q=seqlen_q, | |
| seqlen_k=seqlen_k, | |
| block_size=(m_block_size, n_block_size), | |
| subtile_factor=subtile_factor, | |
| ) | |
| if arch // 10 == 9: | |
| compile_key = ( | |
| arch, | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| causal, | |
| softcap != 0.0, | |
| m_block_size, | |
| n_block_size, | |
| num_threads, | |
| pack_gqa, | |
| num_stages_Q, | |
| num_stages_dO, | |
| SdP_swapAB, | |
| dKV_swapAB, | |
| dQ_swapAB, | |
| AtomLayoutMSdP, | |
| AtomLayoutNdKV, | |
| AtomLayoutMdQ, | |
| V_in_regs, | |
| cu_seqlens_q is None, | |
| cu_seqlens_k is None, | |
| seqused_q is None, | |
| seqused_k is None, | |
| score_mod_hash, | |
| score_mod_bwd_hash, | |
| mask_mod_hash, | |
| num_aux_tensors, | |
| use_block_sparsity, | |
| block_sparse_broadcast_pattern, | |
| get_broadcast_dims(q), | |
| get_broadcast_dims(k), | |
| get_broadcast_dims(v), | |
| get_broadcast_dims(dout), | |
| ) | |
| else: | |
| compile_key = ( | |
| arch, | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| causal, | |
| window_size_left is not None, | |
| window_size_right is not None, | |
| softcap != 0.0, | |
| m_block_size, | |
| n_block_size, | |
| num_threads, | |
| pack_gqa, | |
| cluster_size, | |
| use_2cta_instrs, | |
| deterministic, | |
| score_mod_hash, | |
| score_mod_bwd_hash, | |
| mask_mod_hash, | |
| num_aux_tensors, | |
| use_block_sparsity, | |
| block_sparse_broadcast_pattern, | |
| cu_seqlens_q is None, | |
| cu_seqlens_k is None, | |
| seqused_q is None, | |
| seqused_k is None, | |
| get_broadcast_dims(q), | |
| get_broadcast_dims(k), | |
| get_broadcast_dims(v), | |
| get_broadcast_dims(dout), | |
| ) | |
| if compile_key not in _flash_attn_bwd.compile_cache: | |
| q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ | |
| to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) | |
| ] | |
| dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ | |
| to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) | |
| ] | |
| if dKV_postprocess: | |
| dk_accum_tensor, dv_accum_tensor = [ | |
| to_cute_tensor(t) for t in (dk_accum, dv_accum) | |
| ] | |
| cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ | |
| to_cute_tensor(t, assumed_align=4) if t is not None else None | |
| for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) | |
| ] | |
| dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ | |
| utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) | |
| if t is not None else None | |
| for t in (dQ_semaphore, dK_semaphore, dV_semaphore) | |
| ] | |
| fa_bwd_sm80 = FlashAttentionBackwardSm80( | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| m_block_size, | |
| n_block_size, | |
| num_stages_Q, | |
| num_stages_dO, | |
| num_threads, | |
| pack_gqa, | |
| causal, | |
| SdP_swapAB, | |
| dKV_swapAB, | |
| dQ_swapAB, | |
| AtomLayoutMSdP, | |
| AtomLayoutNdKV, | |
| AtomLayoutMdQ, | |
| V_in_regs=V_in_regs, | |
| ) | |
| if arch // 10 == 9: | |
| fa_bwd_obj = FlashAttentionBackwardSm90( | |
| dtype, | |
| head_dim, | |
| head_dim_v, | |
| qhead_per_kvhead, | |
| causal, | |
| m_block_size, | |
| n_block_size, | |
| num_stages_Q, | |
| num_stages_dO, | |
| num_stages_PdS, | |
| SdP_swapAB, | |
| dKV_swapAB, | |
| dQ_swapAB, | |
| AtomLayoutMSdP, | |
| AtomLayoutNdKV, | |
| AtomLayoutMdQ, | |
| num_threads, | |
| V_in_regs=V_in_regs, | |
| score_mod=score_mod, | |
| score_mod_bwd=score_mod_bwd, | |
| mask_mod=mask_mod, | |
| has_aux_tensors=aux_tensors is not None, | |
| subtile_factor=subtile_factor, | |
| ) | |
| else: | |
| fa_bwd_obj = FlashAttentionBackwardSm100( | |
| head_dim, | |
| head_dim_v, | |
| is_causal=causal, | |
| is_local=local, | |
| qhead_per_kvhead=qhead_per_kvhead, | |
| tile_m=m_block_size, | |
| tile_n=n_block_size, | |
| cluster_size=cluster_size, | |
| use_2cta_instrs=use_2cta_instrs, | |
| deterministic=deterministic, | |
| score_mod=score_mod, | |
| score_mod_bwd=score_mod_bwd, | |
| mask_mod=mask_mod, | |
| has_aux_tensors=aux_tensors is not None, | |
| subtile_factor=subtile_factor, | |
| ) | |
| # Block sparse tensors for backward use Q-direction indexing (transposed from forward). | |
| sparse_tensors_compile = None | |
| if normalized_block_sparse_tensors is not None: | |
| sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) | |
| # TODO: check @can_implement | |
| _flash_attn_bwd.compile_cache[compile_key] = cute.compile( | |
| fa_bwd_obj, | |
| q_tensor, | |
| k_tensor, | |
| v_tensor, | |
| do_tensor, | |
| lse_log2_tensor, | |
| dpsum_tensor, | |
| dq_accum_tensor, | |
| dk_tensor if not dKV_postprocess else dk_accum_tensor, | |
| dv_tensor if not dKV_postprocess else dv_accum_tensor, | |
| softmax_scale, | |
| current_stream, | |
| cu_seqlens_q_tensor, | |
| cu_seqlens_k_tensor, | |
| seqused_q_tensor, | |
| seqused_k_tensor, | |
| None, # softcap - not yet supported in backward | |
| window_size_left, | |
| window_size_right, | |
| dQ_semaphore_tensor, | |
| dK_semaphore_tensor, | |
| dV_semaphore_tensor, | |
| cute_aux_tensors, | |
| sparse_tensors_compile, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_bwd.compile_cache[compile_key]( | |
| q.detach(), | |
| k.detach(), | |
| v.detach(), | |
| dout, | |
| lse_log2, | |
| dpsum, | |
| dq_accum, | |
| dk if not dKV_postprocess else dk_accum, | |
| dv if not dKV_postprocess else dv_accum, | |
| softmax_scale, | |
| current_stream, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_q, | |
| seqused_k, | |
| None, # softcap - not yet supported in backward | |
| window_size_left, | |
| window_size_right, | |
| dQ_semaphore, | |
| dK_semaphore, | |
| dV_semaphore, | |
| aux_tensors, | |
| normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, | |
| ) | |
| num_threads = 256 if arch // 10 == 9 else 128 | |
| # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 | |
| compile_key_post = ( | |
| arch, | |
| dtype, | |
| head_dim, | |
| m_block_size, | |
| num_threads, | |
| AtomLayoutMdQ, | |
| dQ_swapAB, | |
| cu_seqlens_q is None, | |
| seqused_q is None, | |
| use_2cta_instrs, | |
| 1, # no cluster for tile_m | |
| get_broadcast_dims(dq_accum), | |
| get_broadcast_dims(dq), | |
| ) | |
| if compile_key_post not in _flash_attn_bwd.compile_cache_post: | |
| dq_accum_tensor = to_cute_tensor(dq_accum) | |
| dq_tensor = to_cute_tensor(dq) | |
| cu_seqlens_q_tensor, seqused_q_tensor = [ | |
| to_cute_tensor(t, assumed_align=4) if t is not None else None | |
| for t in (cu_seqlens_q, seqused_q) | |
| ] | |
| fa_bwd_post = FlashAttentionBackwardPostprocess( | |
| dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB, | |
| use_2cta_instrs=use_2cta_instrs, | |
| ) | |
| # TODO: check @can_implement | |
| _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( | |
| fa_bwd_post, | |
| dq_accum_tensor, | |
| dq_tensor, | |
| softmax_scale, | |
| cu_seqlens_q_tensor, | |
| seqused_q_tensor, | |
| current_stream, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_bwd.compile_cache_post[compile_key_post]( | |
| dq_accum, | |
| dq, | |
| softmax_scale, | |
| cu_seqlens_q, | |
| seqused_q, | |
| current_stream, | |
| ) | |
| if dKV_postprocess: | |
| # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 | |
| compile_key_post = ( | |
| arch, | |
| dtype, | |
| head_dim, | |
| n_block_size, | |
| num_threads, | |
| AtomLayoutNdKV, | |
| dKV_swapAB, | |
| cu_seqlens_k is None, | |
| seqused_k is None, | |
| False, # even for 2cta, is split along hdim, so always False | |
| cluster_size, # cluster is for tile_n | |
| get_broadcast_dims(dk_accum), | |
| get_broadcast_dims(dk), | |
| ) | |
| if compile_key_post not in _flash_attn_bwd.compile_cache_post: | |
| dk_accum_tensor = to_cute_tensor(dk_accum) | |
| dk_tensor = to_cute_tensor(dk) | |
| cu_seqlens_k_tensor, seqused_k_tensor = [ | |
| to_cute_tensor(t, assumed_align=4) if t is not None else None | |
| for t in (cu_seqlens_k, seqused_k) | |
| ] | |
| fa_bwd_post = FlashAttentionBackwardPostprocess( | |
| dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB, | |
| cluster_size=cluster_size, | |
| ) | |
| # TODO: check @can_implement | |
| _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( | |
| fa_bwd_post, | |
| dk_accum_tensor, | |
| dk_tensor, | |
| softmax_scale, | |
| cu_seqlens_k_tensor, | |
| seqused_k_tensor, | |
| current_stream, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_bwd.compile_cache_post[compile_key_post]( | |
| dk_accum, | |
| dk, | |
| softmax_scale, | |
| cu_seqlens_k, | |
| seqused_k, | |
| current_stream, | |
| ) | |
| compile_key_post = ( | |
| arch, | |
| dtype, | |
| head_dim_v, | |
| n_block_size, | |
| num_threads, | |
| AtomLayoutNdKV, | |
| dKV_swapAB, | |
| cu_seqlens_k is None, | |
| seqused_k is None, | |
| False, | |
| cluster_size, | |
| get_broadcast_dims(dv_accum), | |
| get_broadcast_dims(dv), | |
| ) | |
| if compile_key_post not in _flash_attn_bwd.compile_cache_post: | |
| dv_accum_tensor = to_cute_tensor(dv_accum) | |
| dv_tensor = to_cute_tensor(dv) | |
| cu_seqlens_k_tensor, seqused_k_tensor = [ | |
| to_cute_tensor(t, assumed_align=4) if t is not None else None | |
| for t in (cu_seqlens_k, seqused_k) | |
| ] | |
| fa_bwd_post = FlashAttentionBackwardPostprocess( | |
| dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB, | |
| cluster_size=cluster_size, | |
| ) | |
| # TODO: check @can_implement | |
| _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( | |
| fa_bwd_post, | |
| dv_accum_tensor, | |
| dv_tensor, | |
| cutlass.Float32(1.0), | |
| cu_seqlens_k_tensor, | |
| seqused_k_tensor, | |
| current_stream, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_bwd.compile_cache_post[compile_key_post]( | |
| dv_accum, | |
| dv, | |
| 1.0, | |
| cu_seqlens_k, | |
| seqused_k, | |
| current_stream, | |
| ) | |
| return dq, dk, dv | |
| _flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre") | |
| _flash_attn_bwd.compile_cache = get_jit_cache("bwd") | |
| _flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post") | |
| class FlashAttnFunc(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[Optional[int], Optional[int]] = (None, None), | |
| learnable_sink: Optional[torch.Tensor] = None, | |
| softcap: float = 0.0, | |
| num_splits: int = 1, | |
| pack_gqa: Optional[bool] = None, | |
| deterministic: bool = False, | |
| mask_mod: Optional[Callable] = None, | |
| full_block_cnt: Optional[torch.Tensor] = None, | |
| full_block_idx: Optional[torch.Tensor] = None, | |
| mask_block_cnt: Optional[torch.Tensor] = None, | |
| mask_block_idx: Optional[torch.Tensor] = None, | |
| block_size: Optional[Tuple[int, int]] = None, | |
| return_lse: bool = False, | |
| ): | |
| # Only create block sparse tensors if at least one block sparse parameter is provided | |
| block_sparse_tensors = None | |
| if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): | |
| block_sparse_tensors = BlockSparseTensorsTorch( | |
| full_block_cnt=full_block_cnt, | |
| full_block_idx=full_block_idx, | |
| mask_block_cnt=mask_block_cnt, | |
| mask_block_idx=mask_block_idx, | |
| block_size=block_size, | |
| ) | |
| out, lse = _flash_attn_fwd( | |
| q, | |
| k, | |
| v, | |
| softmax_scale=softmax_scale, | |
| causal=causal, | |
| window_size_left=window_size[0], | |
| window_size_right=window_size[1], | |
| learnable_sink=learnable_sink, | |
| softcap=softcap, | |
| num_splits=num_splits, | |
| pack_gqa=pack_gqa, | |
| mask_mod=mask_mod, | |
| block_sparse_tensors=block_sparse_tensors, | |
| return_lse=return_lse, | |
| ) | |
| ctx.save_for_backward(q, k, v, out, lse) | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| ctx.window_size = window_size | |
| ctx.softcap = softcap | |
| ctx.deterministic = deterministic | |
| # LSE gradient is not supported yet | |
| if lse is not None: | |
| ctx.mark_non_differentiable(lse) | |
| return out, lse | |
| def backward(ctx, dout, *args): | |
| q, k, v, out, lse = ctx.saved_tensors | |
| dq, dk, dv = _flash_attn_bwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| dout, | |
| lse, | |
| ctx.softmax_scale, | |
| ctx.causal, | |
| ctx.softcap, | |
| window_size_left=ctx.window_size[0], | |
| window_size_right=ctx.window_size[1], | |
| deterministic=ctx.deterministic, | |
| ) | |
| return dq, dk, dv, *((None,) * 20) # Extra Nones is fine | |
| class FlashAttnVarlenFunc(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor], | |
| cu_seqlens_k: Optional[torch.Tensor], | |
| seqused_q: Optional[torch.Tensor] = None, | |
| seqused_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| page_table: Optional[torch.Tensor] = None, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[Optional[int], Optional[int]] = (None, None), | |
| learnable_sink: Optional[torch.Tensor] = None, | |
| softcap: float = 0.0, | |
| num_splits: int = 1, | |
| pack_gqa: Optional[bool] = None, | |
| deterministic: bool = False, | |
| score_mod: Optional[Callable] = None, | |
| aux_tensors: Optional[list] = None, | |
| return_lse: bool = False, | |
| ): | |
| out, lse = _flash_attn_fwd( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_q, | |
| seqused_k, | |
| max_seqlen_q=max_seqlen_q, | |
| max_seqlen_k=max_seqlen_k, | |
| page_table=page_table, | |
| softmax_scale=softmax_scale, | |
| causal=causal, | |
| window_size_left=window_size[0], | |
| window_size_right=window_size[1], | |
| learnable_sink=learnable_sink, | |
| softcap=softcap, | |
| num_splits=num_splits, | |
| pack_gqa=pack_gqa, | |
| score_mod=score_mod, | |
| aux_tensors=aux_tensors, | |
| return_lse=return_lse, | |
| ) | |
| ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| ctx.window_size = window_size | |
| ctx.softcap = softcap | |
| ctx.deterministic = deterministic | |
| ctx.max_seqlen_q = max_seqlen_q | |
| ctx.max_seqlen_k = max_seqlen_k | |
| # LSE gradient is not supported yet | |
| if lse is not None: | |
| ctx.mark_non_differentiable(lse) | |
| return out, lse | |
| def backward(ctx, dout, *args): | |
| q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors | |
| assert ctx.softcap == 0.0 | |
| dq, dk, dv = _flash_attn_bwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| dout, | |
| lse, | |
| ctx.softmax_scale, | |
| ctx.causal, | |
| ctx.softcap, | |
| window_size_left=ctx.window_size[0], | |
| window_size_right=ctx.window_size[1], | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| seqused_q=seqused_q, | |
| seqused_k=seqused_k, | |
| max_seqlen_q=ctx.max_seqlen_q, | |
| max_seqlen_k=ctx.max_seqlen_k, | |
| deterministic=ctx.deterministic, | |
| ) | |
| return dq, dk, dv, *((None,) * 20) | |
| def flash_attn_func( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[Optional[int], Optional[int]] = (None, None), | |
| learnable_sink: Optional[torch.Tensor] = None, | |
| softcap: float = 0.0, | |
| num_splits: int = 1, | |
| pack_gqa: Optional[bool] = None, | |
| deterministic: bool = False, | |
| mask_mod: Optional[Callable] = None, | |
| full_block_cnt: Optional[torch.Tensor] = None, | |
| full_block_idx: Optional[torch.Tensor] = None, | |
| mask_block_cnt: Optional[torch.Tensor] = None, | |
| mask_block_idx: Optional[torch.Tensor] = None, | |
| block_size: Optional[Tuple[int, int]] = None, | |
| return_lse: bool = False, | |
| ): | |
| return FlashAttnFunc.apply( | |
| q, | |
| k, | |
| v, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| learnable_sink, | |
| softcap, | |
| num_splits, | |
| pack_gqa, | |
| deterministic, | |
| mask_mod, | |
| full_block_cnt, | |
| full_block_idx, | |
| mask_block_cnt, | |
| mask_block_idx, | |
| block_size, | |
| return_lse, | |
| ) | |
| def flash_attn_varlen_func( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| seqused_q: Optional[torch.Tensor] = None, | |
| seqused_k: Optional[torch.Tensor] = None, | |
| page_table: Optional[torch.Tensor] = None, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[Optional[int], Optional[int]] = (None, None), | |
| learnable_sink: Optional[torch.Tensor] = None, | |
| softcap: float = 0.0, | |
| num_splits: int = 1, | |
| pack_gqa: Optional[bool] = None, | |
| deterministic: bool = False, | |
| score_mod: Optional[Callable] = None, | |
| aux_tensors: Optional[list] = None, | |
| return_lse: bool = False, | |
| ): | |
| return FlashAttnVarlenFunc.apply( | |
| q, | |
| k, | |
| v, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_q, | |
| seqused_k, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| page_table, | |
| softmax_scale, | |
| causal, | |
| window_size, | |
| learnable_sink, | |
| softcap, | |
| num_splits, | |
| pack_gqa, | |
| deterministic, | |
| score_mod, | |
| aux_tensors, | |
| return_lse, | |
| ) | |
| def _flash_attn_fwd_combine( | |
| out_partial: torch.Tensor, | |
| lse_partial: torch.Tensor, | |
| out: torch.Tensor, | |
| lse: Optional[torch.Tensor] = None, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| seqused: Optional[torch.Tensor] = None, | |
| num_splits_dynamic_ptr: Optional[torch.Tensor] = None, | |
| semaphore_to_reset: Optional[torch.Tensor] = None, | |
| ) -> None: | |
| """Forward combine kernel for split attention computation. | |
| Combines partial outputs and log-sum-exp values from multiple splits | |
| of attention computation into final outputs. | |
| Args: | |
| out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or | |
| (num_splits, total_q, nheads, headdim) if there's cu_seqlens | |
| lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or | |
| (num_splits, total_q, nheads) if there's cu_seqlens | |
| out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens | |
| lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. | |
| cu_seqlens: Cumulative sequence lengths for variable length sequences | |
| seqused: Used sequence lengths for each batch | |
| num_splits_dynamic_ptr: Dynamic number of splits per batch | |
| semaphore_to_reset: Semaphore for synchronization | |
| k_block_size: Block size for head dimension | |
| Returns: | |
| None | |
| """ | |
| # Input validation | |
| assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" | |
| assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" | |
| assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( | |
| "out_partial must be fp16, bf16, or fp32" | |
| ) | |
| assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" | |
| assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" | |
| assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" | |
| assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" | |
| assert lse_partial.shape == out_partial.shape[:-1] | |
| # Determine if this is variable length based on dimensions | |
| is_varlen = out_partial.dim() == 4 | |
| # Validate output tensor shapes and types | |
| assert out.shape == out_partial.shape[1:], "out shape mismatch" | |
| if lse is not None: | |
| assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" | |
| assert lse.dtype == torch.float32, "lse must be fp32" | |
| # Validate optional tensors | |
| for t, name in [ | |
| (cu_seqlens, "cu_seqlens"), | |
| (seqused, "seqused"), | |
| (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), | |
| ]: | |
| if t is not None: | |
| assert t.dtype == torch.int32, f"{name} must be int32" | |
| assert t.is_cuda, f"{name} must be on CUDA device" | |
| assert t.is_contiguous(), f"{name} must be contiguous" | |
| head_dim = out_partial.shape[-1] | |
| num_splits = out_partial.shape[0] | |
| assert num_splits <= 256 | |
| # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively | |
| # so that kBlockM is smaller and we have more parallelism. | |
| k_block_size = 64 if head_dim <= 64 else 128 | |
| # We want kBlockM to be as small as possible to maximize parallelism. | |
| # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). | |
| m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) | |
| log_max_splits = max(math.ceil(math.log2(num_splits)), 4) | |
| if m_block_size == 8: | |
| # If kBlockM == 8 then the minimum number of splits is 32. | |
| # TODO: we can deal w this by using 128 threads instead | |
| log_max_splits = max(log_max_splits, 5) | |
| current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | |
| # Create combine kernel configuration | |
| dtype = torch2cute_dtype_map[out.dtype] | |
| dtype_partial = torch2cute_dtype_map[out_partial.dtype] | |
| compile_key = ( | |
| dtype, | |
| dtype_partial, | |
| head_dim, | |
| m_block_size, | |
| k_block_size, | |
| log_max_splits, | |
| cu_seqlens is not None, | |
| seqused is not None, | |
| lse is not None, | |
| ) | |
| if compile_key not in _flash_attn_fwd_combine.compile_cache: | |
| out_partial_tensor = to_cute_tensor( | |
| out_partial, leading_dim=4 if not is_varlen else 3 | |
| ) | |
| lse_partial_tensor = to_cute_tensor( | |
| lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 | |
| ) | |
| out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) | |
| lse_tensor = ( | |
| to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) | |
| if lse is not None | |
| else None | |
| ) | |
| optional_tensors = [ | |
| to_cute_tensor(t, assumed_align=4, leading_dim=0) | |
| if t is not None | |
| else None | |
| for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) | |
| ] | |
| cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( | |
| optional_tensors | |
| ) | |
| fa_combine = FlashAttentionForwardCombine( | |
| dtype=dtype, | |
| dtype_partial=dtype_partial, | |
| head_dim=head_dim, | |
| m_block_size=m_block_size, | |
| k_block_size=k_block_size, | |
| log_max_splits=log_max_splits, | |
| ) | |
| # Check if implementation is supported | |
| if not fa_combine.can_implement( | |
| dtype, | |
| dtype_partial, | |
| head_dim, | |
| m_block_size, | |
| k_block_size, | |
| log_max_splits, | |
| num_threads=256, | |
| ): | |
| raise RuntimeError( | |
| "FlashAttention combine kernel cannot be implemented with given parameters" | |
| ) | |
| _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( | |
| fa_combine, | |
| out_partial_tensor, | |
| lse_partial_tensor, | |
| out_tensor, | |
| lse_tensor, | |
| cu_seqlens_tensor, | |
| seqused_tensor, | |
| num_splits_dynamic_tensor, | |
| semaphore_tensor, | |
| current_stream, | |
| options="--enable-tvm-ffi", | |
| ) | |
| if not is_fake_mode(): | |
| _flash_attn_fwd_combine.compile_cache[compile_key]( | |
| out_partial, | |
| lse_partial, | |
| out, | |
| lse, | |
| cu_seqlens, | |
| seqused, | |
| num_splits_dynamic_ptr, | |
| semaphore_to_reset, | |
| current_stream, | |
| ) | |
| _flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine") | |
| def flash_attn_combine( | |
| out_partial: torch.Tensor, | |
| lse_partial: torch.Tensor, | |
| out: Optional[torch.Tensor] = None, | |
| out_dtype: Optional[torch.dtype] = None, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| seqused: Optional[torch.Tensor] = None, | |
| return_lse: bool = True, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Flash Attention combine function for split attention computation. | |
| Combines partial outputs and log-sum-exp values from multiple splits | |
| of attention computation into final outputs. This is the main user-facing | |
| interface for the combine kernel. | |
| Args: | |
| out_partial: Partial outputs tensor with shape: | |
| - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input | |
| - (num_splits, total_q, num_heads, head_size) for variable length input | |
| lse_partial: Partial LSE tensor with shape: | |
| - (num_splits, batch_size, seqlen, num_heads) for regular batched input | |
| - (num_splits, total_q, num_heads) for variable length input | |
| out: Optional output tensor. If None, will be created automatically. | |
| out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. | |
| cu_seqlens: Cumulative sequence lengths for variable length sequences | |
| seqused: Used sequence lengths for each batch | |
| return_lse: Whether to return the combined LSE tensor. Default is True. | |
| Returns: | |
| Tuple of (out, lse) where: | |
| - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) | |
| or (total_q, num_heads, head_size) for varlen | |
| - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) | |
| or (total_q, num_heads) for varlen. None if return_lse=False | |
| Note: | |
| This function expects the input tensors to be in the format produced by | |
| split attention computation, where the first dimension is num_splits. | |
| The permuting from user format to kernel format is now done inside the kernel. | |
| """ | |
| # Input validation | |
| assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" | |
| assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" | |
| assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" | |
| assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" | |
| # Determine if this is variable length based on dimensions | |
| is_varlen = out_partial.dim() == 4 | |
| if is_varlen: | |
| # Variable length: (num_splits, total_q, num_heads, head_size) | |
| num_splits, total_q, num_heads, head_size = out_partial.shape | |
| assert lse_partial.shape == (num_splits, total_q, num_heads), ( | |
| "lse_partial shape mismatch for varlen" | |
| ) | |
| batch_size = 1 # Treat as single batch for varlen | |
| seqlen = total_q | |
| else: | |
| # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) | |
| num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape | |
| assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( | |
| "lse_partial shape mismatch" | |
| ) | |
| # Determine output dtype | |
| if out_dtype is None: | |
| out_dtype = out_partial.dtype | |
| # Create output if not provided | |
| device = out_partial.device | |
| if out is None: | |
| if is_varlen: | |
| out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) | |
| else: | |
| out = torch.empty( | |
| batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device | |
| ) | |
| # Create lse output only if requested | |
| if return_lse: | |
| if is_varlen: | |
| lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( | |
| 0, 1 | |
| ) | |
| else: | |
| lse = torch.empty( | |
| batch_size, num_heads, seqlen, dtype=torch.float32, device=device | |
| ).transpose(1, 2) | |
| else: | |
| lse = None | |
| _flash_attn_fwd_combine( | |
| out_partial, | |
| lse_partial, | |
| out, | |
| lse, | |
| cu_seqlens, | |
| seqused, | |
| ) | |
| return out, lse | |