# SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Tuple, Union import torch import triton import triton.language as tl import aiter.ops.triton.utils.types as types from aiter.ops.triton.attention.mha_onekernel_bwd import flash_attn_onekernel_backward from aiter.ops.triton.attention.mha_fused_bwd import flash_attn_fused_backward from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.device_info import get_num_xcds from kernel_jit import _attn_fwd, _get_config from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 _LOGGER = AiterTritonLogger() global _USE_FUSED_BWD_KERNEL _USE_FUSED_BWD_KERNEL = False def mha_set_use_fused_bwd_kernel(value: bool): """ Set whether to use fused backward kernel (with atomics) or one-kernel backward (without atomics). Fused backward is faster but doesn't support positional encoding. """ global _USE_FUSED_BWD_KERNEL _USE_FUSED_BWD_KERNEL = value _USE_INT64_STRIDES = True def mha_set_use_int64_strides(value: bool): """Use 64-bit integer strides to prevent integer overflows with very large tensors.""" global _USE_INT64_STRIDES _USE_INT64_STRIDES = value def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, bias: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor], return_lse: bool, # Not used return_softmax: bool, max_seqlen_q: int, max_seqlen_k: int, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, descale_v: Optional[torch.Tensor] = None, sink: Optional[torch.Tensor] = None, config: Optional[dict[str, any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int, int]: if bias is not None: raise ValueError("Bias is not supported yet in the Triton Backend") if window_size_left != -1 or window_size_right != -1: raise ValueError("Sliding Window is not supported yet in the Triton Backend") # FP8 IS_FP8 = types._is_fp8(q) FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max is_varlen = True if cu_seqlens_q is not None else False # The kernel writes every (row < seqlen_q, full head_dim) element of the # output (causal early-exit rows are explicitly zeroed inside the kernel), # so we can skip the redundant memset of torch.zeros and use torch.empty. if IS_FP8: o = torch.empty( (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device ) else: o = torch.empty((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) if is_varlen: # Layout is thd. # q and k are [total_tokens, num_head, head_dim_qk]. # v is [total_tokens, num_head, head_dim_v]. batch, seqlen_q, num_q_heads = ( len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], ) num_k_heads = k.shape[1] q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) else: # Layout is bshd. # q and k are [batch, seq_len, num_head, head_dim_qk]. # v is [batch, seq_len, num_head, head_dim_v]. batch, seqlen_q, num_q_heads = q.shape[:-1] num_k_heads = k.shape[2] q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) qk_head_dim = q.shape[-1] v_head_dim = v.shape[-1] pe_head_dim = qk_head_dim - v_head_dim # padding for head_dim. Power of 2 or 16 BLOCK_DMODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) BLOCK_DMODEL_PE_POW2 = ( 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) ) assert (pe_head_dim == 0 and BLOCK_DMODEL_PE_POW2 == 0) or ( v_head_dim == BLOCK_DMODEL_POW2 and pe_head_dim == BLOCK_DMODEL_PE_POW2 ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." assert (not IS_FP8) or ( IS_FP8 and pe_head_dim == 0 ), "Positional encoding doesn't support FP8." assert (sink is None) or ( sink is not None and sink.dim() == 1 and sink.shape[0] == num_q_heads ), "Sink must be 1D and have one element per query head." # softmax_lse [batch, num_q_heads, seqlen_q] if is_varlen: softmax_lse = torch.zeros( (q.shape[0], num_q_heads), device=q.device, dtype=torch.float32 ) stride_lse_z, stride_lse_h, stride_lse_m = ( 0, softmax_lse.stride(1), softmax_lse.stride(0), ) else: softmax_lse = torch.zeros( (batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32 ) stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() # exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] enable_dropout = dropout_p > 0.0 if enable_dropout: philox_seed = torch.randint(0, 0xFFFFFF, (1,))[ 0 ].item() # No specific reason to restrict range to 0xffffff philox_offset = torch.randint(0, 0xFFFFFF, (1,))[ 0 ].item() # Pass in an int, not Tensor else: philox_seed = 0 philox_offset = 0 if return_softmax or enable_dropout: s_dmask = torch.zeros( (batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32, ) dropout_mask = torch.zeros( (batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32, ) else: s_dmask = None dropout_mask = None if config is None: config = _get_config(enable_dropout, q.dtype, has_pe=pe_head_dim > 0) """ # Tuned for gfx942 config = { "BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "num_warps": 4, "num_ctas": 1, "num_stages": 1, } # Dropout significantly increases VGPR usage so use small tiles if enable_dropout or q.dtype == torch.float32: config = { "BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 1, "num_warps": 2, "num_ctas": 1, "num_stages": 1, } """ grid = lambda META: ( # noqa: E731 batch * num_q_heads * triton.cdiv(seqlen_q, META["BLOCK_M"]), ) _attn_fwd[grid]( q, k, v, descale_q, descale_k, descale_v, o, alibi_slopes, s_dmask, dropout_mask, softmax_lse, sink, *q_strides, *k_strides, *v_strides, descale_q.stride(0) if descale_q is not None else 0, descale_k.stride(0) if descale_k is not None else 0, descale_v.stride(0) if descale_v is not None else 0, *o_strides, alibi_slopes.stride(0) if alibi_slopes is not None else 0, alibi_slopes.stride(1) if alibi_slopes is not None else 0, s_dmask.stride(0) if s_dmask is not None else 0, s_dmask.stride(1) if s_dmask is not None else 0, s_dmask.stride(2) if s_dmask is not None else 0, s_dmask.stride(3) if s_dmask is not None else 0, stride_lse_z if softmax_lse is not None else 0, stride_lse_h if softmax_lse is not None else 0, stride_lse_m if softmax_lse is not None else 0, softmax_scale, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset, SEQLEN_Q=max_seqlen_q, SEQLEN_K=max_seqlen_k, IS_CAUSAL=causal, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BLOCK_DMODEL=v_head_dim, BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, BLOCK_DMODEL_PE=pe_head_dim, RETURN_SCORES=return_softmax, ENABLE_DROPOUT=enable_dropout, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, VARLEN=is_varlen, BATCH=batch, NUM_XCD=get_num_xcds(), USE_INT64_STRIDES=_USE_INT64_STRIDES, ENABLE_SINK=sink is not None, **config, ) return o, softmax_lse, s_dmask, philox_seed, philox_offset class _FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, bias, alibi_slopes, deterministic, return_lse, return_softmax, sink, is_grad_enabled, config=None, ): is_grad = is_grad_enabled and any( x is not None and x.requires_grad for x in [q, k, v, sink] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, window_size_left=int(window_size[0]), window_size_right=int(window_size[1]), bias=bias, alibi_slopes=alibi_slopes, return_lse=return_lse, return_softmax=return_softmax and dropout_p > 0, max_seqlen_q=q.shape[1], max_seqlen_k=k.shape[1], sink=sink, config=config, ) ) if is_grad: ctx.save_for_backward(q, k, v, out_padded, softmax_lse, sink) ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.bias = bias ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] result = [out] if return_lse: result.append(softmax_lse) if return_softmax: result.append(S_dmask) return result[0] if len(result) == 1 else tuple(result) @staticmethod def backward(ctx, do, *args): q, k, v, out, softmax_lse, sink = ctx.saved_tensors bias = ctx.bias dbias = torch.empty_like(bias) if bias is not None else None dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) dsink = ( torch.zeros_like(sink, dtype=torch.float32) if sink is not None else None ) head_size_v_og = do.size(3) do_padded = do if head_size_v_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) if _USE_FUSED_BWD_KERNEL: assert ( sink is None and dsink is None ), "Fused backward doesn't support sinks." flash_attn_fused_backward( do_padded, q, k, v, out, softmax_lse, dq, dk, dv, dbias, ctx.softmax_scale, ctx.alibi_slopes, ctx.causal, None, None, max_seqlen_q=q.shape[1], max_seqlen_k=k.shape[1], dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, USE_INT64_STRIDES=_USE_INT64_STRIDES, ) else: flash_attn_onekernel_backward( do_padded, q, k, v, out, softmax_lse, dq, dk, dv, dbias, ctx.softmax_scale, ctx.alibi_slopes, ctx.causal, None, None, max_seqlen_q=q.shape[1], max_seqlen_k=k.shape[1], dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, USE_INT64_STRIDES=_USE_INT64_STRIDES, sink=sink, dsink=dsink, ) dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] return ( dq, dk, dv, None, # dropout_p None, # softmax_scale None, # causal None, # window_size dbias, None, # alibi_slopes None, # deterministic None, # return_lse None, # return_softmax dsink, None, # is_grad_enabled None, # config ) def flash_attn_func( q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window bias=None, alibi_slopes=None, deterministic=True, return_lse=False, return_attn_probs=False, sink=None, config: Optional[dict[str, any]] = None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim_q). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. bias: (seqlen_q, seqlen_k) alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). sink: (nheads,), attention sink scores (one per Q head), or None Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ _LOGGER.info( f"FLASH_ATTN: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" ) return _FlashAttnFunc.apply( q, k, v, dropout_p, softmax_scale, causal, window_size, bias, alibi_slopes, deterministic, return_lse, return_attn_probs, sink, torch.is_grad_enabled(), config, ) class _FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, bias, alibi_slopes, deterministic, return_lse, return_softmax, block_table, out, sink, is_grad_enabled, config=None, ): is_grad = is_grad_enabled and any( x is not None and x.requires_grad for x in [q, k, v, sink] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_size_og = q.size(2) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, window_size_left=int(window_size[0]), window_size_right=int(window_size[1]), bias=bias, alibi_slopes=alibi_slopes, return_lse=return_lse, return_softmax=return_softmax and dropout_p > 0.0, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, sink=sink, config=config, ) ) if is_grad: ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, sink ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.bias = bias ctx.alibi_slopes = alibi_slopes out = out_padded[..., :head_size_og] result = [out] if return_lse: result.append(softmax_lse) if return_softmax: result.append(S_dmask) return result[0] if len(result) == 1 else tuple(result) @staticmethod def backward(ctx, do, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, sink = ctx.saved_tensors dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) bias = ctx.bias dbias = torch.empty_like(bias) if bias is not None else None dsink = ( torch.zeros_like(sink, dtype=torch.float32) if sink is not None else None ) head_size_og = do.size(2) do_padded = do if head_size_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) if _USE_FUSED_BWD_KERNEL: assert ( sink is None and dsink is None ), "Fused backward doesn't support sinks." flash_attn_fused_backward( do_padded, q, k, v, out, softmax_lse, dq, dk, dv, dbias, ctx.softmax_scale, ctx.alibi_slopes, ctx.causal, cu_seqlens_q, cu_seqlens_k, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, USE_INT64_STRIDES=_USE_INT64_STRIDES, ) else: flash_attn_onekernel_backward( do_padded, q, k, v, out, softmax_lse, dq, dk, dv, dbias, ctx.softmax_scale, ctx.alibi_slopes, ctx.causal, cu_seqlens_q, cu_seqlens_k, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, USE_INT64_STRIDES=_USE_INT64_STRIDES, sink=sink, dsink=dsink, ) dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] return ( dq, dk, dv, None, # cu_seqlens_q, None, # cu_seqlens_k None, # max_seqlen_q None, # max_seqlen_k None, # dropout_p None, # softmax_scale None, # causal None, # window_size dbias, None, # alibi_slopes None, # deterministic None, # return_lse None, # return_softmax None, # block_table None, # out dsink, None, # is_grad_enabled None, # config ) def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window bias=None, alibi_slopes=None, deterministic=False, return_lse=False, return_attn_probs=False, block_table=None, out=None, sink=None, config: Optional[dict[str, any]] = None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. bias: (seqlen_q, seqlen_k) alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). sink: (nheads,), attention sink scores (one per Q head), or None Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ _LOGGER.info( f"FLASH_ATTN_VARLEN: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" ) return _FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, bias, alibi_slopes, deterministic, return_lse, return_attn_probs, block_table, out, sink, torch.is_grad_enabled(), config, ) def flash_attn_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, cache_seqlens: Optional[Union[torch.Tensor, int]] = None, softmax_scale: Optional[float] = None, causal: bool = True, window_size: tuple[int, int] = (-1, -1), softcap: float = 0.0, num_splits: int = 0, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, rotary_interleaved: bool = True, return_softmax_lse: bool = False, ): """ This mirrors the public flash_attn v2 interface for KV cache using the AMD Triton backend. Args: q: (batch, seqlen_q, nheads_q, headdim) k_cache / v_cache: Either contiguous (batch, seqlen_cache, nheads_k, headdim) or paged (num_blocks, page_block_size, nheads_k, headdim) when block_table provided. k, v: Optional incremental tokens to append in-place (appended logically after existing cache). cache_seqlens: int or (batch,) current valid lengths per batch entry. softmax_scale: Optional override; defaults to 1/sqrt(headdim). causal: Apply causal masking. window_size: (left, right) local attention window; (-1,-1) = full. softcap: (float) currently must be 0.0 (backend limitation). num_splits: 0 or 1 only (backend limitation >1). rotary_cos/rotary_sin: Optional rotary embeddings (applied if provided) - interleaving flag unused here. cache_batch_idx/cache_leftpad: Optional indexing / left padding metadata. block_table: Optional paging table mapping logical blocks for paged KV cache. alibi_slopes: (nheads,) or (batch,nheads) bias slopes (currently ignored if provided - placeholder). rotary_interleaved: Flag kept for parity (currently forwarded as True constant to backend which ignores it). return_softmax_lse: If True returns (out, lse) else out. Returns: out (and optionally softmax_lse): (batch, seqlen_q, nheads_q, headdim) """ # Feature guards / normalization if softcap != 0.0: raise NotImplementedError( "softcap != 0 not supported in v2 KV cache backend yet" ) if num_splits not in (0, 1): raise NotImplementedError( "num_splits > 1 not supported in v2 KV cache backend yet" ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) # Contiguity (align last dim contiguous requirement similar to v3 path assumptions) assert q.stride(-1) == 1 and k_cache.stride(-1) == 1 and v_cache.stride(-1) == 1 out, softmax_lse = flash_attn_2.fwd_kvcache( q, k_cache, v_cache, k, v, cache_seqlens, rotary_cos, rotary_sin, cache_batch_idx, cache_leftpad, block_table, alibi_slopes, None, # out tensor softmax_scale, causal, int(window_size[0]), int(window_size[1]), 0.0, # softcap (guarded) rotary_interleaved, num_splits, ) return (out, softmax_lse) if return_softmax_lse else out