| |
| |
|
|
| from typing import Optional, Tuple, Union |
| import torch |
| import triton |
| import triton.language as tl |
|
|
| import aiter.ops.triton.utils.types as types |
| from aiter.ops.triton.attention.mha_onekernel_bwd import flash_attn_onekernel_backward |
| from aiter.ops.triton.attention.mha_fused_bwd import flash_attn_fused_backward |
| from aiter.ops.triton.utils.logger import AiterTritonLogger |
| from aiter.ops.triton.utils.device_info import get_num_xcds |
| from kernel_jit import _attn_fwd, _get_config |
| from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 |
|
|
| _LOGGER = AiterTritonLogger() |
|
|
| global _USE_FUSED_BWD_KERNEL |
| _USE_FUSED_BWD_KERNEL = False |
|
|
|
|
| def mha_set_use_fused_bwd_kernel(value: bool): |
| """ |
| Set whether to use fused backward kernel (with atomics) or one-kernel backward (without atomics). |
| Fused backward is faster but doesn't support positional encoding. |
| """ |
| global _USE_FUSED_BWD_KERNEL |
| _USE_FUSED_BWD_KERNEL = value |
|
|
|
|
| _USE_INT64_STRIDES = True |
|
|
|
|
| def mha_set_use_int64_strides(value: bool): |
| """Use 64-bit integer strides to prevent integer overflows with very large tensors.""" |
| global _USE_INT64_STRIDES |
| _USE_INT64_STRIDES = value |
|
|
|
|
| def _flash_attn_forward( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| dropout_p: float, |
| softmax_scale: float, |
| causal: bool, |
| window_size_left: int, |
| window_size_right: int, |
| bias: Optional[torch.Tensor], |
| alibi_slopes: Optional[torch.Tensor], |
| return_lse: bool, |
| return_softmax: bool, |
| max_seqlen_q: int, |
| max_seqlen_k: int, |
| cu_seqlens_q: Optional[torch.Tensor] = None, |
| cu_seqlens_k: Optional[torch.Tensor] = None, |
| descale_q: Optional[torch.Tensor] = None, |
| descale_k: Optional[torch.Tensor] = None, |
| descale_v: Optional[torch.Tensor] = None, |
| sink: Optional[torch.Tensor] = None, |
| config: Optional[dict[str, any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int, int]: |
|
|
| if bias is not None: |
| raise ValueError("Bias is not supported yet in the Triton Backend") |
| if window_size_left != -1 or window_size_right != -1: |
| raise ValueError("Sliding Window is not supported yet in the Triton Backend") |
|
|
| |
| IS_FP8 = types._is_fp8(q) |
| FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max |
| is_varlen = True if cu_seqlens_q is not None else False |
|
|
| |
| |
| |
| if IS_FP8: |
| o = torch.empty( |
| (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device |
| ) |
| else: |
| o = torch.empty((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) |
| if is_varlen: |
| |
| |
| |
| batch, seqlen_q, num_q_heads = ( |
| len(cu_seqlens_q) - 1, |
| max_seqlen_q, |
| q.shape[1], |
| ) |
| num_k_heads = k.shape[1] |
| q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) |
| k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) |
| v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) |
| o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) |
| else: |
| |
| |
| |
| batch, seqlen_q, num_q_heads = q.shape[:-1] |
| num_k_heads = k.shape[2] |
| q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) |
| k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) |
| v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) |
| o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) |
|
|
| qk_head_dim = q.shape[-1] |
| v_head_dim = v.shape[-1] |
| pe_head_dim = qk_head_dim - v_head_dim |
| |
| BLOCK_DMODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) |
| BLOCK_DMODEL_PE_POW2 = ( |
| 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) |
| ) |
| assert (pe_head_dim == 0 and BLOCK_DMODEL_PE_POW2 == 0) or ( |
| v_head_dim == BLOCK_DMODEL_POW2 and pe_head_dim == BLOCK_DMODEL_PE_POW2 |
| ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." |
| assert (not IS_FP8) or ( |
| IS_FP8 and pe_head_dim == 0 |
| ), "Positional encoding doesn't support FP8." |
|
|
| assert (sink is None) or ( |
| sink is not None and sink.dim() == 1 and sink.shape[0] == num_q_heads |
| ), "Sink must be 1D and have one element per query head." |
|
|
| |
| if is_varlen: |
| softmax_lse = torch.zeros( |
| (q.shape[0], num_q_heads), device=q.device, dtype=torch.float32 |
| ) |
| stride_lse_z, stride_lse_h, stride_lse_m = ( |
| 0, |
| softmax_lse.stride(1), |
| softmax_lse.stride(0), |
| ) |
| else: |
| softmax_lse = torch.zeros( |
| (batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32 |
| ) |
| stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() |
|
|
| |
| enable_dropout = dropout_p > 0.0 |
| if enable_dropout: |
| philox_seed = torch.randint(0, 0xFFFFFF, (1,))[ |
| 0 |
| ].item() |
| philox_offset = torch.randint(0, 0xFFFFFF, (1,))[ |
| 0 |
| ].item() |
| else: |
| philox_seed = 0 |
| philox_offset = 0 |
| if return_softmax or enable_dropout: |
| s_dmask = torch.zeros( |
| (batch, num_q_heads, max_seqlen_q, max_seqlen_k), |
| device=q.device, |
| dtype=torch.float32, |
| ) |
| dropout_mask = torch.zeros( |
| (batch, num_q_heads, max_seqlen_q, max_seqlen_k), |
| device=q.device, |
| dtype=torch.float32, |
| ) |
| else: |
| s_dmask = None |
| dropout_mask = None |
|
|
| if config is None: |
| config = _get_config(enable_dropout, q.dtype, has_pe=pe_head_dim > 0) |
|
|
| """ |
| # Tuned for gfx942 |
| config = { |
| "BLOCK_M": 128, |
| "BLOCK_N": 64, |
| "waves_per_eu": 2, |
| "num_warps": 4, |
| "num_ctas": 1, |
| "num_stages": 1, |
| } |
| # Dropout significantly increases VGPR usage so use small tiles |
| if enable_dropout or q.dtype == torch.float32: |
| config = { |
| "BLOCK_M": 32, |
| "BLOCK_N": 32, |
| "waves_per_eu": 1, |
| "num_warps": 2, |
| "num_ctas": 1, |
| "num_stages": 1, |
| } |
| """ |
|
|
| grid = lambda META: ( |
| batch * num_q_heads * triton.cdiv(seqlen_q, META["BLOCK_M"]), |
| ) |
|
|
| _attn_fwd[grid]( |
| q, |
| k, |
| v, |
| descale_q, |
| descale_k, |
| descale_v, |
| o, |
| alibi_slopes, |
| s_dmask, |
| dropout_mask, |
| softmax_lse, |
| sink, |
| *q_strides, |
| *k_strides, |
| *v_strides, |
| descale_q.stride(0) if descale_q is not None else 0, |
| descale_k.stride(0) if descale_k is not None else 0, |
| descale_v.stride(0) if descale_v is not None else 0, |
| *o_strides, |
| alibi_slopes.stride(0) if alibi_slopes is not None else 0, |
| alibi_slopes.stride(1) if alibi_slopes is not None else 0, |
| s_dmask.stride(0) if s_dmask is not None else 0, |
| s_dmask.stride(1) if s_dmask is not None else 0, |
| s_dmask.stride(2) if s_dmask is not None else 0, |
| s_dmask.stride(3) if s_dmask is not None else 0, |
| stride_lse_z if softmax_lse is not None else 0, |
| stride_lse_h if softmax_lse is not None else 0, |
| stride_lse_m if softmax_lse is not None else 0, |
| softmax_scale, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| dropout_p, |
| philox_seed, |
| philox_offset, |
| SEQLEN_Q=max_seqlen_q, |
| SEQLEN_K=max_seqlen_k, |
| IS_CAUSAL=causal, |
| NUM_Q_HEADS=num_q_heads, |
| NUM_K_HEADS=num_k_heads, |
| BLOCK_DMODEL=v_head_dim, |
| BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, |
| BLOCK_DMODEL_PE=pe_head_dim, |
| RETURN_SCORES=return_softmax, |
| ENABLE_DROPOUT=enable_dropout, |
| IS_FP8=IS_FP8, |
| FP8_MAX=FP8_MAX, |
| VARLEN=is_varlen, |
| BATCH=batch, |
| NUM_XCD=get_num_xcds(), |
| USE_INT64_STRIDES=_USE_INT64_STRIDES, |
| ENABLE_SINK=sink is not None, |
| **config, |
| ) |
|
|
| return o, softmax_lse, s_dmask, philox_seed, philox_offset |
|
|
|
|
| class _FlashAttnFunc(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| bias, |
| alibi_slopes, |
| deterministic, |
| return_lse, |
| return_softmax, |
| sink, |
| is_grad_enabled, |
| config=None, |
| ): |
| is_grad = is_grad_enabled and any( |
| x is not None and x.requires_grad for x in [q, k, v, sink] |
| ) |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
| head_size_og = q.size(3) |
| if head_size_og % 8 != 0: |
| q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) |
| k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) |
| v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) |
| out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( |
| _flash_attn_forward( |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal=causal, |
| window_size_left=int(window_size[0]), |
| window_size_right=int(window_size[1]), |
| bias=bias, |
| alibi_slopes=alibi_slopes, |
| return_lse=return_lse, |
| return_softmax=return_softmax and dropout_p > 0, |
| max_seqlen_q=q.shape[1], |
| max_seqlen_k=k.shape[1], |
| sink=sink, |
| config=config, |
| ) |
| ) |
|
|
| if is_grad: |
| ctx.save_for_backward(q, k, v, out_padded, softmax_lse, sink) |
| ctx.philox_seed = philox_seed |
| ctx.philox_offset = philox_offset |
| ctx.dropout_p = dropout_p |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| ctx.bias = bias |
| ctx.window_size = window_size |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
|
|
| out = out_padded[..., :head_size_og] |
| result = [out] |
| if return_lse: |
| result.append(softmax_lse) |
| if return_softmax: |
| result.append(S_dmask) |
|
|
| return result[0] if len(result) == 1 else tuple(result) |
|
|
| @staticmethod |
| def backward(ctx, do, *args): |
| q, k, v, out, softmax_lse, sink = ctx.saved_tensors |
| bias = ctx.bias |
| dbias = torch.empty_like(bias) if bias is not None else None |
| dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) |
| dsink = ( |
| torch.zeros_like(sink, dtype=torch.float32) if sink is not None else None |
| ) |
| head_size_v_og = do.size(3) |
| do_padded = do |
| if head_size_v_og % 8 != 0: |
| do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) |
|
|
| if _USE_FUSED_BWD_KERNEL: |
| assert ( |
| sink is None and dsink is None |
| ), "Fused backward doesn't support sinks." |
| flash_attn_fused_backward( |
| do_padded, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| dq, |
| dk, |
| dv, |
| dbias, |
| ctx.softmax_scale, |
| ctx.alibi_slopes, |
| ctx.causal, |
| None, |
| None, |
| max_seqlen_q=q.shape[1], |
| max_seqlen_k=k.shape[1], |
| dropout_p=ctx.dropout_p, |
| philox_seed=ctx.philox_seed, |
| philox_offset=ctx.philox_offset, |
| USE_INT64_STRIDES=_USE_INT64_STRIDES, |
| ) |
| else: |
| flash_attn_onekernel_backward( |
| do_padded, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| dq, |
| dk, |
| dv, |
| dbias, |
| ctx.softmax_scale, |
| ctx.alibi_slopes, |
| ctx.causal, |
| None, |
| None, |
| max_seqlen_q=q.shape[1], |
| max_seqlen_k=k.shape[1], |
| dropout_p=ctx.dropout_p, |
| philox_seed=ctx.philox_seed, |
| philox_offset=ctx.philox_offset, |
| USE_INT64_STRIDES=_USE_INT64_STRIDES, |
| sink=sink, |
| dsink=dsink, |
| ) |
|
|
| dq = dq[..., : q.shape[-1]] |
| dk = dk[..., : k.shape[-1]] |
| dv = dv[..., : v.shape[-1]] |
| return ( |
| dq, |
| dk, |
| dv, |
| None, |
| None, |
| None, |
| None, |
| dbias, |
| None, |
| None, |
| None, |
| None, |
| dsink, |
| None, |
| None, |
| ) |
|
|
|
|
| def flash_attn_func( |
| q, |
| k, |
| v, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| bias=None, |
| alibi_slopes=None, |
| deterministic=True, |
| return_lse=False, |
| return_attn_probs=False, |
| sink=None, |
| config: Optional[dict[str, any]] = None, |
| ): |
| """dropout_p should be set to 0.0 during evaluation |
| Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads |
| than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. |
| For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head |
| 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. |
| |
| If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. |
| For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: |
| 1 1 1 1 0 |
| 1 1 1 1 1 |
| If seqlen_q = 5 and seqlen_k = 2, the causal mask is: |
| 0 0 |
| 0 0 |
| 0 0 |
| 1 0 |
| 1 1 |
| If the row of the mask is all zero, the output will be zero. |
| |
| If window_size != (-1, -1), implements sliding window local attention. Query at position i |
| will only attend to keys between |
| [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. |
| |
| Arguments: |
| q: (batch_size, seqlen, nheads, headdim) |
| k: (batch_size, seqlen, nheads_k, headdim) |
| v: (batch_size, seqlen, nheads_k, headdim) |
| dropout_p: float. Dropout probability. |
| softmax_scale: float. The scaling of QK^T before applying softmax. |
| Default to 1 / sqrt(headdim_q). |
| causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| window_size: (left, right). If not (-1, -1), implements sliding window local attention. |
| bias: (seqlen_q, seqlen_k) |
| alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| is added to the attention score of query i and key j. |
| deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| which is slightly slower and uses more memory. The forward pass is always deterministic. |
| return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| testing only. The returned probabilities are not guaranteed to be correct |
| (they might not have the right scaling). |
| sink: (nheads,), attention sink scores (one per Q head), or None |
| Return: |
| out: (batch_size, seqlen, nheads, headdim). |
| softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The |
| logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| normalization factor). |
| S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). |
| The output of softmax (possibly with different scaling). It also encodes the dropout |
| pattern (negative means that location was dropped, nonnegative means it was kept). |
| """ |
| _LOGGER.info( |
| f"FLASH_ATTN: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" |
| ) |
| return _FlashAttnFunc.apply( |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| bias, |
| alibi_slopes, |
| deterministic, |
| return_lse, |
| return_attn_probs, |
| sink, |
| torch.is_grad_enabled(), |
| config, |
| ) |
|
|
|
|
| class _FlashAttnVarlenFunc(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| q, |
| k, |
| v, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| bias, |
| alibi_slopes, |
| deterministic, |
| return_lse, |
| return_softmax, |
| block_table, |
| out, |
| sink, |
| is_grad_enabled, |
| config=None, |
| ): |
| is_grad = is_grad_enabled and any( |
| x is not None and x.requires_grad for x in [q, k, v, sink] |
| ) |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
| head_size_og = q.size(2) |
| if head_size_og % 8 != 0: |
| q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) |
| k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) |
| v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) |
| out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( |
| _flash_attn_forward( |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal=causal, |
| window_size_left=int(window_size[0]), |
| window_size_right=int(window_size[1]), |
| bias=bias, |
| alibi_slopes=alibi_slopes, |
| return_lse=return_lse, |
| return_softmax=return_softmax and dropout_p > 0.0, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| sink=sink, |
| config=config, |
| ) |
| ) |
| if is_grad: |
| ctx.save_for_backward( |
| q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, sink |
| ) |
| ctx.max_seqlen_q = max_seqlen_q |
| ctx.max_seqlen_k = max_seqlen_k |
| ctx.philox_seed = philox_seed |
| ctx.philox_offset = philox_offset |
| ctx.dropout_p = dropout_p |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| ctx.window_size = window_size |
| ctx.bias = bias |
| ctx.alibi_slopes = alibi_slopes |
| out = out_padded[..., :head_size_og] |
|
|
| result = [out] |
| if return_lse: |
| result.append(softmax_lse) |
| if return_softmax: |
| result.append(S_dmask) |
|
|
| return result[0] if len(result) == 1 else tuple(result) |
|
|
| @staticmethod |
| def backward(ctx, do, *args): |
| q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, sink = ctx.saved_tensors |
| dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) |
| bias = ctx.bias |
| dbias = torch.empty_like(bias) if bias is not None else None |
| dsink = ( |
| torch.zeros_like(sink, dtype=torch.float32) if sink is not None else None |
| ) |
| head_size_og = do.size(2) |
| do_padded = do |
| if head_size_og % 8 != 0: |
| do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) |
|
|
| if _USE_FUSED_BWD_KERNEL: |
| assert ( |
| sink is None and dsink is None |
| ), "Fused backward doesn't support sinks." |
| flash_attn_fused_backward( |
| do_padded, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| dq, |
| dk, |
| dv, |
| dbias, |
| ctx.softmax_scale, |
| ctx.alibi_slopes, |
| ctx.causal, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q=ctx.max_seqlen_q, |
| max_seqlen_k=ctx.max_seqlen_k, |
| dropout_p=ctx.dropout_p, |
| philox_seed=ctx.philox_seed, |
| philox_offset=ctx.philox_offset, |
| USE_INT64_STRIDES=_USE_INT64_STRIDES, |
| ) |
| else: |
| flash_attn_onekernel_backward( |
| do_padded, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| dq, |
| dk, |
| dv, |
| dbias, |
| ctx.softmax_scale, |
| ctx.alibi_slopes, |
| ctx.causal, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q=ctx.max_seqlen_q, |
| max_seqlen_k=ctx.max_seqlen_k, |
| dropout_p=ctx.dropout_p, |
| philox_seed=ctx.philox_seed, |
| philox_offset=ctx.philox_offset, |
| USE_INT64_STRIDES=_USE_INT64_STRIDES, |
| sink=sink, |
| dsink=dsink, |
| ) |
|
|
| dq = dq[..., : q.shape[-1]] |
| dk = dk[..., : k.shape[-1]] |
| dv = dv[..., : v.shape[-1]] |
| return ( |
| dq, |
| dk, |
| dv, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| dbias, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| dsink, |
| None, |
| None, |
| ) |
|
|
|
|
| def flash_attn_varlen_func( |
| q, |
| k, |
| v, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| bias=None, |
| alibi_slopes=None, |
| deterministic=False, |
| return_lse=False, |
| return_attn_probs=False, |
| block_table=None, |
| out=None, |
| sink=None, |
| config: Optional[dict[str, any]] = None, |
| ): |
| """dropout_p should be set to 0.0 during evaluation |
| Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads |
| than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. |
| For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head |
| 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. |
| |
| If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. |
| For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: |
| 1 1 1 1 0 |
| 1 1 1 1 1 |
| If seqlen_q = 5 and seqlen_k = 2, the causal mask is: |
| 0 0 |
| 0 0 |
| 0 0 |
| 1 0 |
| 1 1 |
| If the row of the mask is all zero, the output will be zero. |
| |
| If window_size != (-1, -1), implements sliding window local attention. Query at position i |
| will only attend to keys between |
| [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. |
| |
| Arguments: |
| q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. |
| k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| of the sequences in the batch, used to index into q. |
| cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| of the sequences in the batch, used to index into kv. |
| max_seqlen_q: int. Maximum query sequence length in the batch. |
| max_seqlen_k: int. Maximum key sequence length in the batch. |
| dropout_p: float. Dropout probability. |
| softmax_scale: float. The scaling of QK^T before applying softmax. |
| Default to 1 / sqrt(headdim). |
| causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| window_size: (left, right). If not (-1, -1), implements sliding window local attention. |
| bias: (seqlen_q, seqlen_k) |
| alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| is added to the attention score of query i and key j. |
| deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| which is slightly slower and uses more memory. The forward pass is always deterministic. |
| return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| testing only. The returned probabilities are not guaranteed to be correct |
| (they might not have the right scaling). |
| sink: (nheads,), attention sink scores (one per Q head), or None |
| Return: |
| out: (total, nheads, headdim). |
| softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The |
| logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| normalization factor). |
| S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). |
| The output of softmax (possibly with different scaling). It also encodes the dropout |
| pattern (negative means that location was dropped, nonnegative means it was kept). |
| """ |
|
|
| _LOGGER.info( |
| f"FLASH_ATTN_VARLEN: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" |
| ) |
| return _FlashAttnVarlenFunc.apply( |
| q, |
| k, |
| v, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| bias, |
| alibi_slopes, |
| deterministic, |
| return_lse, |
| return_attn_probs, |
| block_table, |
| out, |
| sink, |
| torch.is_grad_enabled(), |
| config, |
| ) |
|
|
|
|
| def flash_attn_with_kvcache( |
| q: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| k: Optional[torch.Tensor] = None, |
| v: Optional[torch.Tensor] = None, |
| cache_seqlens: Optional[Union[torch.Tensor, int]] = None, |
| softmax_scale: Optional[float] = None, |
| causal: bool = True, |
| window_size: tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| num_splits: int = 0, |
| rotary_cos: Optional[torch.Tensor] = None, |
| rotary_sin: Optional[torch.Tensor] = None, |
| cache_batch_idx: Optional[torch.Tensor] = None, |
| cache_leftpad: Optional[torch.Tensor] = None, |
| block_table: Optional[torch.Tensor] = None, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| rotary_interleaved: bool = True, |
| return_softmax_lse: bool = False, |
| ): |
| """ |
| This mirrors the public flash_attn v2 interface for KV cache using the AMD Triton backend. |
| |
| Args: |
| q: (batch, seqlen_q, nheads_q, headdim) |
| k_cache / v_cache: Either contiguous (batch, seqlen_cache, nheads_k, headdim) or paged |
| (num_blocks, page_block_size, nheads_k, headdim) when block_table provided. |
| k, v: Optional incremental tokens to append in-place (appended logically after existing cache). |
| cache_seqlens: int or (batch,) current valid lengths per batch entry. |
| softmax_scale: Optional override; defaults to 1/sqrt(headdim). |
| causal: Apply causal masking. |
| window_size: (left, right) local attention window; (-1,-1) = full. |
| softcap: (float) currently must be 0.0 (backend limitation). |
| num_splits: 0 or 1 only (backend limitation >1). |
| rotary_cos/rotary_sin: Optional rotary embeddings (applied if provided) - interleaving flag unused here. |
| cache_batch_idx/cache_leftpad: Optional indexing / left padding metadata. |
| block_table: Optional paging table mapping logical blocks for paged KV cache. |
| alibi_slopes: (nheads,) or (batch,nheads) bias slopes (currently ignored if provided - placeholder). |
| rotary_interleaved: Flag kept for parity (currently forwarded as True constant to backend which ignores it). |
| return_softmax_lse: If True returns (out, lse) else out. |
| |
| Returns: |
| out (and optionally softmax_lse): (batch, seqlen_q, nheads_q, headdim) |
| """ |
| |
| if softcap != 0.0: |
| raise NotImplementedError( |
| "softcap != 0 not supported in v2 KV cache backend yet" |
| ) |
| if num_splits not in (0, 1): |
| raise NotImplementedError( |
| "num_splits > 1 not supported in v2 KV cache backend yet" |
| ) |
|
|
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| if cache_seqlens is not None and isinstance(cache_seqlens, int): |
| cache_seqlens = torch.full( |
| (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device |
| ) |
|
|
| |
| assert q.stride(-1) == 1 and k_cache.stride(-1) == 1 and v_cache.stride(-1) == 1 |
|
|
| out, softmax_lse = flash_attn_2.fwd_kvcache( |
| q, |
| k_cache, |
| v_cache, |
| k, |
| v, |
| cache_seqlens, |
| rotary_cos, |
| rotary_sin, |
| cache_batch_idx, |
| cache_leftpad, |
| block_table, |
| alibi_slopes, |
| None, |
| softmax_scale, |
| causal, |
| int(window_size[0]), |
| int(window_size[1]), |
| 0.0, |
| rotary_interleaved, |
| num_splits, |
| ) |
| return (out, softmax_lse) if return_softmax_lse else out |
|
|