diff --git "a/build/torch-cuda/flash_fwd.py" "b/build/torch-cuda/flash_fwd.py" new file mode 100644--- /dev/null +++ "b/build/torch-cuda/flash_fwd.py" @@ -0,0 +1,2426 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of +# https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h +# and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h +# from Cutlass C++ to Cute-DSL. +# Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py + +import math +from types import SimpleNamespace +from typing import Type, Callable, Optional, List +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Constexpr, Float32, Int32, const_expr, Boolean +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +import cutlass.utils as utils_basic +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_basic + +from .quack import copy_utils +from .quack import layout_utils +from .quack import sm90_utils + +from . import ampere_helpers as sm80_utils +from .cute_dsl_utils import assume_tensor_aligned +from . import utils +from .mask import AttentionMask +from .softmax import Softmax, apply_score_mod_inner +from .seqlen_info import SeqlenInfoQK +from .block_info import BlockInfo +from .block_sparsity import BlockSparseTensors +from .block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) +from . import pipeline +from .pack_gqa import PackGQA +from .named_barrier import NamedBarrierFwd +from .quack.cute_dsl_utils import ParamsBase +from .tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, +) +from cutlass.cute import FastDivmodDivisor + + +class FlashAttentionForwardBase: + arch: int = 80 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + is_local: bool = False, + pack_gqa: bool = True, + tile_m: int = 128, + tile_n: int = 128, + num_stages: int = 1, + num_threads: int = 128, + Q_in_regs: bool = False, + score_mod: Optional[cutlass.Constexpr] = None, + mask_mod: Optional[cutlass.Constexpr] = None, + has_aux_tensors: bool = False, + q_subtile_factor: int | None = None, + ): + """Initializes the configuration for a flash attention kernel. + + All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :param score_mod: A callable that takes the attention scores and applies a modification. + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` + :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` + """ + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal + self.is_local = is_local + self.pack_gqa = pack_gqa + self.tile_m = tile_m + self.tile_n = tile_n + self.num_threads = num_threads + self.num_stages = num_stages + self.q_subtile_factor = q_subtile_factor + self.Q_in_regs = Q_in_regs + self.score_mod = score_mod + self.mask_mod = mask_mod + self.qk_acc_dtype = Float32 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_K + # TODO: sm86 and sm89 + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") + if smem_usage > smem_capacity: + return False + # Check if twice the block size is divisible by the number of threads + if (tile_m * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric] | None, + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [None, Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mCuSeqlensQ_type not in [None, Int32]): + raise TypeError("cu_seqlens_q tensor must be Int32") + if const_expr(mCuSeqlensK_type not in [None, Int32]): + raise TypeError("cu_seqlens_k tensor must be Int32") + if const_expr(mSeqUsedQ_type not in [None, Int32]): + raise TypeError("seqused_q tensor must be Int32") + if const_expr(mSeqUsedK_type not in [None, Int32]): + raise TypeError("seqused_k tensor must be Int32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( + self._get_smem_layout_atom() + ) + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self.tile_m, self.tile_hdim), + (0, 1), + ) + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, + (self.tile_n, self.tile_hdim, self.num_stages), + (0, 1, 2), + ) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, + (self.tile_n, self.tile_hdimv, self.num_stages), + (0, 1, 2), + ) + self.sO_layout = cute.tile_to_shape( + sO_layout_atom, + (self.tile_m, self.tile_hdimv), + (0, 1), + ) + if const_expr(sP_layout_atom is not None): + self.sP_layout = cute.tile_to_shape( + sP_layout_atom, + (self.tile_m, self.tile_n), + (0, 1), + ) + else: + self.sP_layout = None + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tQ_layout and tK_layout: thread layout for QK load + tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + tQ_layout = cute.make_ordered_layout( + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), + ) + tK_layout = cute.make_ordered_layout( + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we load Q + assert self.tile_m % tQ_layout.shape[0] == 0 + tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + tV_layout = cute.make_ordered_layout( + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), + ) + # TODO: need a different layout for O if O dtype is not the same as V dtype + # tO_layout: thread layout for O store + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.tile_m % tO_layout.shape[0] == 0 + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout) + self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout) + self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout) + # gmem_tiled_copy_O: tiled copy for O store + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + def _get_smem_layout_atom(self): + raise NotImplementedError() + + def _get_tiled_mma(self): + raise NotImplementedError() + + def _get_shared_storage_cls(self): + raise NotImplementedError() + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + raise NotImplementedError() + + @cute.jit + def epilogue( + self, + acc_O: cute.Tensor, + lse: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sO: cute.Tensor, + seqlen: SeqlenInfoQK, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tiled_mma: cute.TiledMma, + tidx: Int32, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + ): + # store acc_O + rO = cute.make_fragment_like(acc_O, self.dtype) + rO.store(acc_O.load().to(self.dtype)) + # Make sure all threads have finished reading V + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_O, taccOrO, taccOsO) + + cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead + ) + + # Write LSE from rmem -> gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + if const_expr(not self.pack_gqa): + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) + gLSE_expanded_layout = cute.append( + gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,)) + ) + gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) + thr_mma = tiled_mma.get_slice(tidx) + taccOgLSE = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gLSE_expanded)) + assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse) + taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) + t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0: + for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): + if ( + t0accOcO[m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] + ): + taccOgLSE[m, 0] = lse[m] + else: + pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) + + if const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + # thr_mma = tiled_mma.get_slice(tidx) + # taccOgO = thr_mma.partition_C(gO) + # cute.autovec_copy(rO, taccOgO) + # sync to make sure all smem stores are done + if const_expr(self.use_tma_O): + # ensure smem writes are visible to TMA + cute.arch.fence_view_async_shared() + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True + ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + store_O() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + # load acc O from smem to rmem for wider vectorization + cute.autovec_copy(tOsO, tOrO) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) + + @cute.jit + def advance_pipeline(self, pipeline_index): + return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0 + + @cute.jit + def load_Q( + self, + gmem_thr_copy: cute.TiledCopy, + gQ: cute.Tensor, + sQ: cute.Tensor, + block: Int32, + seqlen: Int32, + headdim: Int32, + ): + tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) + cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=headdim) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. + if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]: + cute.copy( + gmem_thr_copy, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def load_K( + self, + gmem_tiled_copy: cute.TiledCopy, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + tKcK: cute.Tensor, + t0KcK: cute.Tensor, + tKpK: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load K? + is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_k): + # Instead of using tKcK, we using t0KcK and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. + if const_expr(is_even_n_smem_k): + seqlen_limit = seqlen - block * self.tile_n + else: + if const_expr(not need_predicates): + seqlen_limit = self.tile_n + else: + seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n) + seqlen_limit -= tKcK[0][0] + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if t0KcK[0, n, 0][0] < seqlen_limit: + cute.copy( + gmem_tiled_copy, + tKgK[None, n, None, block], + tKsK[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], + pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + else: + cute.copy( + gmem_tiled_copy, + tKgK[None, None, None, block], + tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK if const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def load_V( + self, + gmem_tiled_copy: cute.TiledCopy, + tVgV: cute.Tensor, + tVsV: cute.Tensor, + tVcV: cute.Tensor, + t0VcV: cute.Tensor, + tVpV: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load V? + is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_v): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if ( + is_even_n_smem_v + or n < cute.size(tVsV.shape[1]) - 1 + or tVcV[0, n, 0][0] < self.tile_n + ): + predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None + if const_expr(need_predicates): + seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] + predicate_n = t0VcV[0, n, 0][0] < seqlen_limit + predicate = cute.make_fragment_like(tVpV[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = ( + tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True + ) and predicate_n + cute.copy( + gmem_tiled_copy, + tVgV[None, n, None, block], + tVsV[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], + pred=predicate, + ) + else: + cute.copy( + gmem_tiled_copy, + tVgV[None, None, None, block], + tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tVpV if const_expr(self.check_hdim_v_oob) else None, + ) + + +class FlashAttentionForwardSm80(FlashAttentionForwardBase): + def _get_smem_layout_atom(self): + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv) + sO_layout_atom = sV_layout_atom + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + tiled_mma_pv = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + + @cute.struct + class SharedStorageQKV: + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + + @cute.struct + class SharedStorageSharedQV: + sQ: sQV_struct + sK: sK_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + stream: cuda.CUstream, + softmax_scale: Optional[Float32] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + learnable_sink: Optional[cute.Tensor] = None, + aux_tensors=None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + assert learnable_sink is None, "Learnable sink is not supported in this kernel" + self._check_type( + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + ) + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_pv.size + self.num_producer_threads = self.num_threads + self.num_Q_load_threads = self.num_threads + self.num_epilogue_threads = self.num_threads + # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None + self.use_tma_O = self.arch >= 90 + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) + for t in (mQ, mK, mV, mO) + ] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) + # grid_dim: (m_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mQ.shape[0], self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]), + ) + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = Float32(softmax_scale * LOG2_E) + softmax_scale = None + else: + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = Float32(LOG2_E) + softmax_scale = Float32(softmax_scale) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.kernel( + mQ, + mK, + mV, + mO, + mLSE, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + aux_tensors=None, + fastdiv_mods=None, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, num_head, batch_size = cute.arch.block_idx() + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # TODO: return early if n_block_max == 0 + # if self.is_causal: + # if n_block_max <= 0: + # return + n_block = n_block_max - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.tile_m, self.tile_hdim) + blkK_shape = (self.tile_n, self.tile_hdim) + blkV_shape = (self.tile_n, self.tile_hdimv) + gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) + num_head_kv = num_head // self.qhead_per_kvhead + gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) + gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = layout_utils.transpose_view(sV) + + gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) + gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + thr_mma_pv = tiled_mma_pv.get_slice(tidx) + tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) + tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) + tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) + acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_QK = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self.dtype, + ) + smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tKcK = gmem_thr_copy_K.partition_S(cK) + t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) + if const_expr(self.tile_hdim == self.tile_hdimv): + tVcV = tKcK + t0VcV = t0KcK + else: + cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) + tVcV = gmem_thr_copy_V.partition_S(cV) + t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) + if const_expr(self.same_hdim_kv): + tVpV = tKpK + else: + tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) + + # shape: (atom_v_m * rest_m) + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + softmax.reset() + + # group parameters for compute_one_n_block + mma_params = SimpleNamespace( + thr_mma_qk=thr_mma_qk, + thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_Q=smem_thr_copy_Q, + smem_thr_copy_K=smem_thr_copy_K, + smem_thr_copy_V=smem_thr_copy_V, + tSsQ=tSsQ, + tSsK=tSsK, + tOsVt=tOsVt, + ) + load_K = partial( + self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k + ) + load_V = partial( + self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k + ) + + compute_one_n_block = partial( + self.compute_one_n_block, + mma_params=mma_params, + smem_copy_params=smem_copy_params, + softmax=softmax, + load_K=load_K, + load_V=load_V, + score_mod=self.score_mod, + batch_idx=batch_size, + head_idx=num_head, + m_block=m_block, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1]) + cute.arch.cp_async_commit_group() + + def preprocess_Q(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) + if const_expr(self.Q_in_regs): + cute.arch.barrier() + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) + + # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and + # read from smem_q to registers, then load V. + # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. + if const_expr(self.Q_in_regs): + load_K(n_block, smem_pipe_write=0, need_predicates=True) + cute.arch.cp_async_commit_group() + preprocess_Q() + cute.arch.barrier() # Make sure all threads have read smem_q before loading V + + for stage in cutlass.range_constexpr(self.num_stages): + if const_expr(not self.Q_in_regs or stage > 0): + if stage == 0 or n_block - stage >= 0: + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) + cute.arch.cp_async_commit_group() + if const_expr(stage < self.num_stages - 1): + if stage == 0 or n_block - stage >= 0: + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) + cute.arch.cp_async_commit_group() + if const_expr(not self.Q_in_regs): + preprocess_Q() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + mask = AttentionMask( + self.tile_m, + self.tile_n, + seqlen.seqlen_q, + seqlen.seqlen_k, + window_size_left, + window_size_right, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + mask_fn = partial( + mask.apply_mask, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, + ) + + # First iteration with seqlen masking + smem_pipe_read = Int32(0) + smem_pipe_write = Int32(self.num_stages - 1) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + is_first_n_block=True, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # The remaining iterations have no masking + for n_tile in cutlass.range(n_block, unroll=1): + compute_one_n_block( + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # TODO: local + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize() + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # reuse sQ's data iterator + sO = cute.make_tensor(sQ.iterator, sO_layout) + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + None, + tiled_mma_pv, + tidx, + m_block, + num_head, + batch_size, + ) + + @cute.jit + def compute_one_n_block( + self, + n_block: Int32, + smem_pipe_read: Int32, + smem_pipe_write: Int32, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + load_K: Callable, + load_V: Callable, + score_mod: Callable | None, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + seqlen: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + """Compute one n_block of S/O. + + This function provides different variants for processing the first n block versus + subsequent blocks. + """ + + def sync(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) + cute.arch.barrier() + + acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) + acc_S = cute.make_fragment(acc_shape_S, Float32) + acc_S.fill(0.0) + # wait for smem tile QK before mma calculation for S + sync() + + # need predicates for the first tile + def load_V_next(): + if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: + load_V( + n_block - self.num_stages + 1, + smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1, + ) + cute.arch.cp_async_commit_group() + + load_V_next() + sm80_utils.gemm( + mma_params.thr_mma_qk, + acc_S, + mma_params.tSrQ, + mma_params.tSrK, + smem_copy_params.tSsQ, + smem_copy_params.tSsK[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_Q, + smem_copy_params.smem_thr_copy_K, + # hook_fn=load_V_next, + A_in_regs=self.Q_in_regs, + ) + if const_expr(score_mod is not None): + self.apply_score_mod( + mma_params.thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + seqlen, + softmax_scale=softmax.softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + + def load_K_next(): + if n_block - self.num_stages >= 0: + load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) + cute.arch.cp_async_commit_group() + + # wait for smem tile V for O + if const_expr(self.num_stages == 1): + sync() + load_K_next() + if const_expr(mask_fn is not None): + mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + softmax.rescale_O(mma_params.acc_O, row_scale) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + tOrP = layout_utils.reshape_acc_to_frgA(rP) + if const_expr(self.num_stages > 1): + sync() + load_K_next() + sm80_utils.gemm_rs( + mma_params.thr_mma_pv, + mma_params.acc_O, + tOrP, + mma_params.tOrVt, + smem_copy_params.tOsVt[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_V, + # hook_fn=load_K_next, + ) + # if const_expr(self.num_stages > 1): + # load_K_next() + + +class FlashAttentionForwardSm90(FlashAttentionForwardBase): + arch = 90 + + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs + self.buffer_align_bytes = 1024 + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, + ) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv + ), + self.dtype, + ) + sO_layout_atom = sV_layout_atom + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n + ), + self.dtype, + ) + else: + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_n), + ) + tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 + sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, + mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + @cute.struct + class SharedStorageQKV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + sP: sP_struct + + @cute.struct + class SharedStorageSharedQV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sQ: sQV_struct + sK: sK_struct + sP: sP_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) + ) + + mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None + + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_qk.size + self.num_threads_per_warp_group = 128 + self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) + self.num_producer_threads = 32 + self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q + self.num_epilogue_threads = self.num_mma_threads + self.num_mma_regs = ( + 256 + if self.num_mma_warp_groups == 1 + else (240 if self.num_mma_warp_groups == 2 else 160) + ) + self.num_producer_regs = ( + 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) + ) + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.use_scheduler_barrier = ( + (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_mma_warp_groups == 2) + ) + self.use_tma_Q = self.arch >= 90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = ( + self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + ) + # TODO: rescale_O_before_gemm + self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) + ) + + SharedStorage = self._get_shared_storage_cls() + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) + + # TMA + gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() + gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast + gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } + tma_atom_Q, tma_tensor_Q = None, None + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_Q, + mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + 1, # No mcast for now + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + 1, # No mcast for now + ) + tma_atom_O, tma_tensor_O = None, None + if const_expr(self.use_tma_O): + tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_O, + mO, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast + ) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softmax_scale = None + else: + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.kernel( + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, + tma_tensor_K, + tma_tensor_V, + tma_tensor_O if const_expr(self.use_tma_O) else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + TileScheduler, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors=Optional[list[cute.Tensor]], + fastdiv_mods=None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Prefetch tma descriptor + if warp_idx == 0: + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Mbarrier init + mbar_ptr_Q = storage.mbar_ptr.data_ptr() + if warp_idx == 1: + # if tidx < 2: + # # barrierO num threads should be self.num_mma_threads + # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) + # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread + ) + pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE + ) + pipeline_k = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + defer_sync=True, + ) + pipeline_v = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_bytes["V"], + defer_sync=False + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + else: + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = layout_utils.transpose_view(sV) + sP = None + if const_expr(sP_layout is not None): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + # reuse sQ's data iterator + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: # Producer + cute.arch.setmaxregister_decrease(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + blocksparse_tensors, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.setmaxregister_increase(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + mQ, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + blocksparse_tensors, + aux_tensors, + fastdiv_mods, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + if warp_idx_in_wg == 0: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + load_Q = None + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True + ) + # TODO: mcast + # TODO check warp_idx if we have 128 producer threads + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) + load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) + load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + self.use_tma_Q, + self.tma_copy_bytes["Q"], + self.intra_wg_overlap, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # softmax: Softmax, + # acc_O: cute.Tensor, + mQ: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + aux_tensors: Optional[list], + fastdiv_mods=None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( + wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK + ) + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( + wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + self.mma_init() + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + check_inf=True, + ) + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + softmax=softmax, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + ) + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + + # shape: (atom_v_m * rest_m) + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + # Recompute fastdiv_mods if necessary for varlen with aux_tensors + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + softmax_scale=softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + mma_one_n_block = partial( + mma_one_n_block_all, + seqlen=seqlen, + softmax=softmax, + score_mod_fn=score_mod_fn, + ) + # Load Q if not TMA_Q + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + q_consumer_phase ^= 1 + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True + O_should_accumulate = False + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + # acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + ) + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + + else: + # ========================================== + # Block sparsity + # ========================================== + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + seqlen, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + fastdiv_mods, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + # Handle empty case (when no blocks to process) + if not processed_any: + softmax.reset() + acc_O.fill(0.0) + + sink_val = None + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_fragment_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.tile_m + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + # if pv gemm not rs + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + return kv_consumer_state + + @cute.jit + def mma_one_n_block( + self, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(0) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) + pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read + + @cute.jit + def mma_one_n_block_intrawg_overlap( + self, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + check_inf: cutlass.Constexpr = True, + ): + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(1) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) + warpgroup.wait_group(0) + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read + + @cute.jit + def mma_init(self): + warp_group_idx = utils.canonical_warp_group_idx(sync=False) + if const_expr(self.use_scheduler_barrier): + if warp_group_idx == 1: + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + def warp_scheduler_barrier_sync(self): + if const_expr(self.use_scheduler_barrier): + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def warp_scheduler_barrier_arrive(self): + if const_expr(self.use_scheduler_barrier): + assert self.num_mma_warp_groups in [2, 3] + cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 + if const_expr(self.num_mma_warp_groups == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_mma_warp_groups + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, + number_of_threads=2 * self.num_threads_per_warp_group, + )