diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7d5e4ade468de366bb73eed0ccb38d4e358cf8 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100. + +Hub-kernel packaging of the CuTe-DSL sparse attention stack from +https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The +host-side helper kernels (CSR builder, decode scheduler) are precompiled +Torch ops; the attention kernels are compiled at runtime through +nvidia-cutlass-dsl. +""" + +# Sparse attention forward / decode. +from .interface import ( + SparseDecodePagedAttentionWrapper, + sparse_atten_func, + sparse_atten_nvfp4_kv_func, + sparse_decode_atten_func, +) + +# CSR + schedule construction. +from .sparse_index_utils import build_k2q_csr + +# SM100 fused CSR builder. +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + +# FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK +# selection + q2k construction remain caller-owned downstream steps. +from .fp4_indexer_interface import fp4_indexer_block_scores + +# NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention. +from .quantize import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + quantize_bf16_to_nvfp4_128x4, + quantize_kv_bf16_to_nvfp4_128x4, + swizzle_nvfp4_scale_to_128x4, +) + +__version__ = "0.1.1" + +__all__ = [ + # attention + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "SparseDecodePagedAttentionWrapper", + # indexing / CSR + "fp4_indexer_block_scores", + "build_k2q_csr", + "SparseK2qCsrBuilderSm100", + # nvfp4 quantization helpers + "Nvfp4QuantizedTensor", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "dequantize_nvfp4_128x4_to_bf16", + "swizzle_nvfp4_scale_to_128x4", + "nvfp4_global_scale_from_amax", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..61b23503a28f088ac223d37d361b598de2219004 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8dcd8c86e512f3ddd5acb95f6fdcad3cfaa1579bb6f874a714fba066e6877161 +size 1169368 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6be2da4d5d784683e9e2fb8bfe08e93847dc6640 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _msa_cuda_09d7851 +ops = torch.ops._msa_cuda_09d7851 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_msa_cuda_09d7851::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py b/build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..48dc1d05480355d2af4f4e47142ae4cd692184b0 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py @@ -0,0 +1,1061 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Public FP4 sparse-attention indexer block-score interface.""" + +from __future__ import annotations + +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32 +from cutlass.cute.runtime import make_ptr + +from .src.sm100.fp4_indexer import ( + Fp4FormatSpec, + Fp4IndexerDecodePackedQSm100, + Fp4IndexerDecodeQPackSm100, + Fp4IndexerScaleReorderSm100, + Fp4IndexerStagedMmaSm100, + _BLOCK_K, + _DECODE_K_TILES_PER_CTA, + _DECODE_PACK_Q_LEN, + _DECODE_QHEAD_PER_KV, + _FP4_PACKED_D_BYTES, + _HEAD_DIM, + _MMA_TILER_MN, + _PAGE_SIZE, + ceil_div, + k_tiles_per_cta_for, + normalize_fp4_format, +) + + +_PUBLIC_SCALE_LAYOUT = "public" +_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma" +_FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _device_arch(device: torch.device) -> tuple[int, int]: + major, minor = torch.cuda.get_device_capability(device) + return int(major), int(minor) + + +def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool: + return device_arch >= (10, 3) + + +def normalize_scale_layout(scale_layout: str) -> str: + """Normalize and validate FP4 indexer scale layout mode. + + Parameters + ---------- + scale_layout : str + Either ``"public"`` for logical scale tensors or ``"preordered_mma"`` + for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``. + + Returns + ------- + str + The normalized scale layout string. + """ + + scale_layout = str(scale_layout) + if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT): + raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}") + return scale_layout + + +def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int: + if q_len <= 0 or k_len <= 0: + return 0 + q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0]) + k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta) + group_tokens = k_tiles_per_cta * _BLOCK_K + causal_offset = int(k_len) - int(q_len) + tasks = 0 + for q_tile_idx in range(q_tile_count): + q_tile_start = q_tile_idx * _MMA_TILER_MN[0] + q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1) + visible_limit = q_tile_last + causal_offset + if visible_limit >= 0: + tasks += min(k_group_count, visible_limit // group_tokens + 1) + return tasks + + +def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int: + """Conservative X-grid bound for per-batch causal prefill compact mapping.""" + + if max_q_len <= 0 or max_k_len <= 0: + return 0 + q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0]) + candidates = {int(max_q_len)} + for q_tile_idx in range(q_tile_count): + q_len = q_tile_idx * _MMA_TILER_MN[0] + 1 + if q_len <= max_q_len: + candidates.add(q_len) + return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates) + + +def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None: + if tensor.device != device: + raise ValueError(f"{name} must be on the same CUDA device") + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None: + fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None) + allowed = {torch.uint8, torch.int8} + if fp4_x2_dtype is not None: + allowed.add(fp4_x2_dtype) + if tensor.dtype not in allowed: + raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}") + + +def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 3: + raise ValueError(f"{name} must have shape [total_q, Hq, 64]") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 4: + raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]") + if int(tensor.shape[-2]) != _PAGE_SIZE: + raise ValueError(f"{name}.shape[-2] must be 128") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def validate_q_scale_thg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + total_q: int, + heads: int, +) -> None: + """Validate public Q FP4 scale layout ``[total_q, Hq, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical Q scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + total_q : int + Total query token count. + heads : int + Number of Q heads. + """ + + expected = (int(total_q), int(heads), fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def validate_k_scale_phsg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + page_count: int, + heads: int, +) -> None: + """Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical K scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + page_count : int + Number of physical KV pages. + heads : int + Number of KV heads. + """ + + expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``.""" + + spec = normalize_fp4_format(fp4_format) + return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l)) + + +def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g) + + +def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return contiguous storage shape for preordered MMA scales.""" + + spec = normalize_fp4_format(fp4_format) + return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4) + + +def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_storage_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1) + + +def validate_mma_scale_storage( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + mn: int, + l: int, +) -> None: + """Validate preordered MMA scale storage expected by the FP4 indexer. + + Parameters + ---------- + scale : torch.Tensor + Tensor view whose shape/stride should match + ``fp4_indexer_mma_scale_storage_shape`` and + ``fp4_indexer_mma_scale_storage_stride``. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + mn : int + Logical M/N extent of the scale domain. + l : int + Logical batch/head extent folded into the final layout dimension. + """ + + expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name) + expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name) + if tuple(scale.shape) != expected_shape: + raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}") + if tuple(scale.stride()) != expected_stride: + raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + + +def _empty_mma_scale_tensor( + *, + mn: int, + l: int, + spec: Fp4FormatSpec, + device: torch.device, +) -> torch.Tensor: + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + storage = torch.empty( + (int(l), rest_m, rest_g, 32, 4, 4), + dtype=spec.torch_scale_dtype, + device=device, + ) + return storage.permute(3, 4, 1, 5, 2, 0) + + +def _compile_fp4_scale_reorder_kernel( + *, + fmt: Fp4FormatSpec, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_scale_reorder_sm100_1cta", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_reorder_scales_for_mma_cute( + q_scale: torch.Tensor, + k_scale: torch.Tensor, + *, + fp4_format: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reorder public Q/K FP4 scales to MMA-friendly storage. + + Parameters + ---------- + q_scale : torch.Tensor + Public Q scale tensor with shape ``[total_q, Hq, G]``. + k_scale : torch.Tensor + Public K scale tensor with shape ``[page_count, Hk, 128, G]``. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by + ``validate_mma_scale_storage``. These tensors can be passed to + ``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``. + """ + + spec = normalize_fp4_format(fp4_format) + if q_scale.device != k_scale.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device") + _require_cuda_tensor(q_scale, name="q_scale") + _require_cuda_tensor(k_scale, name="k_scale") + if q_scale.ndim != 3: + raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}") + if k_scale.ndim != 4: + raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}") + total_q, heads_q, _ = (int(v) for v in q_scale.shape) + page_count, heads_k, _, _ = (int(v) for v in k_scale.shape) + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + + q_scale_mma = _empty_mma_scale_tensor( + mn=total_q, + l=heads_q, + spec=spec, + device=q_scale.device, + ) + k_scale_mma = _empty_mma_scale_tensor( + mn=_PAGE_SIZE, + l=page_count * heads_k, + spec=spec, + device=k_scale.device, + ) + + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + q_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + problem_size = ( + Int32(total_q), + Int32(heads_q), + Int32(page_count), + Int32(heads_k), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream) + compiled = _compile_fp4_scale_reorder_kernel( + fmt=spec, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + q_scale_mma_ptr=q_scale_mma_ptr, + k_scale_mma_ptr=k_scale_mma_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return q_scale_mma, k_scale_mma + + +def _compile_fp4_decode_q_pack_kernel( + *, + fmt: Fp4FormatSpec, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_q_pack_sm100", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _pack_decode_q_for_mma( + q_bytes: torch.Tensor, + q_scale_storage: torch.Tensor, + cu_seqlens_q: torch.Tensor, + *, + fmt: Fp4FormatSpec, + heads_q: int, + heads_k: int, + batch: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_pack = torch.empty( + (batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + dtype=torch.uint8, + device=q_bytes.device, + ) + q_scale_pack = torch.empty( + fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name), + dtype=fmt.torch_scale_dtype, + device=q_bytes.device, + ) + if q_pack.data_ptr() % 128 != 0: + raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA") + if q_scale_pack.data_ptr() % 32 != 0: + raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned") + q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(q_bytes.shape[0]), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream) + compiled = _compile_fp4_decode_q_pack_kernel( + fmt=fmt, + q_ptr=q_ptr, + q_scale_ptr=q_scale_ptr, + q_pack_ptr=q_pack_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return q_pack, q_scale_pack + + +def _compile_fp4_decode_packed_q_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_packed_q_sm100", + fmt.name, + bool(causal), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodePackedQSm100( + fmt=fmt.name, + causal=causal, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _run_fp4_decode_packed_q_scores( + q_pack: torch.Tensor, + k_bytes: torch.Tensor, + q_scale_pack: torch.Tensor, + k_scale_storage: torch.Tensor, + scores: torch.Tensor, + kv_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + qo_offset_arg: torch.Tensor, + *, + fmt: Fp4FormatSpec, + causal: bool, + has_qo_offset: int, + heads_q: int, + heads_k: int, + batch: int, + max_k_tiles: int, + total_q: int, + device_arch: tuple[int, int], + use_tmem_load_red: bool, +) -> None: + page_count = int(k_bytes.shape[0]) + rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA) + compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA) + compact_schedule = compact_groups < rectangular_groups + if compact_schedule: + scores.fill_(float("-inf")) + + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + k_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + problem_size = ( + Int32(_PAGE_SIZE), + Int32(max_k_tiles * _PAGE_SIZE), + Int32(_HEAD_DIM), + Int32(batch * heads_k), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream) + compiled = _compile_fp4_decode_packed_q_kernel( + fmt=fmt, + causal=causal, + compact_schedule=compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_pack_ptr=q_pack_ptr, + k_ptr=k_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + + +def _compile_fp4_qk_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + preordered_q_scale_tma: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_staged_mma_sm100", + fmt.name, + bool(causal), + bool(preordered_q_scale_tma), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerStagedMmaSm100( + fmt=fmt.name, + causal=causal, + preordered_q_scale_tma=preordered_q_scale_tma, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_block_scores( + q_fp4: torch.Tensor, + k_fp4: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + *, + max_seqlen_q: int, + max_seqlen_k: int, + kv_indices: torch.Tensor, + fp4_format: str, + causal: bool = False, + qo_offset: Optional[torch.Tensor] = None, + scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT, +) -> torch.Tensor: + """Return FP4 QK max scores per 128-token KV page. + + Parameters + ---------- + q_fp4 : torch.Tensor + Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last + dimension stores two FP4 values per byte for logical head dimension + 128. + k_fp4 : torch.Tensor + Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``. + q_scale : torch.Tensor + Q scale tensor. With ``scale_layout="public"``, shape is + ``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use + ``fp4_indexer_reorder_scales_for_mma_cute`` output layout. + k_scale : torch.Tensor + K scale tensor. With ``scale_layout="public"``, shape is + ``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the + preordered MMA scale layout. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + cu_page_offsets : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request + page counts. + max_seqlen_q : int + Maximum Q sequence length. + max_seqlen_k : int + Maximum KV sequence length. + kv_indices : torch.Tensor + Flattened physical page indices with shape ``[sum_pages]`` and dtype + int32. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + causal : bool, optional + Whether to apply causal masking. + qo_offset : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid + only when ``causal=True``. + scale_layout : str, optional + ``"public"`` accepts logical public scale tensors and launches a scale + reorder kernel. ``"preordered_mma"`` expects preordered MMA scale + tensors and skips the reorder. + + Returns + ------- + torch.Tensor + Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32. + Entries beyond the valid KV page range are ``-inf``. + """ + + spec = normalize_fp4_format(fp4_format) + causal = bool(causal) + scale_layout = normalize_scale_layout(scale_layout) + use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE + q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4") + k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4") + total_q, heads_q, _ = (int(v) for v in q_bytes.shape) + page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape) + if page_size != _PAGE_SIZE: + raise ValueError(f"k_fp4 page_size must be 128, got {page_size}") + if heads_q % heads_k != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + _require_cuda_tensor(q_fp4, name="q_fp4") + _require_cuda_tensor(k_fp4, name="k_fp4") + device_arch = _device_arch(q_fp4.device) + use_tmem_load_red = _supports_tmem_load_red(device_arch) + _require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device) + _require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device) + _require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device) + if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4") + if scale_layout == _PUBLIC_SCALE_LAYOUT: + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + else: + validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q) + validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k) + batch = int(cu_seqlens_q.shape[0]) - 1 + if batch < 0: + raise ValueError("cu_seqlens_q must have shape [B + 1]") + if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape: + raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]") + if q_bytes.data_ptr() % 128 != 0: + raise ValueError("q_fp4 data pointer must be 128B aligned for TMA") + if k_bytes.data_ptr() % 128 != 0: + raise ValueError("k_fp4 data pointer must be 128B aligned for TMA") + if kv_indices is None: + raise ValueError("kv_indices is required") + if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1: + raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device") + if not kv_indices.is_contiguous(): + raise ValueError("kv_indices must be contiguous") + if qo_offset is not None: + if not causal: + raise ValueError("qo_offset is only valid when causal=True") + if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,): + raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device") + if not qo_offset.is_contiguous(): + raise ValueError("qo_offset must be contiguous") + + m_extent = int(max_seqlen_q) + max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE) + n_aligned = max_k_tiles * _PAGE_SIZE + if max_k_tiles == 0: + return torch.full( + (heads_q, 0, total_q), + float("-inf"), + dtype=torch.float32, + device=q_fp4.device, + ) + + scores = torch.empty( + (heads_q, max_k_tiles, total_q), + dtype=torch.float32, + device=q_fp4.device, + ) + if qo_offset is None: + qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device) + has_qo_offset = 0 + else: + qo_offset_arg = qo_offset + has_qo_offset = 1 + if scale_layout == _PUBLIC_SCALE_LAYOUT: + q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute( + q_scale, + k_scale, + fp4_format=spec.name, + ) + else: + q_scale_arg = q_scale + k_scale_arg = k_scale + scale_assumed_align = 32 + if q_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + if k_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV + if use_decode_packed_q: + q_pack, q_scale_pack = _pack_decode_q_for_mma( + q_bytes, + q_scale_arg, + cu_seqlens_q, + fmt=spec, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + ) + _run_fp4_decode_packed_q_scores( + q_pack, + k_bytes, + q_scale_pack, + k_scale_arg, + scores, + kv_indices, + cu_seqlens_q, + cu_seqlens_k, + cu_page_offsets, + qo_offset_arg, + fmt=spec, + causal=causal, + has_qo_offset=has_qo_offset, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + max_k_tiles=max_k_tiles, + total_q=total_q, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + ) + return scores + prefill_compact_task_count = 0 + prefill_compact_schedule = False + if causal and has_qo_offset == 0: + k_tiles_per_cta = k_tiles_per_cta_for(causal) + q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0]) + k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta) + rectangular_task_count = q_tile_count * k_group_count + prefill_compact_task_count = min( + rectangular_task_count, + _causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta), + ) + prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19 + if prefill_compact_schedule: + scores.fill_(float("-inf")) + q_ptr = make_ptr( + cutlass.Uint8, + q_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + k_ptr = make_ptr( + cutlass.Uint8, + k_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + scores_ptr = make_ptr( + cutlass.Float32, + scores.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + kv_indices_ptr = make_ptr( + cutlass.Int32, + kv_indices.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_k_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_k.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_page_offsets_ptr = make_ptr( + cutlass.Int32, + cu_page_offsets.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + qo_offset_ptr = make_ptr( + cutlass.Int32, + qo_offset_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(m_extent), + Int32(n_aligned), + Int32(_HEAD_DIM), + Int32(batch * heads_q), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + Int32(prefill_compact_task_count), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream) + compiled = _compile_fp4_qk_kernel( + fmt=spec, + causal=causal, + preordered_q_scale_tma=use_preordered_q_scale_tma, + compact_schedule=prefill_compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_ptr=q_ptr, + k_ptr=k_ptr, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return scores + + +__all__ = [ + "fp4_indexer_block_scores", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/interface.py b/build/torch211-cxx11-cu128-x86_64-linux/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9e507961840b3322238646ffffe3e97cf5d13130 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/interface.py @@ -0,0 +1,2011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse attention interface. + +Current delivery scope: + - head dimension is supported only for D=128 + +Public API: + sparse_atten_func(...) + sparse_decode_atten_func(...) + SparseDecodePagedAttentionWrapper + +Internal forward core: + _sparse_atten_csr_varlen_forward(...) + +Preprocessing (external, done once): + q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr() + -> k2q_row_ptr [head_kv, total_rows + 1] int32 + -> k2q_q_indices [head_kv, total_q * topK] int32 +""" + +import math +import os +from typing import Optional + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 +from cutlass.cute.runtime import from_dlpack + +from .src.sm100.fwd.combine import combine +from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100 +from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 +from .src.sm100.prepare_scheduler import ( + SparseAttentionSchedule, + prepare_sparse_fwd_schedule_and_split, +) +from .src.sm100.decode_schedule import ( + DecodeAttentionSchedule, + prepare_decode_schedule, +) +from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter +from .src.common.tma_utils import ( + create_q_gather4_tma_desc, +) + +_compile_cache: dict = {} +_TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12 +_SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32) +_SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_DECODE_QHEAD_PER_KV = 16 + + +def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype: + supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn} + if partial_dtype not in supported: + raise TypeError( + "partial_dtype must be one of torch.float32 / torch.bfloat16 / " + "torch.float16 / torch.float8_e4m3fn, " + f"got {partial_dtype}" + ) + return partial_dtype + + +def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype: + dtype = fallback if dtype is None else dtype + if dtype not in _SUPPORTED_FWD_MMA_DTYPES: + raise TypeError( + f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}" + ) + return dtype + + +def _resolve_forward_mma_dtypes( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qk_dtype: Optional[torch.dtype], + pv_dtype: Optional[torch.dtype], +) -> tuple[torch.dtype, torch.dtype]: + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + if pv_dtype is None: + # Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V + # stages both K and V as BF16 compute operands. + if ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ): + pv_dtype = torch.bfloat16 + else: + pv_dtype = v.dtype + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype") + + if q.dtype != qk_dtype: + raise ValueError( + "qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported" + ) + if k.dtype != qk_dtype: + if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16): + raise ValueError( + "unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported" + ) + if v.dtype != pv_dtype: + if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16): + raise ValueError( + "unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported" + ) + return qk_dtype, pv_dtype + + +def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4): + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) + return tensor.mark_layout_dynamic(leading_dim=0) + + +def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype): + if dtype == torch.bfloat16: + return cutlass.BFloat16 + if dtype == torch.float16: + return cutlass.Float16 + if dtype == torch.float8_e4m3fn: + return cutlass.Float8E4M3FN + raise TypeError( + f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}" + ) + + +def _prepare_paged_kv_for_tma(k, v, blk_kv: int): + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + return k, v + + +def _validate_cu_seqlens( + cu_seqlens: torch.Tensor, + *, + name: str, + device: torch.device, +) -> None: + if cu_seqlens.device != device: + raise ValueError(f"{name} must be on the same device as q") + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must have shape [B + 1]") + if cu_seqlens.shape[0] < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int: + return int(k2q_row_ptr.shape[1] - 1) + + +def _validate_csr_varlen_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in _SUPPORTED_FWD_DTYPES: + raise TypeError( + "CSR sparse forward supports only torch.bfloat16 and " + f"torch.float8_e4m3fn Q/K/V, got {q.dtype}" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("q, k, v must be on the same device") + mixed_fp8_kv_bf16_q = ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ) + if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype): + raise ValueError( + "q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache" + ) + if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]: + raise ValueError("q, k, v must have the same head dimension") + dim = q.shape[-1] + if dim != 128: + raise NotImplementedError( + f"CSR sparse forward currently supports only D=128, got D={dim}" + ) + if page_table is None: + if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]: + raise ValueError("k and v must have the same [Hkv, D] tail dimensions") + head_kv = k.shape[-2] + else: + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must have the same Hkv and D" + ) + head_kv = k.shape[1] + if ( + q.device != k2q_row_ptr.device + or q.device != k2q_q_indices.device + ): + raise ValueError("CSR metadata must be on the same device as q") + if ( + k2q_row_ptr.dtype != torch.int32 + or k2q_q_indices.dtype != torch.int32 + ): + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + total_q = q.shape[0] + + head_q = q.shape[1] + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < total_q * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + total_k = k.shape[0] + if k.ndim != 3 or v.ndim != 3: + raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]") + if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]): + raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]") + else: + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2 or page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape != v.shape: + raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}") + if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must match " + "[num_pages, Hkv, page_size, D]" + ) + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError( + f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; " + "require page_size == blk_kv" + ) + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_csr_varlen_nvfp4_kv_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn): + raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}") + if q.shape[-1] != 128: + raise NotImplementedError( + f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}" + ) + if k.dtype != torch.uint8 or v.dtype != torch.uint8: + raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}") + if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8: + raise TypeError( + "KVFP4 block scales must be torch.uint8 E4M3 tensors, got " + f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}" + ) + if k_global_scale is not None and k_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None") + if v_global_scale is not None and v_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None") + tensors = ( + k, + v, + k_scale_128x4, + v_scale_128x4, + k2q_row_ptr, + k2q_q_indices, + cu_seqlens_q, + cu_seqlens_k, + ) + optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None) + if any(t.device != q.device for t in tensors + optional_tensors): + raise ValueError("KVFP4 inputs and metadata must be on the same device as q") + if k.shape != v.shape: + raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}") + packed_dim = q.shape[-1] // 2 + scale_cols = q.shape[-1] // 16 + if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2: + raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors") + if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols: + raise ValueError( + "KVFP4 block scales must have at least D/16 columns; " + f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}" + ) + if k_global_scale is not None and k_global_scale.numel() < 1: + raise ValueError("KVFP4 K global scale must contain at least one element") + if v_global_scale is not None and v_global_scale.numel() < 1: + raise ValueError("KVFP4 V global scale must contain at least one element") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + if k.ndim != 3: + raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]") + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + total_k = int(k.shape[0]) + head_kv = int(k.shape[1]) + required_scale_rows = total_k * head_kv + else: + if k.ndim != 4: + raise ValueError( + "KVFP4 Sparse Page Attention requires k/v shape " + "[num_pages, Hkv, page_size, D/2]" + ) + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError( + f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}" + ) + head_kv = int(k.shape[1]) + required_scale_rows = int(k.shape[0]) * head_kv * page_size + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + + padded_scale_rows = ((required_scale_rows + 127) // 128) * 128 + padded_scale_cols = ((scale_cols + 3) // 4) * 4 + for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)): + if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols: + raise ValueError( + f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, " + f"need at least {(padded_scale_rows, padded_scale_cols)}" + ) + + if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device: + raise ValueError("CSR metadata must be on the same device as q") + if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32: + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + if page_table is not None and page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if seqused_k is not None and seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < q.shape[0] * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_sparse_decode_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("decode attention requires q to have shape [total_q, Hq, D]") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("decode q, k, and v must be on the same device") + if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype: + raise TypeError( + "decode attention currently supports only torch.float8_e4m3fn Q/K/V" + ) + if k.shape != v.shape: + raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}") + if q.shape[-1] != 128 or k.shape[-1] != 128: + raise NotImplementedError( + f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}" + ) + if not bool(causal): + raise NotImplementedError("decode attention currently supports only causal=True") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + + head_kv = int(k.shape[1]) + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("decode q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV: + raise NotImplementedError( + "decode attention currently supports only " + f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}" + ) + + if page_table is None: + raise ValueError("decode attention requires page_table") + if page_table.device != q.device: + raise ValueError("decode page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("decode page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]") + batch = int(page_table.shape[0]) + if page_table.stride(-1) != 1: + raise ValueError("decode page_table must be contiguous in the last dimension") + + if seqused_k is None: + raise ValueError("decode attention requires seqused_k") + if seqused_k.device != q.device: + raise ValueError("decode seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("decode seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("decode seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("decode seqused_k must be contiguous") + + seqlen_q = int(seqlen_q) + max_seqlen_k = int(max_seqlen_k) + if seqlen_q <= 0 or max_seqlen_k <= 0: + raise ValueError("decode seqlen_q and max_seqlen_k must be positive") + if int(q.shape[0]) != batch * seqlen_q: + raise ValueError("decode q.shape[0] must equal batch * seqlen_q") + + if q2k_indices is not None: + if q2k_indices.device != q.device: + raise ValueError("decode q2k_indices must be on the same device as q") + if q2k_indices.dtype != torch.int32: + raise TypeError("decode q2k_indices must be torch.int32") + if q2k_indices.ndim != 3: + raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]") + if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]: + raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]") + if not q2k_indices.is_contiguous(): + raise ValueError("decode q2k_indices must be contiguous") + return batch, head_kv + + +def _validate_schedule_common( + schedule: SparseAttentionSchedule, + *, + device: torch.device, +) -> None: + if schedule.scheduler_metadata is None: + raise ValueError("schedule.scheduler_metadata is required") + if schedule.work_count is None: + raise ValueError("schedule.work_count is required") + metadata = schedule.scheduler_metadata + work_count = schedule.work_count + if metadata.device != device or work_count.device != device: + raise ValueError("schedule tensors must be on the same device as q") + if metadata.dtype != torch.int32 or work_count.dtype != torch.int32: + raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32") + if metadata.ndim != 2 or metadata.shape[1] != 6: + raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]") + if work_count.shape != (1,): + raise ValueError("schedule.work_count must have shape [1]") + if not metadata.is_contiguous() or not work_count.is_contiguous(): + raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous") + + +def _validate_fwd_schedule( + schedule: SparseAttentionSchedule, + *, + q: torch.Tensor, + k2q_q_indices: torch.Tensor, + head_kv: int, +) -> None: + _validate_schedule_common(schedule, device=q.device) + if schedule.qsplit_indices is None: + raise ValueError("schedule.qsplit_indices is required for forward") + if schedule.split_counts is None: + raise ValueError("schedule.split_counts is required for forward") + qsplit = schedule.qsplit_indices + split_counts = schedule.split_counts + if qsplit.device != q.device or split_counts.device != q.device: + raise ValueError("forward schedule tensors must be on the same device as q") + if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32: + raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32") + if qsplit.shape != k2q_q_indices.shape: + raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices") + total_q = q.shape[0] + if split_counts.shape != (total_q, head_kv): + raise ValueError( + "schedule.split_counts must have shape " + f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if not qsplit.is_contiguous() or not split_counts.is_contiguous(): + raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous") + + +def sparse_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, + usable_SM_count: int = -1, + qk_dtype: Optional[torch.dtype] = None, + pv_dtype: Optional[torch.dtype] = None, +): + """Run SM100 CSR block-sparse varlen attention. + + This is the public forward-only sparse attention API. It consumes + query-to-key block selections converted to CSR metadata by + ``build_k2q_csr`` and supports both dense KV layout and paged KV layout. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Dense layout ``[total_k, Hkv, 128]`` or paged layout + ``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K + may be FP8 E4M3 while QK compute uses BF16 staging. + v : torch.Tensor + Same layout and head count as ``k``. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + max_seqlen_q : int + Maximum Q sequence length in the batch. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + KV block size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return LSE computed with logits scaled by + ``softmax_scale / lse_temperature_scale``. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. Supported values are + FP32, BF16, FP16, and FP8 E4M3. + return_softmax_lse : bool, optional + If True, return ``(out, softmax_lse)`` or + ``(out, softmax_lse, temperature_lse)``. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Effective KV length per request + for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. If omitted, the schedule is built + during the call. + usable_SM_count : int, optional + Maximum number of SMs used by the scheduler. ``-1`` uses all SMs. + qk_dtype : torch.dtype, optional + Compile-time MMA operand dtype for QK. Defaults to Q storage dtype, + except supported FP8 K/V cache staging modes. + pv_dtype : torch.dtype, optional + Compile-time MMA operand dtype for PV. Defaults to V storage dtype, + except supported FP8 K/V cache staging modes. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + + Notes + ----- + ``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support + head dimension 128 only. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_inputs( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + max_seqlen_q = int(max_seqlen_q) + max_seqlen_k = int(max_seqlen_k) + + return _sparse_atten_csr_varlen_forward( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + int(topK), + int(blk_kv), + bool(causal), + float(softmax_scale), + lse_temperature_scale, + return_temperature_lse, + partial_dtype, + bool(return_softmax_lse), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + schedule, + int(usable_SM_count), + int(batch), + int(head_kv), + int(max_seqlen_q), + int(max_seqlen_k), + qk_dtype, + pv_dtype, + ) + + +def sparse_atten_nvfp4_kv_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Run SM100 CSR sparse attention with packed NVFP4 K/V. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged + layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8 + because each byte packs two FP4 values. + v : torch.Tensor + Packed NVFP4 V data with the same shape as ``k``. + k_scale_128x4 : torch.Tensor + K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8 + containing FP8 E4M3 scale values. + v_scale_128x4 : torch.Tensor + V block scales in the same 128x4 tiled storage. + k_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for K. May be ``None``. + v_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for V. May be ``None``. The V global + scale is applied in the combine stage. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q, cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV + lengths. + max_seqlen_q, max_seqlen_k : int + Maximum Q and KV sequence lengths in the batch. + blk_kv : int, optional + KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return temperature-scaled LSE. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. + return_softmax_lse : bool, optional + If True, return LSE together with the output. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Effective KV length per request for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + """ + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + total_q, head_q, dim = q.shape + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + + schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k_scale_128x4.contiguous(), + v_scale_128x4.contiguous(), + None if k_global_scale is None else k_global_scale.contiguous(), + None if v_global_scale is None else v_global_scale.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + k2q_qsplit_indices.contiguous(), + split_counts.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + O_partial, + LSE_partial, + LSE_temperature_partial, + float(softmax_scale), + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + int(blk_kv), + head_kv, + int(max_seqlen_q), + causal=bool(causal), + schedule=schedule, + ) + + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + output_scale=v_global_scale, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def sparse_decode_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor] = None, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = True, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + schedule: Optional[DecodeAttentionSchedule] = None, + O_partial: Optional[torch.Tensor] = None, + LSE_partial: Optional[torch.Tensor] = None, +): + """Run forward-only paged FP8 decode attention. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8 + E4M3 dtype. + v : torch.Tensor + Paged V cache with the same shape and dtype as ``k``. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype + int32. ``None`` selects the dense all-KV decode path. + page_table : torch.Tensor + Physical page table with shape ``[batch_size, max_num_pages_per_seq]`` + and dtype int32. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per request. + seqlen_q : int + Uniform query length per request. Ragged Q lengths should use prefill + or append paths instead. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + Page size. Must match ``k.shape[2]``. + causal : bool, optional + Whether to apply causal masking. Current decode kernel requires True. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + schedule : DecodeAttentionSchedule, optional + Prebuilt decode schedule. + O_partial, LSE_partial : torch.Tensor, optional + Optional split-KV partial workspaces. Normally owned by + ``SparseDecodePagedAttentionWrapper``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output with shape ``q.shape``. Optional LSE has shape + ``[batch_size * seqlen_q, Hq]`` and dtype float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + batch, head_kv = _validate_sparse_decode_inputs( + q, + k, + v, + q2k_indices, + page_table=page_table, + seqused_k=seqused_k, + seqlen_q=seqlen_q, + max_seqlen_k=max_seqlen_k, + blk_kv=blk_kv, + causal=causal, + ) + head_q = int(q.shape[1]) + head_dim = int(q.shape[2]) + if schedule is None: + schedule = prepare_decode_schedule( + seqused_k=seqused_k.contiguous(), + page_size=int(blk_kv), + seqlen_q=int(seqlen_q), + num_qo_heads=head_q, + num_kv_heads=head_kv, + head_dim=head_dim, + max_seqlen_k=int(max_seqlen_k), + ) + if schedule.split_kv: + if O_partial is None: + O_partial = torch.empty( + (schedule.partial_rows, head_q, head_dim), + dtype=torch.float32, + device=q.device, + ) + if LSE_partial is None: + LSE_partial = torch.empty( + (schedule.partial_rows, head_q), + dtype=torch.float32, + device=q.device, + ) + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + lse = torch.empty( + q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q), + dtype=torch.float32, + device=q.device, + ) + _call_sparse_decode_forward_sm100_paged_fp8( + q.contiguous(), + k.contiguous(), + v.contiguous(), + None if q2k_indices is None else q2k_indices.contiguous(), + page_table.contiguous(), + seqused_k.contiguous(), + out, + lse, + schedule, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + max_seqlen_k=int(max_seqlen_k), + blk_kv=int(blk_kv), + causal=bool(causal), + return_lse=bool(return_softmax_lse), + ) + if return_softmax_lse: + return out, lse + return out + + +class SparseDecodePagedAttentionWrapper: + """Plan/run helper for paged FP8 decode attention. + + Use this wrapper when the same page table shape and sequence metadata are + reused across multiple decode layers. ``plan`` validates metadata and + allocates persistent schedules/workspaces; ``run`` then launches the decode + kernel with lower per-call overhead than ``sparse_decode_atten_func``. + """ + + def __init__(self, *, blk_kv: int = 128, causal: bool = True): + self.blk_kv = int(blk_kv) + self.causal = bool(causal) + self.batch: Optional[int] = None + self.num_qo_heads: Optional[int] = None + self.num_kv_heads: Optional[int] = None + self.head_dim: Optional[int] = None + self.page_table: Optional[torch.Tensor] = None + self.seqused_k: Optional[torch.Tensor] = None + self.q2k_indices: Optional[torch.Tensor] = None + self.seqlen_q: Optional[int] = None + self.max_seqlen_k: Optional[int] = None + self.is_sparse: bool = False + self.decode_schedule: Optional[DecodeAttentionSchedule] = None + self.request_indices: Optional[torch.Tensor] = None + self.qo_tile_indices: Optional[torch.Tensor] = None + self.kv_tile_indices: Optional[torch.Tensor] = None + self.merge_indptr: Optional[torch.Tensor] = None + self.o_indptr: Optional[torch.Tensor] = None + self.block_valid_mask: Optional[torch.Tensor] = None + self.kv_pages: Optional[torch.Tensor] = None + self.split_counts: Optional[torch.Tensor] = None + self.split_kv: bool = False + self.cta_tile_q: int = 0 + self.num_q_tiles: int = 0 + self.kv_chunk_size_pages: int = 0 + self.kv_chunk_size_tokens: int = 0 + self.work_count: int = 0 + self.padded_work_count: int = 0 + self.O_partial: Optional[torch.Tensor] = None + self.LSE_partial: Optional[torch.Tensor] = None + # Cached dummy buffers used in non-split path to satisfy the kernel's + # positional arg signature without per-call torch.empty (saves ~5us + # on every run() for small kv). + self._O_partial_dummy: Optional[torch.Tensor] = None + self._LSE_partial_dummy: Optional[torch.Tensor] = None + # When the caller doesn't ask for LSE, the kernel still needs a valid + # tensor pointer to write to. Cache a small placeholder so run() can + # skip the per-call torch.empty for it as well. + self._lse_dummy: Optional[torch.Tensor] = None + + def plan( + self, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + q2k_indices: Optional[torch.Tensor] = None, + num_qo_heads: Optional[int] = None, + num_kv_heads: Optional[int] = None, + head_dim: Optional[int] = 128, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, + ) -> "SparseDecodePagedAttentionWrapper": + """Prepare decode scheduling metadata and reusable workspaces. + + Parameters + ---------- + page_table : torch.Tensor + Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps + logical pages to physical KV-cache pages. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per + request. + seqlen_q : int + Uniform query length per request. + max_seqlen_k : int + Maximum KV sequence length in the batch. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and + dtype int32. ``None`` selects the dense all-KV path. + num_qo_heads : int + Number of Q/O heads. + num_kv_heads : int + Number of KV heads. Current decode kernel requires + ``num_qo_heads / num_kv_heads == 16`` at run time. + head_dim : int, optional + Head dimension. Must be 128. + enable_cuda_graph : bool, optional + Build schedule metadata compatible with CUDA graph capture. + max_grid_size : int, optional + Override maximum CTA count used by the scheduler. + fixed_split_size : int, optional + Force a fixed split-KV chunk size in pages. + disable_split_kv : bool, optional + Disable split-KV even for long KV sequences. + + Returns + ------- + SparseDecodePagedAttentionWrapper + ``self``, planned and ready for ``run``. + """ + if page_table.ndim != 2: + raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]") + if page_table.dtype != torch.int32: + raise TypeError("decode plan requires page_table to be torch.int32") + if seqused_k.dtype != torch.int32: + raise TypeError("decode plan requires seqused_k to be torch.int32") + if not page_table.is_cuda or not seqused_k.is_cuda: + raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors") + if page_table.device != seqused_k.device: + raise ValueError("decode plan requires page_table and seqused_k on the same device") + if page_table.stride(-1) != 1: + raise ValueError("decode plan requires page_table contiguous in the last dimension") + if seqused_k.shape != (int(page_table.shape[0]),): + raise ValueError("decode plan requires seqused_k with shape [B]") + if q2k_indices is not None and q2k_indices.dtype != torch.int32: + raise TypeError("decode plan requires q2k_indices to be torch.int32") + if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0: + raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k") + if num_qo_heads is None or num_kv_heads is None or head_dim is None: + raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim") + if head_dim is not None and int(head_dim) != 128: + raise NotImplementedError("decode plan currently supports only head_dim=128") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads") + + self.batch = int(page_table.shape[0]) + self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads) + self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads) + self.head_dim = None if head_dim is None else int(head_dim) + self.page_table = page_table.contiguous() + self.seqused_k = seqused_k.contiguous() + self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous() + self.seqlen_q = int(seqlen_q) + self.max_seqlen_k = int(max_seqlen_k) + self.is_sparse = q2k_indices is not None + + # max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++ + # schedule launcher because the decode attn kernel always runs at + # 1 CTA/SM (its register/smem budget saturates the SM). Callers + # can still override via the explicit max_grid_size kwarg. + schedule = prepare_decode_schedule( + seqused_k=self.seqused_k, + page_size=self.blk_kv, + seqlen_q=self.seqlen_q, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seqlen_k=self.max_seqlen_k, + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=max_grid_size, + fixed_split_size=fixed_split_size, + disable_split_kv=bool(disable_split_kv), + ) + self.decode_schedule = schedule + self.request_indices = schedule.request_indices + self.qo_tile_indices = schedule.qo_tile_indices + self.kv_tile_indices = schedule.kv_tile_indices + self.merge_indptr = schedule.merge_indptr + self.o_indptr = schedule.o_indptr + self.block_valid_mask = schedule.block_valid_mask + self.kv_pages = schedule.kv_pages + self.split_counts = schedule.split_counts + self.split_kv = schedule.split_kv + self.cta_tile_q = schedule.cta_tile_q + self.num_q_tiles = schedule.num_q_tiles + self.kv_chunk_size_pages = schedule.kv_chunk_size_pages + self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens + self.work_count = schedule.work_count + self.padded_work_count = schedule.padded_work_count + if schedule.split_kv: + self.O_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self.LSE_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + self._O_partial_dummy = None + self._LSE_partial_dummy = None + else: + self.O_partial = None + self.LSE_partial = None + # decode_forward_paged_fp8 always wants non-None partial buffers + # for the kernel's positional arg layout (compile keeps the slot + # alive even when split_kv=False). Allocate once here and reuse. + self._O_partial_dummy = torch.empty( + (1, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self._LSE_partial_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + # LSE dummy is shape (1, head_q) — used when caller doesn't request + # LSE and the schedule isn't split-KV (split-KV always writes LSE). + self._lse_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + return self + + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + ): + """Launch decode using metadata cached by ``plan``. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``. + v : torch.Tensor + Paged V cache with the same shape as ``k``. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + out : torch.Tensor, optional + Preallocated BF16 output buffer with shape ``q.shape``. + lse : torch.Tensor, optional + Preallocated float32 LSE buffer with shape ``[total_q, Hq]``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output, optionally with float32 LSE. + """ + if self.decode_schedule is None: + raise RuntimeError("decode wrapper must be planned before run") + if self.is_sparse: + # Sparse path still goes through the validating wrapper for now; + # only the dense fast path is collapsed. + return sparse_decode_atten_func( + q, k, v, self.q2k_indices, + page_table=self.page_table, seqused_k=self.seqused_k, + seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, + blk_kv=self.blk_kv, causal=self.causal, + softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse, + schedule=self.decode_schedule, + O_partial=self.O_partial, LSE_partial=self.LSE_partial, + ) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + if out is None: + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + if lse is None: + if return_softmax_lse or self.split_kv: + # Real LSE needed — must allocate per-call (shape depends on q). + lse = torch.empty( + q.shape[:2], dtype=torch.float32, device=q.device, + ) + else: + # Kernel only needs a valid pointer; reuse cached dummy. + lse = self._lse_dummy + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + schedule = self.decode_schedule + decode_forward_paged_fp8( + q, k, v, + self.page_table, self.seqused_k, + out, lse, + schedule.request_indices, schedule.qo_tile_indices, + schedule.kv_tile_indices, schedule.block_valid_mask, + schedule.split_counts, schedule.o_indptr, schedule.merge_indptr, + self.O_partial, self.LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=self.seqlen_q, + page_size=self.blk_kv, + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=self.causal, + return_lse=bool(return_softmax_lse), + # cached dummies — avoid per-call torch.empty inside run_decode_attention + O_partial_dummy=self._O_partial_dummy, + LSE_partial_dummy=self._LSE_partial_dummy, + ) + if return_softmax_lse: + return out, lse + return out + + +def _sparse_atten_csr_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + causal: bool, + softmax_scale: float, + lse_temperature_scale: float, + return_temperature_lse: bool, + partial_dtype: torch.dtype, + return_softmax_lse: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + page_table: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + schedule: Optional[SparseAttentionSchedule], + usable_SM_count: int, + batch: int, + head_kv: int, + max_seqlen_q: int, + max_seqlen_k: int, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + total_q, head_q, dim = q.shape + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by head_kv") + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + schedule = _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count, + causal=causal, + schedule=schedule, + qk_dtype=qk_dtype, + pv_dtype=pv_dtype, + ) + # Sparse Attention and Sparse Page Attention both use the varlen-Q + # combine path; the kernel-written LSE_out is the final contract. + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def _call_sparse_decode_forward_sm100_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + schedule: DecodeAttentionSchedule, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, + return_lse: bool = True, +) -> None: + """Compile and launch the SM100 paged fp8 decode forward kernel. + + Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse + the same schedule wrapper but needs a separate q2k gather path. + """ + if q2k_indices is not None: + raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet") + if schedule.cta_tile_q != 128: + raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}") + if schedule.split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode forward requires O_partial and LSE_partial") + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + + decode_forward_paged_fp8( + q, + k, + v, + page_table, + seqused_k, + out, + lse, + schedule.request_indices, + schedule.qo_tile_indices, + schedule.kv_tile_indices, + schedule.block_valid_mask, + schedule.split_counts, + schedule.o_indptr, + schedule.merge_indptr, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(blk_kv), + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + ) + + +def _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count=-1, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + """Compile and launch the SM100 sparse forward K1 kernel on CSR metadata.""" + head_dim = q.shape[-1] + dtype = q.dtype + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype") + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size) + else: + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + k.dtype, + v.dtype, + qk_dtype, + pv_dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype), + pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"): + _compile_cache[key]( + k_kernel, + v_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule + + +def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V.""" + + head_dim = q.shape[-1] + dtype = q.dtype + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0" + k_global_scale_kernel = k_global_scale + # V global scale is linear in the final output. Keep K1 on block-scale-only V + # and apply the tensor scale once in K2 combine. + v_global_scale_kernel = None + has_k_global_scale = k_global_scale_kernel is not None + has_v_global_scale = v_global_scale_kernel is not None + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("KVFP4 sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + _prepare_paged_kv_for_tma(k, v, n_block_size) + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen_nvfp4_kv", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + bool(fp8_pair_dequant), + bool(has_k_global_scale), + bool(has_v_global_scale), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardNvfp4KvSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + fp8_pair_dequant=bool(fp8_pair_dequant), + has_k_global_scale=bool(has_k_global_scale), + has_v_global_scale=bool(has_v_global_scale), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k_scale_128x4), + to_cute_tensor_kvouter(v_scale_128x4), + None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel), + None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"): + _compile_cache[key]( + k_kernel, + v_kernel, + k_scale_128x4, + v_scale_128x4, + k_global_scale_kernel, + v_global_scale_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..aecf3a35ae5adf74beb5a54284d030b75459dfc9 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,71 @@ +{ + "name": "msa", + "id": "_msa_cuda_09d7851", + "version": 0, + "license": "other", + "upstream": "https://github.com/MiniMax-AI/MSA", + "python-depends": [ + "tvm-ffi", + "nvidia-cutlass-dsl" + ], + "backend": { + "type": "cuda", + "archs": [ + "10.0" + ] + }, + "digest": { + "algorithm": "sha256", + "files": { + "__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=", + "_msa_cuda_09d7851.abi3.so": "jc2MhuUS893VrLlfb9ytPPqhV5u2+HSnFPugZuaHcWE=", + "_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=", + "fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=", + "interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=", + "msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=", + "quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + "quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=", + "quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=", + "quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=", + "quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=", + "quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=", + "quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=", + "sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=", + "src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=", + "src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=", + "src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=", + "src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=", + "src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=", + "src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=", + "src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=", + "src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=", + "src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=", + "src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=", + "src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=", + "src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=", + "src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=", + "src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=", + "src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=", + "src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=", + "src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=", + "src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=", + "src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=", + "src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=", + "src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=", + "src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=", + "src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=", + "src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=", + "src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=", + "src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=", + "src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=", + "src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=", + "src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=", + "src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=", + "src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=", + "src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=", + "src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw=" + } + } +} \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore new file mode 100644 index 0000000000000000000000000000000000000000..568d4fe339e2d25ed51094b0f3681366956886fa --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore @@ -0,0 +1 @@ +{"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHTDCCBtGgAwIBAgIUXQHYSDFOSO1tjFUUICxJvOGeZcMwCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA4WhcNMjYwNjMwMTc1NDA4WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPXM0K6Fgcg5CUSklxxl2csu3F3KVSv8zPaW2wSeCwTB487WjsTVM+EqcLz/LSKUD5XL4tCAc1+gFBa30H4iDgKOCBfAwggXsMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQUfsSvN2oaJ+OmV0cSOHDNe9Nc/qUwHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBigYKKwYBBAHWeQIEAgR8BHoAeAB2AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhlrEAAAQDAEcwRQIhAN6iYC5242Rjj5dTsIgyISVMIPYWL2i81TwWknEvZur+AiAt30f5Wif9ZHR/wsWh+ve5O9GtVpL2jPTURJTl0u2xMjAKBggqhkjOPQQDAwNpADBmAjEA4i2QuFAcvw5KQAQADHbn8kVwmCTVfjK5xdQ1bJEu5eVu4PY4Br1zC9GVk7p6opFmAjEAm7jnPQ2jC5BL90FIlwMdeEVPgNmR7svFEElrkQme43Rqt6pvdGksMAzAqaWXQFqT"},"tlogEntries":[{"logIndex":"2024793345","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841448","inclusionPromise":{"signedEntryTimestamp":"MEUCIQDoWovnRcuj8EsCnxn/h18ObLX1W2EowGsjOnjj31tjKgIgE1bqiVYG2avTTL3CutjFGVSxSQtlXFYWVfl+DRCyVUk="},"inclusionProof":{"logIndex":"1902889083","rootHash":"rTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=","treeSize":"1902889093","hashes":["o6DK+OhTtiUAKd3yIcR79MoEH+e/lGDEz7/klBOgQgQ=","QFE69AbxzyZT6lYixktLCZ3SnTobLI2F6l/FFy7U7bE=","euXxtVgM7AeowPy83tQZihH1C4RDec9dw20k4Rjy7X8=","mCF45aBQkD6Ga0kRgUZm/6GIWnlvuDEwC1rsiDj7r9A=","wCaOWjILsSS/Bc8GMCLLwZ/lR4z6kHhhDwjBR489Drg=","oREPAC441YAiXLkRB+S3slZaG/rywypoRAOWh9Onh28=","tdRUnZp2XzgIgMBhnUUzZKRYmgMR9VRE4EFRMnBcvN4=","SRE7OpzsmEEBrnt2NvwSO2YvAQJHxIzVKMjw7ssvt3A=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889093\nrTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=\n\n— rekor.sigstore.dev wNI9ajBFAiBuldB8XClfqbEMlZnWsMAPF1CWf+PfKW6kiBU0RaE3YwIhAKQGXPHErozLpsxzvdgVeeJVRUx9RGAtRP5qoXqfKhJm\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiIyMzVlYjhiNGYxZmIyOWIzZWU4OTNlNzI4ODU1NDc3N2E3YzE3ZTVhNzNkNDM3YTc0M2JlNzAxOGYyOWQ5OGI4In19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FVUNJQ1dkOUxlZ3ZSb0oxWDZIQUwway9SV1BvTG1sbS9YU3c3VXhOWmNpSFMwc0FpRUE3U1phSlJXVGlHdlJIWWh2d0pLS0RwRDVnRUNZT25GMGMzRURMT0VTOWNNPSIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFVSRU5EUW5SSFowRjNTVUpCWjBsVldGRklXVk5FUms5VFR6RjBha1pWVlVsRGVFcDJUMGRsV21OTmQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTkZkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFMFYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVZRV0Uwd1N6WkdaMk5uTlVOVlUydHNlSGhzTW1OemRUTkdNMHRXVTNZNGVsQmhWeklLZDFObFEzZFVRalE0TjFkcWMxUldUU3RGY1dOTWVpOU1VMHRWUkRWWVREUjBRMEZqTVN0blJrSmhNekJJTkdsRVowdFBRMEptUVhkbloxaHpUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlZtYzFOMkNrNHliMkZLSzA5dFZqQmpVMDlJUkU1bE9VNWpMM0ZWZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwWjFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT0VKSWIwRUtaVUZDTWtGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2hzY2tWQlFVRlJSQXBCUldOM1VsRkphRUZPTm1sWlF6VXlOREpTYW1vMVpGUnpTV2Q1U1ZOV1RVbFFXVmRNTW1rNE1WUjNWMnR1UlhaYWRYSXJRV2xCZERNd1pqVlhhV1k1Q2xwSVVpOTNjMWRvSzNabE5VODVSM1JXY0V3eWFsQlVWVkpLVkd3d2RUSjRUV3BCUzBKblozRm9hMnBQVUZGUlJFRjNUbkJCUkVKdFFXcEZRVFJwTWxFS2RVWkJZM1ozTlV0UlFWRkJSRWhpYmpoclZuZHRRMVJXWm1wTE5YaGtVVEZpU2tWMU5XVldkVFJRV1RSQ2NqRjZRemxIVm1zM2NEWnZjRVp0UVdwRlFRcHROMnB1VUZFeWFrTTFRa3c1TUVaSmJIZE5aR1ZGVmxCblRtMVNOM04yUmtWRmJISnJVVzFsTkROU2NYUTJjSFprUjJ0elRVRjZRWEZoVjFoUlJuRlVDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyDADAgEAMIICvwYJKoZIhvcNAQcCoIICsDCCAqwCAQMxDTALBglghkgBZQMEAgEwgbcGCyqGSIb3DQEJEAEEoIGnBIGkMIGhAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQghcKBnsFCpVtXbanqDCSR8zDubO5wb4xvtguYuZJRTKMCFGXfBMQDzomI8IngRpeuarmPZQoDGA8yMDI2MDYzMDE3NDQwOFowAwIBAaAypDAwLjEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MRUwEwYDVQQDEwxzaWdzdG9yZS10c2GgADGCAdowggHWAgEBMFEwOTEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MSAwHgYDVQQDExdzaWdzdG9yZS10c2Etc2VsZnNpZ25lZAIUOhNULwyQYe68wUMvy4qOiyojiwwwCwYJYIZIAWUDBAIBoIH8MBoGCSqGSIb3DQEJAzENBgsqhkiG9w0BCRABBDAcBgkqhkiG9w0BCQUxDxcNMjYwNjMwMTc0NDA4WjAvBgkqhkiG9w0BCQQxIgQgczwr9pKyxDMc0eur+DGt9Mdetezf8UQKp2Sn3wspffwwgY4GCyqGSIb3DQEJEAIvMX8wfTB7MHkEIIX5J7wHq2LKw7RDVsEO/IGyxog/2nq55thw2dE6zQW3MFUwPaQ7MDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAoGCCqGSM49BAMCBGYwZAIwJmfpM3hVIBsGwNTieyT54BZfQTwFye2f0/les1QzRFpXz5nu59C0tKLFYqcNPDdQAjBI9y5eNjjl9yo9BtpcZmIjURLuYioqzrjahNDmiThJZgRNROaVkPWrE5dlDJoFe58="}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"I164tPH7KbPuiT5yiFVHd6fBflpz1DenQ75wGPKdmLg="},"signature":"MEUCICWd9LegvRoJ1X6HAL0k/RWPoLmlm/XSw7UxNZciHS0sAiEA7SZaJRWTiGvRHYhvwJKKDpD5gECYOnF0c3EDLOES9cM="}} \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cbeb29242b92b7cc336cd336604e58c36f4459 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py @@ -0,0 +1,532 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Tuple +from functools import partial + +import cutlass.cute as cute +from cutlass import Float32, Boolean, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm + + +F32_or_F32x2 = Float32 | Tuple[Float32, Float32] + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, +) + + +@dsl_user_op +def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) + return 0.5 + 0.5 * tanh(0.5 * x) + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + + +@dsl_user_op +def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + # return dout * out * (1.0 - out) + return dout * (out - out * out) + + +@dsl_user_op +def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) + else: + return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)) + + +@dsl_user_op +@cute.jit +def drelu( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0)) + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0)) + return dx, relu(x) + + +@dsl_user_op +def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * x + else: + relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) + return cute.arch.mul_packed_f32x2(relu_x, x) + + +@dsl_user_op +@cute.jit +def drelu_sq( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward + Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out + Returns: (dx, relu_sq_out) where: + - dx = dout * 2 * x if x > 0, else 0 + - relu_sq_out = max(x, 0) * x + """ + if const_expr(not isinstance(x, tuple)): + relu_x = relu(x) + relu_sq_out = relu_x * x + # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0 + dx = 2.0 * (dout * relu_x) + return dx, relu_sq_out + else: + relu_x = relu(x) + relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x) + dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x)) + return dx, relu_sq_out + + +@dsl_user_op +def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ + gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x))) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774 + if const_expr(not isinstance(x, tuple)): + return 0.5 * ( + x + # Currently cute.math.tanh(x, fastmath=True) generates very slow code + # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True)) + * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) + ) + else: + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x) + return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + + +@dsl_user_op +def dgelu_tanh_approx( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward + Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out + Returns: (dx, gelu_out) + + Derivative uses the chain rule: + d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2 + and sech^2(z) = 1 - tanh^2(z) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774 + sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322 + + if const_expr(not isinstance(x, tuple)): + # Compute z = x * (c1 + c2 * x^2) + x_sq = x * x + # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True) + tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq)) + half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z + gelu_out = x * half_tanh_z_plus_one + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = 1 - tanh_z * tanh_z + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx)) + + dx = dout * dgelu + return dx, gelu_out + else: + # Compute z = x * (c1 + c2 * x^2) + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one) + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + + dx = cute.arch.mul_packed_f32x2(dout, dgelu) + return dx, gelu_out + + +@dsl_user_op +@cute.jit +def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + use_linear = Boolean(x > 20.0) + return ( + cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True) + if not use_linear + else x + ) + else: + log2_e = math.log2(math.e) + x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e)) + x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) + x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0)) + log_x_exp_p1 = ( + cute.math.log2(x_exp_p1[0], fastmath=True), + cute.math.log2(x_exp_p1[1], fastmath=True), + ) + ln2 = math.log(2.0) + softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + use_linear_0 = Boolean(x[0] > 20.0) + use_linear_1 = Boolean(x[1] > 20.0) + return ( + softplus_x[0] if not use_linear_0 else x[0], + softplus_x[1] if not use_linear_1 else x[1], + ) + + +@dsl_user_op +@cute.jit +def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + use_linear = Boolean(out > 20.0) + # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout + dx = dout - dout * cute.math.exp(-out, fastmath=True) + return dx if not use_linear else dout + + +@dsl_user_op +def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2: + """ + silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA. + """ + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x if const_expr(not already_halved) else x + # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half + return x_half * tanh(x_half) + x_half + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half) + + +@dsl_user_op +def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return silu(x) * y + else: + return cute.arch.mul_packed_f32x2(silu(x), y) + + +@dsl_user_op +def dswiglu( + x: F32_or_F32x2, + y: F32_or_F32x2, + dout: F32_or_F32x2, + *, + already_halved: bool = False, + loc=None, + ip=None, +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out + Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x) + + d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + + This has been optimized to use fewer instructions (i.e. we expand things out + to use FFMA instead of FADD and FMUL). + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) + # FMUL, MUFU.TANH, then FFMA + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = x * sigmoid_x # FMUL + else: + tanh_x = tanh(x) # MUFU.TANH + sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA + silu_x = x * tanh_x + x # FFMA + silu_x_dout = silu_x * dout # FMUL + # d_silu(x) * dout + # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout + # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA + dx = d_silu_x_dout * y # FMUL + dy = silu_x_dout + swiglu_out = silu_x * y # FMUL + # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(x) and silu(x) + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x) + else: + tanh_x = (tanh(x[0]), tanh(x[1])) + sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2( + sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x + ) + d_silu_x_dout = cute.arch.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout + ) + dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y) + dy = silu_x_dout + swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y) + return dx, dy, swiglu_out + + +@dsl_user_op +def swiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> F32_or_F32x2: + """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y. + https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249 + x * sigmoid(alpha * x) * (y + 1) + Compile down to FMUL, FMUL, TANH, FFMA, FFMA + """ + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x + # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half + silu_x = x_half * tanh(alpha * x_half) + x_half + return silu_x * y + silu_x + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + + +@dsl_user_op +def dswiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + Swiglu OAI backward pass: computes gradients w.r.t. x and y + Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out + Returns: (dx, dy, swiglu_oai_out) + + Derivative of x * sigmoid(alpha * x) w.r.t. x: + d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x)) + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + alpha_x_half = (0.5 * alpha) * x # FMUL + # MUFU.TANH, then FFMA + # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True) + sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) + silu_x = x * sigmoid_alpha_x # FMUL + silu_x_dout = silu_x * dout # FMUL + # FFMA, FFMA, FMUL + d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1 + dy = silu_x_dout + swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1 + # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(alpha * x) + alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + silu_x_minus_product = cute.arch.fma_packed_f32x2( + silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x + ) + sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2( + (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x + ) + d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + dy = silu_x_dout + swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + return dx, dy, swiglu_out + + +@dsl_user_op +def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GLU: Gated Linear Unit + glu(x, y) = sigmoid(x) * y + Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + """ + if const_expr(not isinstance(x, tuple)): + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + return sigmoid_x * y # FMUL + else: + sigmoid_x = sigmoid(x) + return cute.arch.mul_packed_f32x2(sigmoid_x, y) + + +@dsl_user_op +def dglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out + Returns: (dx, dy, glu_out) where: + - dx = dout * y * sigmoid(x) * (1 - sigmoid(x)) + - dy = dout * sigmoid(x) + - glu_out = sigmoid(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + sigmoid_x_dout = sigmoid_x * dout # FMUL + glu_out = sigmoid_x * y # FMUL + # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout + # = y * (1 - sigmoid(x)) * sigmoid_x_dout + # = (y - y * sigmoid(x)) * sigmoid_x_dout + # = (y - glu_out) * sigmoid_x_dout + dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL + dy = sigmoid_x_dout + # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA + return dx, dy, glu_out + else: + sigmoid_x = sigmoid(x) + sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout) + glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y) + # dx = (y - glu_out) * sigmoid_x_dout + y_minus_glu_out = sub_packed_f32x2(y, glu_out) + dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + dy = sigmoid_x_dout + return dx, dy, glu_out + + +@dsl_user_op +def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ReGLU: ReLU Gated Linear Unit + reglu(x, y) = relu(x) * y = max(x, 0) * y + """ + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * y + else: + relu_x = relu(x) + return cute.arch.mul_packed_f32x2(relu_x, y) + + +@dsl_user_op +@cute.jit +def dreglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out + Returns: (dx, dy, reglu_out) where: + - dx = dout * y if x > 0, else 0 + - dy = dout * relu(x) + - reglu_out = relu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + relu_x = cute.arch.fmax(x, Float32(0.0)) + dx = (dout * y) if x_pos else Float32(0.0) + dy = dout * relu_x + reglu_out = relu_x * y + return dx, dy, reglu_out + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + relu_x = relu(x) + dout_y = cute.arch.mul_packed_f32x2(dout, y) + dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) + dy = cute.arch.mul_packed_f32x2(dout, relu_x) + reglu_out = cute.arch.mul_packed_f32x2(relu_x, y) + return dx, dy, reglu_out + + +@dsl_user_op +def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GeGLU: GELU Gated Linear Unit + geglu(x, y) = gelu(x) * y + Uses the tanh approximation of GELU + """ + if const_expr(not isinstance(x, tuple)): + return gelu_tanh_approx(x) * y + else: + return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y) + + +@dsl_user_op +def dgeglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out + Returns: (dx, dy, geglu_out) where: + - dx = dout * y * d_gelu(x) + - dy = dout * gelu(x) + - geglu_out = gelu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = dgelu_x_dout * y + dy = gelu_x * dout + geglu_out = gelu_x * y + return dx, dy, geglu_out + else: + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y) + dy = cute.arch.mul_packed_f32x2(gelu_x, dout) + geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y) + return dx, dy, geglu_out diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4375594669c8f12d6a79d8878316271cb819568a --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Optional + +import cutlass.cute as cute + + +def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad989559766d6ee6e8ece9d322bf08980706dfa --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py @@ -0,0 +1,890 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import re +from typing import Optional, Type, Tuple, Callable, Sequence +from functools import partial + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Int16, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline +from cutlass._mlir.dialects import llvm +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + +Sm100MmaPeerBitMask = 0xFEFFFFFF + + +@dsl_user_op +def cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + retile: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + if const_expr(retile): + src = tiled_copy.retile(src) + cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def load_s2r_retile( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst_shape: cute.Tensor | cute.Shape, + *, + loc=None, + ip=None, +) -> cute.Tensor: + # Will also accept dst_shape being a tensor, in which case we write into that tensor + if const_expr(not isinstance(dst_shape, cute.Tensor)): + dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip) + else: + dst = dst_shape + cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + num_copy_elems = src.shape[0][0] + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +# def tiled_copy_2d( +# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +# ) -> cute.TiledCopy: +# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width +# copy_elems = num_copy_bits // dtype.width +# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() +# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +# gmem_threads_per_row = major_mode_size // copy_elems +# assert num_threads % gmem_threads_per_row == 0 +# thr_layout = cute.make_ordered_layout( +# (num_threads // gmem_threads_per_row, gmem_threads_per_row), +# order=(1, 0), +# ) +# val_layout = cute.make_layout((1, copy_elems)) +# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return b, m, s + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: + bit_msk = (1 << b) - 1 + yyy_msk = bit_msk << (m + s) + return ptr_int ^ ((ptr_int & yyy_msk) >> s) + + +def swizzle_ptr(ptr: cute.Pointer): + b, m, s = parse_swizzle_from_pointer(ptr) + ptr_int = swizzle_int(ptr.toint(), b, m, s) + return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) + + +def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: + outer = tensor.layout + width = tensor.element_type.width + inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for + # for 16 bits and <3, 2, 3> for 32 bits) + new_layout = cute.recast_layout( + width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer)) + ) + # recast_ptr to remove the pointer swizzle + return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout) + + +def partition_D_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_D(tensor).iterator), + thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +def partition_S_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_S(tensor).iterator), + thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +@dsl_user_op +def sm90_get_smem_load_op( + layout_c: cutlass.utils.LayoutEnum, + elem_ty_c: Type[cutlass.Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem load atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_c : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_c : Type[Numeric] + The element type for output tensor D. + + Returns: + -------- + Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters. + """ + + if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta): + raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") + is_m_major = layout_c.is_m_major_c() + if elem_ty_c.width == 16: + return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip) + else: + return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_load_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_store_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + + def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs): + dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx] + cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sC + + +def get_smem_load_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sC = thr_copy.partition_S(sC) + else: + tSR_sC = partition_S_position_independent(thr_copy, sC) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape + + def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs): + src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx] + return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs) + + return copy_fn, thr_copy, tSR_sC + + +def epilog_smem_copy_atom( + tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False +) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + +def get_smem_store_epi( + tiled_mma: cute.TiledMma, + epi_tile: cute.Shape, + sC: Optional[cute.Tensor], + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]: + dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16 + tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile) + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom) + thr_copy = tiled_copy.get_slice(tidx) + tRS_sC = None + if const_expr(sC is not None): + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + sC_shape = sC.shape[:2] if sC is not None else epi_tile + # (R2S, R2S_M, R2S_N, PIPE_C) + tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape + tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs) + + return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC + + +def get_smem_store_A( + tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sA = thr_copy.partition_D(sA) + else: + tRS_sA = partition_D_position_independent(thr_copy, sA) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sA + + +def get_smem_load_A( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + tidx: Int32, + arch: int, + with_dst_tensor: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sA = thr_copy.partition_S(sA) + else: + tSR_sA = partition_S_position_independent(thr_copy, sA) + tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs): + return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs) + + return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + Get the address of the TMA descriptor embedded in a TMA Copy Atom. + + Extracts the constant memory address of the TMA descriptor for use with + custom PTX instructions. + + :param tma_atom: TMA Copy Atom from make_tiled_tma_atom + :return: Pointer to TMA descriptor in constant memory + + Example: + >>> desc_ptr = get_tma_descriptor_address(tma_atom) + """ + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + + +@dsl_user_op +def tma_gather4_load( + tma_desc_ptr: cute.Pointer, + dst_smem_ptr: cute.Pointer, + mbarrier_ptr: cute.Pointer, + col_idx: Int32, + row_indices: Sequence[Int32], + *, + num_cta: int = 1, + multicast_mask=None, + loc=None, + ip=None, +) -> None: + """ + Perform TMA gather4 load from global memory to shared memory. + + Issues PTX instruction: + cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar]; + + This loads 4 rows (specified by row_indices) from a 2D tensor at the given + column index into shared memory, using the TMA descriptor. + + :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned) + :type tma_desc_ptr: Pointer + :param dst_smem_ptr: Destination address in shared memory + :type dst_smem_ptr: Pointer + :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking + :type mbarrier_ptr: Pointer + :param col_idx: Column index + :type col_idx: Int32 + :param row_indices: Sequence of exactly 4 row indices + :type row_indices: Sequence[Int32] + :param num_cta: Number of CTAs participating (default: 1) + :type num_cta: int + :param multicast_mask: Optional multicast mask + :type multicast_mask: Int16 + + Requirements: + - row_indices must contain exactly 4 elements + - Compute capability >= SM_100 (Blackwell) + - TMA descriptor must be properly initialized for 2D tensor + + Example: + >>> from cutlass.cute.nvgpu import cpasync + >>> from cutlass.cute import core + >>> + >>> # Create TMA descriptor + >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...) + >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom) + >>> + >>> # Compute indices (typically from kernel logic) + >>> col_idx = core.get(...) or 5 # Int32 value + >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values + >>> + >>> # Gather 4 rows at computed column + >>> tma_gather4_load( + ... tma_desc_ptr=tma_desc_ptr, + ... dst_smem_ptr=smem_ptr, + ... mbarrier_ptr=barrier_ptr, + ... col_idx=col_idx, + ... row_indices=row_indices + ... ) + """ + if len(row_indices) != 4: + raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}") + col_val = Int32(col_idx).ir_value() + row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices] + # Convert pointers to integer addresses + desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip) + if num_cta > 1: + # Executed by both CTAs. Set peer bit to 0 so that the + # transaction bytes will update CTA0's barrier. + mbar_addr = mbar_addr & Sm100MmaPeerBitMask + mbar_addr = mbar_addr.ir_value() + # Handle multicast_mask - may already be ir.Value or Python int + multicast_mask_val = None + if multicast_mask is not None: + multicast_mask_val = Int16(multicast_mask).ir_value() + assert multicast_mask_val is None, "multicast is not supported yet" + # Emit inline PTX for TMA gather4 + # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar]; + ptx = ( + f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} " + "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];" + ) + + llvm.inline_asm( + None, + [ + dst_addr, + desc_addr, + col_val, + row_vals[0], + row_vals[1], + row_vals[2], + row_vals[3], + mbar_addr, + ], + ptx, + "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy( + atom, + src[None, src_idx], + dst[None, dst_idx], + mbar_ptr=tma_bar_ptr, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +@cute.jit +def gather_m_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_M), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + tAsA = thr_copy_A.partition_D(sA) + # k-major + assert tAsA.shape[2] == 1 + tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + m_idx = cute.make_rmem_tensor(rows_per_thread, Int32) + for m in cutlass.range(rows_per_thread, unroll_full=True): + row_idx = tAcA[0, m, 0][0] + if tApA_m[m]: + m_idx[m] = gsAIdx[row_idx] + else: + m_idx[m] = 0 # It's ok to load row 0 in the case of OOB + + mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + + def copy_fn(src_idx, dst_idx, pred: bool = False): + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + mA_cur = mA_k[None, (None, src_idx)] + for m in cutlass.range_constexpr(tAcA.shape[1]): + # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape + # ((elems_per_load), thread_per_row) + # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA + # So we append 1s to the last dimension and then do tiled_divide, then slice. + mA_row = cute.tiled_divide( + cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1) + )[None, None, 0] + if const_expr(is_even_m_smem) or tApA_m[m]: + # There's only 1 load per row + assert cute.size(tAcA.shape, mode=[2]) == 1 + ki = tAcA[0, 0, 0][1] // elems_per_load + cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k) + + return copy_fn + + +@cute.jit +def gather_k_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (tile_M, whatever) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + gAIdx, sAIdx = None, None + if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem): + gAIdx = gsAIdx + else: + assert gsAIdx.memspace == cute.AddressSpace.smem + sAIdx = gsAIdx + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + # (atom_v, CPY_M, 1, STAGE) + tAsA = thr_copy_A.partition_D(sA) + # m-major + tAsA = cute.group_modes(tAsA, 0, 3) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) + # This is very convoluted but idk a better way + # for tile_M=128, flat_divide gives (8, 16, K), + # then logical_divide gives ((8, 1), (8, 2), K). + tidx = thr_copy_A.thr_idx + tAmA = cute.logical_divide( + cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col) + )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K) + + def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]: + # Prefetch mAIdx early, even before smem is free + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + gAIdx_cur = gAIdx[None, src_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + if const_expr(not pred): + k_idx[k] = gAIdx_cur[col_idx] + else: + if tApA_k[k]: + k_idx[k] = gAIdx_cur[col_idx] + else: + k_idx[k] = -1 + return k_idx, tApA_k + + def prefetch_from_smem_fn( + a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False + ) -> Tuple[cute.Tensor, cute.Tensor]: + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_cur = sAIdx[None, dst_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + k_idx[k] = sAIdx_cur[col_idx] + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return k_idx, tApA_k + + def copy_fn( + src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False + ): + k_idx, tApA_k = k_idx_tApA_k + tApA_k_pred = None + if const_expr(pred): + tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread) + for k in cutlass.range_constexpr(tAcA.shape[2]): + # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2)) + for m in cutlass.range_constexpr(tAcA.shape[1]): + if tApA_m[m]: + cute.copy( + thr_copy_A, + tAmA[None, m, k_idx[k]], + tAsA[(None, m, k), dst_idx], + pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k], + ) + + return copy_fn, prefetch_from_gmem_fn if const_expr( + gAIdx is not None + ) else prefetch_from_smem_fn + + +@cute.jit +def gather_m_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # ((4, 32), (64, 1), STAGE) + sAIdx: cute.Tensor, # (tile_M), + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Callable: + tile_M = cute.size(sAIdx, mode=[0]) + tile_K = cute.size(sA[None, None, 0]) // tile_M + assert tile_M % 4 == 0 + # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2 + cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel + + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout + ) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) + # ((4, 1), 8, (64, 1), STAGE) + tSR_sA = warp_copy_AIdx_s2r.partition_S(sA) + tSR_rAIdx = load_s2r(tSR_sAIdx) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + col_idx = tile_K * src_idx + for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, m] for v in range(4)] + smem_ptr = tSR_sA[None, m, None, dst_idx].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) + + return copy_fn diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c92cf39ac08b92245316da46526494d7d8370e1 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from functools import lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float16, BFloat16, Float32 +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: Float16, + torch.bfloat16: BFloat16, + torch.float32: Float32, + torch.int32: Int32, + torch.int64: Int64, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..099e0daf54cdac4b25b6d96f01b35451c810249b --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, const_expr + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: + shape = (*a.shape[:dim], size, *a.shape[dim:]) + stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +@cute.jit +def permute_gated_Cregs_b16(t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower) + + +@cute.jit +def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 + a b | c d | e f | g h + to + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [2, 0, 3, 1] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b10 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a b | c d | e f | g h -> a b | c d | f e | h g + left0 = left if quad_idx < 2 else right + right0 = right if quad_idx < 2 else left + # a b | c d | f e | h g -> a b | f d | c e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a e | f b | c g | h d + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a e | f b | c g | h d -> a e | b f | c g | d h + t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0 + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + + +@cute.jit +def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + to + T0 | T1 | T2 | T3 + a b | c d | e f | g h + This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [1, 3, 0, 2] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b01 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + # This is just the inverse of permute_Cregs_b32_for_stsm + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a e | b f | c g | d h -> a e | f b | c g | h d + left0 = left if quad_idx % 2 == 0 else right + right0 = right if quad_idx % 2 == 0 else left + # a e | f b | c g | h d -> a b | f d | c e | h g + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a b | c d | f e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | c d | f e | h g -> a b | c d | e f | g h + t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0 + + +@cute.jit +def concat_layout(*layouts: cute.Layout) -> cute.Layout: + return cute.make_layout( + tuple(l.shape for l in layouts), + stride=tuple(l.stride for l in layouts), + ) + + +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def convert_layout_zero_stride( + input: cute.Tensor | cute.Layout, ref_layout: cute.Layout +) -> cute.Layout: + layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input + # Group the modes with non-zero stride in the ref_layout together, + # and the modes with zero stride together + layout_flat = cute.flatten(layout) + ref_layout_flat = cute.flatten(ref_layout) + nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0] + zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0] + # There's an edge case when all modes are zero stride + new_shape = ( + tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,), + tuple(layout_flat[i].shape for i in zero_modes), + ) + new_stride = ( + tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,), + tuple(layout_flat[i].stride for i in zero_modes), + ) + out_layout = cute.make_layout(new_shape, stride=new_stride) + if const_expr(isinstance(input, cute.Tensor)): + return cute.make_tensor(input.iterator, out_layout) + else: + return out_layout + + +def mma_partition_C_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def mma_partition_A_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quantize.py b/build/torch211-cxx11-cu128-x86_64-linux/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4719a4854bc9388b2a866598f9e21c1f14921181 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quantize.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Transformer Engine NVFP4 quantization helper. + +This file is intended as a customer-facing example for preparing KV tensors +for the KVFP4 attention kernel: + - BF16/FP16 K/V input + - packed E2M1 FP4 data from Transformer Engine + - E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout + - one FP32 tensor/global scale per tensor +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + + +NVFP4_BLOCK_SIZE = 16 +NVFP4_FP4_MAX = 6.0 +NVFP4_FP8_E4M3_MAX = 448.0 + + +@dataclass(frozen=True) +class Nvfp4QuantizedTensor: + """Packed NVFP4 tensor plus dequantization metadata. + + Attributes + ---------- + data : torch.Tensor + Packed E2M1 FP4 data from Transformer Engine. The last dimension is + half of the original logical last dimension because each byte stores + two FP4 values. + scale_128x4 : torch.Tensor + E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage. + global_scale : torch.Tensor + FP32 tensor/global dequant scale. + logical_scale_shape : tuple[int, int] + Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling. + original_shape : tuple[int, ...] + Original BF16/FP16 tensor shape before quantization. + """ + + data: torch.Tensor + scale_128x4: torch.Tensor + global_scale: torch.Tensor + logical_scale_shape: Tuple[int, int] + original_shape: Tuple[int, ...] + + +def _round_up(x: int, multiple: int) -> int: + return ((int(x) + multiple - 1) // multiple) * multiple + + +def nvfp4_scale_128x4_offset( + row: torch.Tensor, + col: torch.Tensor, + scale_cols: int, +) -> torch.Tensor: + """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage. + + Parameters + ---------- + row : torch.Tensor + Logical row indices. + col : torch.Tensor + Logical scale-column indices. + scale_cols : int + Logical number of scale columns before padding to a multiple of 4. + + Returns + ------- + torch.Tensor + Flat offsets into the padded 128x4 tiled storage. + """ + + tiles_n = _round_up(scale_cols, 4) // 4 + tile_m = row // 128 + tile_n = col // 4 + outer = row % 128 + inner = col % 4 + return ( + (tile_m * tiles_n + tile_n) * 512 + + (outer % 32) * 16 + + (outer // 32) * 4 + + inner + ) + + +def swizzle_nvfp4_scale_to_128x4( + scale: torch.Tensor, + *, + rows: int, + cols: int, +) -> torch.Tensor: + """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout. + + Parameters + ---------- + scale : torch.Tensor + Logical rowwise scale tensor with at least shape ``[rows, cols]``. + rows : int + Number of logical rows to convert. + cols : int + Number of logical scale columns to convert. + + Returns + ------- + torch.Tensor + Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)`` + and swizzled into 128x4 tiled storage. + """ + + if scale.ndim != 2: + raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}") + + rows = int(rows) + cols = int(cols) + padded_rows = _round_up(rows, 128) + padded_cols = _round_up(cols, 4) + if scale.shape[0] < rows or scale.shape[1] < cols: + raise ValueError( + "scale is smaller than the requested logical shape: " + f"got {tuple(scale.shape)}, need at least {(rows, cols)}" + ) + + logical = scale[:rows, :cols].contiguous() + if logical.shape != (padded_rows, padded_cols): + logical = torch.nn.functional.pad( + logical.to(torch.float32), + (0, padded_cols - cols, 0, padded_rows - rows), + ).to(scale.dtype) + swizzled = torch.empty_like(logical) + + row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None] + col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1) + swizzled.reshape(-1)[offset] = logical.reshape(-1) + return swizzled + + +def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor: + """Compute TE NVFP4 tensor/global dequant scale from rowwise amax. + + Parameters + ---------- + amax : torch.Tensor + Rowwise absolute maxima returned by Transformer Engine. + + Returns + ------- + torch.Tensor + FP32 global scale equal to ``amax / (448 * 6)``. + """ + + return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX) + + +def _import_te_nvfp4_quantizer(): + try: + from transformer_engine.pytorch.tensor import NVFP4Quantizer + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "Transformer Engine NVFP4 quantization is unavailable. Install a " + "Transformer Engine build with its PyTorch dependencies, including " + "FlashAttention v3 when required by that TE build." + ) from exc + return NVFP4Quantizer + + +def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor: + """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine. + + TE returns rowwise scales in logical padded layout. This helper returns + the scales in physical 128x4 tiled storage, so the attention kernel can + load them with ``nvfp4_scale_128x4_offset``. + + Parameters + ---------- + x : torch.Tensor + CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16, + and the flattened row dimension ``prod(x.shape[:-1])`` must also be + divisible by 16. + + Returns + ------- + Nvfp4QuantizedTensor + Packed FP4 data, 128x4-swizzled block scales, global scale, and shape + metadata needed by the KVFP4 attention kernel or by reference + dequantization. + """ + + if not x.is_cuda: + raise ValueError("NVFP4 quantization requires a CUDA tensor") + if x.dtype not in (torch.bfloat16, torch.float16): + raise TypeError(f"x must be bf16 or fp16, got {x.dtype}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}") + if x.shape[-1] % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}" + ) + + rows = 1 + for dim in x.shape[:-1]: + rows *= int(dim) + if rows % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + "flattened row dimension must be divisible by " + f"{NVFP4_BLOCK_SIZE}, got {rows}" + ) + + NVFP4Quantizer = _import_te_nvfp4_quantizer() + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False) + qx = quantizer.quantize(x.contiguous()) + meta = qx.get_metadata() + + data = meta["rowwise_data"] + if data.dtype is not torch.uint8: + data = data.view(torch.uint8) + logical_scale = meta["rowwise_scale_inv"] + amax = meta["amax_rowwise"] + scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE + scale_128x4 = swizzle_nvfp4_scale_to_128x4( + logical_scale, + rows=rows, + cols=scale_cols, + ) + global_scale = nvfp4_global_scale_from_amax(amax).contiguous() + + return Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale_128x4, + global_scale=global_scale, + logical_scale_shape=(rows, scale_cols), + original_shape=tuple(int(v) for v in x.shape), + ) + + +def quantize_kv_bf16_to_nvfp4_128x4( + k: torch.Tensor, + v: torch.Tensor, +) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]: + """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention. + + Parameters + ---------- + k : torch.Tensor + CUDA BF16 or FP16 K tensor. + v : torch.Tensor + CUDA BF16 or FP16 V tensor. + + Returns + ------- + tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor] + Quantized K and V tensors with independent scales. + """ + + return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v) + + +def dequantize_nvfp4_128x4_to_bf16( + qx: Nvfp4QuantizedTensor, + *, + include_global_scale: bool = True, +) -> torch.Tensor: + """Reference dequantization for validation. + + This mirrors the kernel contract: + x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale + + Parameters + ---------- + qx : Nvfp4QuantizedTensor + Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``. + include_global_scale : bool, optional + If True, multiply by ``qx.global_scale`` after applying per-block + scales. + + Returns + ------- + torch.Tensor + BF16 tensor with shape ``qx.original_shape``. + """ + + data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8) + if data.shape[-1] * 2 != qx.original_shape[-1]: + raise ValueError( + "packed data last dimension does not match original shape: " + f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical" + ) + + rows, scale_cols = qx.logical_scale_shape + logical_dim = int(qx.original_shape[-1]) + if scale_cols * NVFP4_BLOCK_SIZE != logical_dim: + raise ValueError( + "logical scale columns do not match original last dimension: " + f"{scale_cols} scale cols vs dim {logical_dim}" + ) + + fp4_lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=data.device, + ) + packed = data.reshape(rows, logical_dim // 2) + lo = packed & 0x0F + hi = packed >> 4 + values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device) + values[:, 0::2] = fp4_lut[lo.long()] + values[:, 1::2] = fp4_lut[hi.long()] + + row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None] + col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, scale_cols) + scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols) + scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32) + scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1) + out = values * scale + if include_global_scale: + global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32) + out = out * global_scale + return out.reshape(qx.original_shape).to(torch.bfloat16) + + +def _example() -> None: + device = torch.device("cuda") + k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v) + print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype) + print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype) + print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype) + print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype) + print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype) + print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + raise RuntimeError("quantize.py requires CUDA") + _example() diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54c982c9230b189051e3a0bdf76d22b397dd62a --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Host-side q2k <-> k2q index conversion for sparse attention. + +These utilities prepare sparse metadata on the Python side for tests, +benchmarks, and other offline preprocessing flows. They are not kernel +runtime helpers, so they intentionally live outside `src/common`. + +Sparse attention pattern: + - Each Q token independently selects up to topK KV blocks (blk_kv tokens each). + - Under GQA, all Q heads in one group share the same sparsity pattern, + so indices are defined at the head_kv level. + +Shapes: + q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks), + trailing unused slots padded with -1 + k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1 + k2q_counts: [batch, head_kv, Nkv] int32 + +CSR reverse-index format: + q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices + k2q_row_ptr: [head_kv, total_rows + 1] int32 + k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx +""" + +from typing import Optional, Tuple + +import torch + +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + + +def q2k_to_k2q( + q2k_indices: torch.Tensor, + num_kv_blocks: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert q2k sparse indices to k2q representation. + + For each KV block, find which Q tokens attend to it. + + Args: + q2k_indices: [batch, head_kv, Sq, topK] int32. + For each Q token, the KV blocks it attends to. Unused slots must + be padded with -1. + num_kv_blocks: Total number of KV blocks (= Skv / blk_kv). + + Returns: + k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32. + For each KV block, the Q token indices that attend to it, + left-packed and padded with -1. Last dim fixed to Sq (upper bound). + k2q_counts: [batch, head_kv, num_kv_blocks] int32. + Actual number of Q tokens per KV block. + """ + B, H, Sq, topK = q2k_indices.shape + device = q2k_indices.device + N = Sq * topK + + kv_flat = q2k_indices.reshape(B, H, N).long() + valid_flat = kv_flat >= 0 + q_flat = ( + torch.arange(Sq, device=device) + .unsqueeze(-1) + .expand(Sq, topK) + .reshape(N) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, H, N) + ) + + k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + k2q_counts.scatter_add_( + 2, + safe_kv_flat, + valid_flat.to(torch.int32), + ) + + sort_keys = torch.where( + valid_flat, + kv_flat, + torch.full_like(kv_flat, num_kv_blocks), + ) + sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True) + sorted_q = q_flat.gather(-1, sort_idx) + sorted_valid = valid_flat.gather(-1, sort_idx) + + offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device) + offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long() + + global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N) + group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1)) + pos_in_group = global_pos - group_offset + + k2q_indices = torch.full( + (B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device + ) + flat_k2q = k2q_indices.reshape(B, H, -1) + flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group + for b in range(B): + for h in range(H): + valid = sorted_valid[b, h] + flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int() + + return k2q_indices, k2q_counts + + +def k2q_to_q2k( + k2q_indices: torch.Tensor, + k2q_counts: torch.Tensor, + Sq: int, + topK: int, +) -> torch.Tensor: + """Convert dense k2q indices back to q2k representation. + + Parameters + ---------- + k2q_indices : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values + are Q token indices padded with ``-1``. + k2q_counts : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of + valid Q indices per KV block. + Sq : int + Q sequence length per batch item in this dense reference format. + topK : int + Maximum number of KV blocks selected per Q token. + + Returns + ------- + torch.Tensor + Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted + by KV block index with ``-1`` padding at the tail. + """ + B, H, Nkv, _ = k2q_indices.shape + device = k2q_indices.device + + q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device) + counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device) + + for b in range(B): + for h in range(H): + for kv_blk in range(Nkv): + count = k2q_counts[b, h, kv_blk].item() + for j in range(count): + qt = k2q_indices[b, h, kv_blk, j].item() + if qt < 0: + continue + p = counters[b, h, qt].item() + if p < topK: + q2k[b, h, qt, p] = kv_blk + counters[b, h, qt] += 1 + + q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k) + _, sort_idx = q2k_sort_key.sort(dim=-1) + q2k = q2k.gather(-1, sort_idx) + return q2k + + +def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}") + if cu_seqlens.numel() < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor: + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + return (seqlens_k + kv_block_size - 1) // kv_block_size + + +def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]: + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + batch = len(rows_per_batch_cpu) + max_rows = max(rows_per_batch_cpu, default=0) + row_dtype = ( + torch.int32 + if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max + else torch.int64 + ) + row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype) + row_linear = 0 + for kv_block_idx in range(max_rows): + for batch_idx, row_count in enumerate(rows_per_batch_cpu): + if kv_block_idx < row_count: + row_map_cpu[batch_idx, kv_block_idx] = row_linear + row_linear += 1 + return row_map_cpu.to(rows_per_batch.device), row_linear + + +def build_k2q_csr_torch_reference( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, +) -> tuple: + """Torch reference for q2k -> k2q CSR conversion. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32. Values are + batch-local KV block indices padded with ``-1``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape + ``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape + ``[head_kv, total_q * topK]``. + """ + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError( + "q2k_indices must have shape [head_kv, total_q, topK], " + f"got {tuple(q2k_indices.shape)}" + ) + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + + head_kv, total_q, topk = q2k_indices.shape + if total_q != int(cu_seqlens_q[-1].item()): + raise ValueError( + f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] " + f"({int(cu_seqlens_q[-1].item())})" + ) + + rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size) + row_map, total_rows = _build_packed_row_map(rows_per_batch) + nnz_upper_bound = total_q * topk + + k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device) + k2q_q_indices = torch.full( + (head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device + ) + if total_rows == 0 or total_q == 0 or topk == 0: + return k2q_row_ptr, k2q_q_indices + + counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device) + total_entries = total_q * topk + row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64 + row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device) + q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device) + valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device) + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist() + entry_cursor = 0 + + for batch_idx, kv_rows in enumerate(rows_per_batch_cpu): + q_start = q_cu_cpu[batch_idx] + q_end = q_cu_cpu[batch_idx + 1] + q_len = q_end - q_start + if q_len == 0: + continue + num_entries = q_len * topk + q2k_batch = q2k_indices[:, q_start:q_end, :] + valid_batch = q2k_batch >= 0 + if valid_batch.any(): + max_valid_kv = int(q2k_batch[valid_batch].max().item()) + if max_valid_kv >= kv_rows: + raise ValueError( + f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, " + f"but that batch only has {kv_rows} logical kv blocks" + ) + kv_flat = q2k_batch.reshape(head_kv, num_entries).long() + valid_flat = valid_batch.reshape(head_kv, num_entries) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + row_map_batch = row_map[batch_idx] + row_flat = row_map_batch[safe_kv_flat] + q_flat = ( + torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32) + .view(1, q_len, 1) + .expand(head_kv, q_len, topk) + .reshape(head_kv, num_entries) + ) + row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat + q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat + valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat + counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32)) + entry_cursor += num_entries + + k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32) + + sort_stride = max(total_q, 1) + invalid_key = total_rows * sort_stride + max_sort_key = invalid_key + max(total_q - 1, 0) + if max_sort_key < torch.iinfo(torch.int32).max: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32) + sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all] + else: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64) + sort_keys[valid_all] = ( + row_all[valid_all].to(torch.int64) * sort_stride + + q_all[valid_all].to(torch.int64) + ) + _, sort_idx = sort_keys.sort(dim=1, stable=True) + sorted_q = q_all.gather(1, sort_idx) + + valid_counts = valid_all.sum(dim=1) + write_mask = ( + torch.arange(total_entries, device=q2k_indices.device) + .unsqueeze(0) + .expand(head_kv, -1) + < valid_counts.unsqueeze(1) + ) + k2q_q_indices[write_mask] = sorted_q[write_mask] + + return k2q_row_ptr, k2q_q_indices + + +_K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100() + + +def build_k2q_csr( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, + *, + total_k: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]: + """Build the public k2q CSR reverse index on GPU. + + Runtime construction does not read device-side ``cu_seqlens`` on the host, + so callers must provide size hints such as ``total_k`` from already-known + tensor shapes. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are + batch-local KV block indices with trailing ``-1`` padding. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + total_k : int + Total KV token count. Required; normally ``k.shape[0]`` for dense KV + or ``sum(kv_segment_lens)`` for paged KV. + max_seqlen_k : int, optional + Maximum KV sequence length. Passing this avoids recomputing a bound. + max_seqlen_q : int, optional + Maximum Q sequence length. + total_rows : int, optional + Total number of packed KV-block rows across the batch. If omitted, + the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``. + qhead_per_kv : int, optional + Number of Q heads per KV head under GQA. + return_schedule : bool, optional + If True, also return the sparse forward schedule object produced by the + SM100 builder. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object] + ``(k2q_row_ptr, k2q_q_indices)`` or + ``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on + the same CUDA device as ``q2k_indices``. + """ + if total_k is None: + raise ValueError("build_k2q_csr requires total_k from k.shape[0]") + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}") + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + return _K2Q_CSR_BUILDER( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + total_k=int(total_k), + blk_kv=int(kv_block_size), + max_seqlen_k=max_seqlen_k, + max_seqlen_q=max_seqlen_q, + total_rows=total_rows, + qhead_per_kv=qhead_per_kv, + return_schedule=return_schedule, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd0b4da4ddb6fba21bcb18c924f5e9e8b583e6 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Persistent AOT cache for CuTe DSL compiled kernels. + +Saves compiled TVM FFI kernels as .o files on first compile, +loads them on subsequent runs to skip JIT compilation. + +Environment variables: + MM_SPARSE_ATTN_AOT_CACHE: Override cache directory + (default: ~/.cache/minfer/mm_sparse_attn) + MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely +""" + +import hashlib +import os +import time + +import cutlass.cute as cute + +_AOT_CACHE_DIR = os.environ.get( + "MM_SPARSE_ATTN_AOT_CACHE", + os.path.expanduser("~/.cache/minfer/mm_sparse_attn"), +) +_AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1" + +_loaded_modules: dict[str, object] = {} + + +def _key_to_path(key: tuple) -> str: + h = hashlib.sha256(repr(key).encode()).hexdigest()[:16] + name = str(key[0]).replace("/", "_") + return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}") + + +def try_load_aot(key: tuple): + if _AOT_DISABLE: + return None + obj_path = _key_to_path(key) + ".o" + if not os.path.isfile(obj_path): + return None + func_name = str(key[0]) + try: + if obj_path not in _loaded_modules: + _loaded_modules[obj_path] = cute.runtime.load_module( + obj_path, enable_tvm_ffi=True + ) + return getattr(_loaded_modules[obj_path], func_name) + except Exception as e: + print(f"[aot_cache] Failed to load {obj_path}: {e}") + return None + + +def save_aot(key: tuple, compiled) -> None: + if _AOT_DISABLE: + return + if not hasattr(compiled, "export_to_c"): + return + obj_path = _key_to_path(key) + ".o" + os.makedirs(_AOT_CACHE_DIR, exist_ok=True) + tmp_path = obj_path + f".tmp.{os.getpid()}" + func_name = str(key[0]) + try: + t0 = time.time() + compiled.export_to_c(tmp_path, function_name=func_name) + os.replace(tmp_path, obj_path) + dt = time.time() - t0 + print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)") + except Exception as e: + print(f"[aot_cache] Failed to save {func_name}: {e}") + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5753a8a175b529567e0be238f47fd4cc8401bf --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd22f7efa3cef9988b4036c2d00fc1d3b9c816e8 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py @@ -0,0 +1,1093 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +from . import mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, + num_unroll_groups: int = 1, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range( + cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups + ): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, + acc_tmem_addr, + rA, + rB, + sA_cur, + sB_cur, + zero_init=zero_init, + cta_group=cta_group, + **kwargs, + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + split_arrive: Optional[int] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + assert split_arrive is not None, ( + "split_arrive must be provided when mbar_ptr is not None" + ) + split_arrive_idx = split_arrive // op.shape_mnk[2] + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed( + acc_tmem_addr: Int32, + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_start_b: Int32, + idesc: int, + smem_desc_base_a: Optional[int], + smem_desc_base_b: int, + tCrA_layout: cute.Layout, + tCrB_layout: cute.Layout, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + else: + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] + + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + # smem_desc_start_a_lo = smem_desc_start_a + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(num_k_tile // 4 * 3, num_k_tile) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_smem_desc( + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_base_a: Optional[int], + tCrA_layout: cute.Layout, + var_name_prefix: str = "smem_desc", +) -> None: + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + smem_desc_base_a_lo, smem_desc_a_hi = None, None + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + if const_expr(not is_ts): + llvm.inline_asm( + None, + [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], + f".reg .b32 {var_name_prefix}_lo;\n\t" + f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" + f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + + "".join( + ( + f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" + f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" + ) + for k in range(1, num_k_tile) + ), + "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: + idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) + llvm.inline_asm( + None, + [], + f".reg .b32 {var_name};\n\t" # noqa + f"mov.b32 {var_name}, {hex(idesc)};\n\t", + constraints="", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed_varname( + acc_tmem_addr: Int32, + smem_desc_start_b: Int32, + # idesc: int, + smem_desc_base_b: int, + tCrB_layout: cute.Layout, + smem_var_name_prefix: str, + idesc_var_name: str, + smem_offset: int, + zero_init: bool | Boolean = False, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + is_ts = False + num_k_tile = cute.size(tCrB_layout.shape[2]) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + # ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + # ".reg .b64 smem_desc_b;\n\t" + f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + # f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $2;\n\t" + "mov.b32 smem_desc_b_lo_start, $0;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + + "".join( + ( + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + ) + for k in range(1, num_k_tile) + ) + + "setp.ne.b32 p, $1, 0;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + + "".join( + ( + # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py new file mode 100644 index 0000000000000000000000000000000000000000..463290ab3b022a8883e7d40b84ff1ab31827e5dc --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...src.common.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n)) + n_block_min = 0 + if num_splits > 1: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) + return n_block_min, n_block_max + + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_block_max = cute.ceil_div( + seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m + ) + m_block_min = 0 + if const_expr(self.is_causal): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx *= self.qhead_per_kvhead_packgqa + m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m) + return m_block_min, m_block_max diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98ba5f40b7b9543744e663a96bcdf637c7e2a146 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py @@ -0,0 +1,1179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Copy, store, and layout execution helpers. + +`copy_utils.py` is the canonical owner for generic copy primitives, async +bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers. +""" + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +# Generic Copy Primitives + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +# Store/Layout Helpers + +@dsl_user_op +def atomic_add_i32(gmem_ptr, *, loc=None, ip=None): + """Simple atomicAdd. Intended for use under a single-thread guard.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "atom.global.add.u32 $0, [$1], 1;\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None): + """Lane-0 atomicAdd broadcast to the whole warp via shfl.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "{\n" + ".reg .pred p;\n" + ".reg .u32 lane, r;\n" + "mov.u32 lane, %laneid;\n" + "mov.u32 r, 0;\n" + "setp.eq.u32 p, lane, 0;\n" + "@p atom.global.add.u32 r, [$1], 1;\n" + "shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n" + "mov.u32 $0, r;\n" + "}\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def stg_128( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.bf16.f32 h0, $5;\n" + "cvt.rn.bf16.f32 h1, $6;\n" + "cvt.rn.bf16.f32 h2, $7;\n" + "cvt.rn.bf16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.f16.f32 h0, $5;\n" + "cvt.rn.f16.f32 h1, $6;\n" + "cvt.rn.f16.f32 h2, $7;\n" + "cvt.rn.f16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_32_fp8_e4m3( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n" + "mov.b32 p0, {h0, h1};\n" + "st.global.b32 [$4], p0;\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_bf16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two bf16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.bf16.f32 h0, $1;\n" + "cvt.rn.bf16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_f16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two fp16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .f16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.f16.f32 h0, $1;\n" + "cvt.rn.f16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_fp8_e4m3_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [ + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + ] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + Float32(v8).ir_value(loc=loc, ip=ip), + Float32(v9).ir_value(loc=loc, ip=ip), + Float32(v10).ir_value(loc=loc, ip=ip), + Float32(v11).ir_value(loc=loc, ip=ip), + Float32(v12).ir_value(loc=loc, ip=ip), + Float32(v13).ir_value(loc=loc, ip=ip), + Float32(v14).ir_value(loc=loc, ip=ip), + Float32(v15).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n" + "cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n" + "cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n" + "cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n" + "cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n" + "cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n" + "cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; " + "mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; " + "mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; " + "mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; " + "mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;", + ( + "=f,=f,=f,=f,=f,=f,=f,=f," + "=f,=f,=f,=f,=f,=f,=f,=f," + "l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f" + ), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + *acc_layout_col_major.shape[2:], + ), + stride=( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + *acc_layout_col_major.stride[2:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor: + layout = convert_layout_acc_mn( + convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout) + ) + return cute.make_tensor(tensor.iterator, layout) + + +def real_col_to_stg128_fake_col(col: Int32) -> Int32: + nt = col // Int32(16) + col16 = col - nt * Int32(16) + pair = col16 // Int32(2) + rank = pair % Int32(4) + kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2)) + return nt * Int32(16) + rank * Int32(4) + kv + + +def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(16) + fake16 = fake_col - nt * Int32(16) + rank = fake16 // Int32(4) + kv = fake16 % Int32(4) + return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2)) + + +def real_col_to_stg128_half_fake_col(col: Int32) -> Int32: + nt = col // Int32(32) + col32 = col - nt * Int32(32) + lane = (col32 % Int32(8)) // Int32(2) + group = col32 // Int32(8) + elem = col32 % Int32(2) + return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem + + +def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(32) + fake32 = fake_col - nt * Int32(32) + lane = fake32 // Int32(8) + lane_slot = fake32 - lane * Int32(8) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem + + +def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32: + nt = col // Int32(64) + col64 = col - nt * Int32(64) + lane = (col64 % Int32(8)) // Int32(2) + group = col64 // Int32(8) + elem = col64 % Int32(2) + return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem + + +def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(64) + fake64 = fake_col - nt * Int32(64) + lane = fake64 // Int32(16) + lane_slot = fake64 - lane * Int32(16) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem + + +# Cluster & Bulk Async Ops + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_s2cluster( + smem_src_ptr: cute.Pointer, + smem_dst_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + size: int | Int32, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +): + smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() + smem_dst_ptr_i32 = set_block_rank( + smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [ + smem_dst_ptr_i32, + smem_src_ptr_i32, + mbar_ptr_i32, + Int32(size).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +# TMA Copy Adapters + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +__all__ = [ + "atomic_add_broadcast_i32", + "atomic_add_fp32x4", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "copy", + "cpasync_bulk_g2s", + "cpasync_bulk_get_copy_fn", + "cpasync_bulk_s2cluster", + "cpasync_reduce_bulk_add_f32", + "cvt_copy", + "get_copy_atom", + "load_s2r", + "make_16x256b_tensor_mn_view", + "make_tmem_copy", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "set_block_rank", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "sts_32_bf16", + "sts_32_f16", + "store_shared_remote_fp32x4", + "tiled_copy_1d", + "tiled_copy_2d", + "tma_get_copy_fn", + "tma_producer_copy_fn", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3473fbbf77fa1261abfc8fd960102c70d3e64bd --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import logging +import os +import pathlib +import time +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +logger = logging.getLogger("minimax") + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta +from cutlass.cute.runtime import from_dlpack + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile. + + Behaviour: + - Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set. + - Logs JIT compile wall time at DEBUG level via the ``minimax`` logger, + tagged with the kernel's class name when available. Enable with + ``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env + ``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT + (~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md). + """ + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + kernel_obj = args[0] if args else kwargs.get("op") + kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "" + t0 = time.time() + output = cute_compile_og(*args, **kwargs) + dt = time.time() - t0 + logger.debug("[%s] compiled in %.1fs", kernel_name, dt) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1": + if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + logger.addHandler(_h) + logger.setLevel(logging.DEBUG) + + +# Monkey-patch cute.compile so every JIT compile across the repo gets timed +# without touching individual call sites. Idempotent: only patches once. +if cute.compile is not cute_compile_patched: + cute.compile = cute_compile_patched + + +def assume_strides_aligned(t): + """Assume all strides except the last are divisible by 128 bits. + + Python int strides (e.g., stride=0 from GQA expand) are kept as-is + since they're static and don't need alignment assumptions. + """ + divby = 128 // t.element_type.width + strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1]) + return (*strides, t.stride[-1]) + + +def assume_tensor_aligned(t): + """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.""" + if t is None: + return None + return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t))) + + +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py new file mode 100644 index 0000000000000000000000000000000000000000..63a8b4a501ac499e372056a07d499832c830b474 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0da42c3be9bf1c3dcff81ccde579b54131bfa4c6 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Callable, Optional, TypeAlias +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Uint32, const_expr + +from ...src.common import utils as utils +from ...src.common.seqlen_info import SeqlenInfoQK + +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + + +@cute.jit +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) + + +@cute.jit +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) + + +@cute.jit +def mask_r2p_lambda( + X: cute.Tensor, + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, +) -> None: + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)): + mask = mask_gen_fn(s) + for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * MASK_R2P_CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) + + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_info: SeqlenInfoQK + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + swap_AB: cutlass.Constexpr[bool] = False + + @property + def seqlen_q(self) -> Int32: + return self.seqlen_info.seqlen_q + + @property + def seqlen_k(self) -> Int32: + return self.seqlen_info.seqlen_k + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + row_idx: Optional[Int32] = None, + kv_valid_cols: Optional[Int32] = None, + kv_block_col_start: Optional[Int32] = None, + ) -> None: + if const_expr(not mask_seqlen and not mask_causal): + return + + col_limit = Int32(self.tile_n) + if const_expr(mask_seqlen): + if const_expr(kv_valid_cols is not None): + col_limit = kv_valid_cols + else: + col_limit = self.seqlen_k - n_block * Int32(self.tile_n) + + if const_expr(mask_causal): + if const_expr(row_idx is None): + row_axis = 0 if const_expr(not self.swap_AB) else 1 + row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa) + else: + row_idx_cur = row_idx + if const_expr(kv_block_col_start is not None): + block_col_start = kv_block_col_start + else: + block_col_start = n_block * Int32(self.tile_n) + causal_col_limit = ( + row_idx_cur + self.seqlen_k - self.seqlen_q + - block_col_start + Int32(1) + ) + col_limit = ( + cutlass.min(col_limit, causal_col_limit) + if const_expr(mask_seqlen) + else causal_col_limit + ) + + if col_limit < Int32(self.tile_n): + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit, s), + rank1=True, + ) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + is_full_block: bool = False, + check_m_boundary: bool = True, + valid_tok_count: Optional[Int32] = None, + q_idx_tile: Optional[cute.Tensor] = None, + masked_tok_count: Optional[Int32] = None, + ) -> None: + del is_full_block, check_m_boundary + del t0ScS_t2r + row_axis = 0 if const_expr(not self.swap_AB) else 1 + col_axis = 1 if const_expr(not self.swap_AB) else 0 + + if const_expr(valid_tok_count is not None): + kv_block_col_start = n_block * Int32(self.tile_n) + causal_q_offset = self.seqlen_k - self.seqlen_q + nfrag = const_expr(cute.size(acc_S.shape)) + for i in cutlass.range(nfrag, unroll_full=True): + row_idx = tScS_t2r[i][row_axis] + tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa) + acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i] + if const_expr(mask_seqlen): + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i] + if const_expr(mask_causal): + if const_expr(q_idx_tile is not None): + causal_tok_count = ( + masked_tok_count + if const_expr(masked_tok_count is not None) + else Int32(0) + ) + if tok_idx < causal_tok_count: + q_idx = q_idx_tile[tok_idx] + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = ( + -Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i] + ) + return + + thr_col_offset = tScS_t2r[0][col_axis] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + + if const_expr(not mask_causal): + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + return + + thr_row_offset = tScS_t2r[0][row_axis] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + row_limit_top = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + num_rep = cute.size(tScS_t2r, mode=[0]) + row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..53c58d17f5085d207f2a1d7b6b45d627ff3322e3 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT +# +# The bit-field encodings, enum values, and descriptor layout below mirror the +# SM100 tcgen05 MMA instruction descriptor as documented and +# implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay +# identical to the hardware/ISA encodings; see the "Third-party licenses" +# section of README.md at the repo root for attribution. + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix "layout" in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type -> encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 + if cutlass_type is cutlass.Float8E4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.Float8E5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for SM100 MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + # CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this + # descriptor aligned with generated/reference SM100 FP8 kernels. + desc |= (int(is_f8f6f4) & 0x1) << 23 + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 + + +def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: + sA_swizzle = sA.iterator.type.swizzle_type + return make_smem_desc_base( + cute.recast_layout(128, sA.element_type.width, sA.layout[0]), + sA_swizzle, + major, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..a7722a471ca011a94d5fd7774224906001979b78 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import enum + + +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + LoadWG = enum.auto() + StoreEpilogue = enum.auto() + KvLoad = enum.auto() + KvDequantK = enum.auto() + KvDequantV = enum.auto() diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dc25edd3f48fbe2c77ec94c8ab3f1ea417507 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""PackGQA primitives for GQA (grouped-query attention) tile layouts. + +Contains: +- ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead`` + into the seqlen dimension of a tensor layout (zero-copy view). +- ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / + ``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)`` + as a single packed row dimension. +- ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE`` + for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q + innermost (stride-1). +""" + +from dataclasses import dataclass +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ...quack import layout_utils + +from . import utils + + +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass +class PackGQA: + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + +@dataclass +class PackGQAComb(PackGQA): + """PackGQA subclass for the K2 combine kernel. + + Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from + ``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when + LSE_partial is laid out with H_q innermost. + + K2 combine treats each query head independently (no GQA grouping in combine + itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller — + all heads are folded into one "group" per Sq position. + """ + + @cute.jit + def load_LSE( + self, + mLSE_partial: cute.Tensor, + # Packed layout after caller-side reshape: + # shape ((qhead_per_kvhead, seqlen_q), num_splits) + # stride ((1, qhead_per_kvhead), ...) + # — H_q is the innermost (stride-1) element of the packed first dim. + sLSE: cute.Tensor, + # SMEM destination: ``(topk, m_block_size)`` fp32. + topk: cutlass.Constexpr[int], + # Explicit topk so the identity tensor shape is a plain int, + # avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape. + gmem_tiled_copy: cute.TiledCopy, + tidx: Int32, + block: Int32, + num_splits: Int32, + seqlen: Int32, + num_heads_divmod: FastDivmodDivisor, + mCounter: Optional[cute.Tensor] = None, + batch_idx: Optional[Int32] = None, + qhead_per_kvhead: Int32 = Int32(1), + # divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so + # caller controls whether the divisor is constexpr or a runtime value. + ): + """Coalesced GMEM→SMEM async load of LSE_partial for one tile. + + For each (split, row) slot this thread owns in the tile, compute the + GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32. + Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``) + are filled with ``-inf`` so they flow cleanly through downstream reductions. + + Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values + (head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to + adjacent GMEM addresses when H_q is stride-1 — one sector per warp. + """ + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cLSE = cute.make_identity_tensor((topk, self.m_block_size)) + tLSEcLSE = gmem_thr_copy.partition_S(cLSE) + tLSEsLSE = gmem_thr_copy.partition_D(sLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = block * self.m_block_size + mi + m_pos, h_pos = divmod(idx, num_heads_divmod) + + if m_pos < seqlen: + row_count = ( + mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead] + if const_expr(mCounter is not None) + else num_splits + ) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + # Build a 1-element GMEM tensor at ((h_pos, m_pos), si), + # matching PackGQA.store_LSE's ptr pattern so cute.copy + # receives a proper Tensor, not a scalar. + src_ptr_i64 = utils.elem_pointer( + mLSE_partial, ((h_pos, m_pos), si)).toint() + src_ptr = cute.make_ptr( + Float32, src_ptr_i64, + cute.AddressSpace.gmem, assumed_align=4, + ) + src_t = cute.make_tensor(src_ptr, (1,)) + cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5f6923c42a826d4f3dd1f192ce2fdb38eefbf5 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + + +@dataclass(frozen=True) +class PagedKVManager: + mPageTable: cute.Tensor + page_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + + @staticmethod + def create( + mPageTable: cute.Tensor, + *, + page_size: int, + n_block_size: int, + ): + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + return PagedKVManager( + mPageTable, + page_size=page_size, + n_block_size=n_block_size, + ) + + @cute.jit + def logical_length( + self, + batch_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + if const_expr(mSeqUsedK is not None): + return mSeqUsedK[batch_idx] + return num_kv_blocks * Int32(self.n_block_size) + + @cute.jit + def valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def physical_block_index( + self, + batch_idx: Int32, + kv_block_idx: Int32, + ) -> Int32: + return self.mPageTable[batch_idx, kv_block_idx] + +__all__ = ["PagedKVManager"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..27f711772f5c6fa16a86f4aa305f42a0ca9322eb --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate, dsl_user_op +from cutlass.pipeline import PipelineState +from cutlass.pipeline import PipelineUserType +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg +import cutlass.pipeline as cutlass_pipeline + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """Compatibility wrapper for FA-style helpers now vendored into src.common.""" + return cutlass_pipeline.make_pipeline_state(type, stages) + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + @staticmethod + def create(*args, **kwargs): + obj = NamedBarrierOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", NamedBarrier) + return obj + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +@dataclass(frozen=True) +class PipelineAsync(PipelineAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineAsync + object.__setattr__(obj, "__class__", PipelineAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_try_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + *, + loc=None, + ip=None, + ): + return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineTmaAsyncOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineTmaAsync) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineTmaUmma + object.__setattr__(obj, "__class__", PipelineTmaUmma) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx( + state.index, tx_count, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineUmmaAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineUmmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineUmmaAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsyncUmmaOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineAsyncUmma) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py new file mode 100644 index 0000000000000000000000000000000000000000..873304f71c2cb47ffdd1453fe771c754783f51a4 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...quack import copy_utils + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + + +@dataclass(frozen=True) +class SeqlenInfo: + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + batch_idx: Int32, + seqlen_static: Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, + ): + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) + ) + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] + else: + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + return SeqlenInfoQK( + offset_q, + offset_k, + padded_offset_q, + padded_offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, + ) + + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + else: + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) + + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) + else: + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8f94c1c9e40aeb44c0a128165d90a502feb04afd --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Online softmax primitives. + +Contains: +- ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O. + The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed + ``fmul.f32x2`` (2× CUDA-core throughput) when available. +- ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``, + ``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel. +""" + +import math +import operator +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from ...quack import layout_utils +from ...quack.cute_dsl_utils import ParamsBase + +from . import utils + + +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is + rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL + compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU + confirms: FFMA2 count is 0 for the TensorSSA path). The packed + rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA + instruction count for the softmax scale/subtract stage. + """ + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + row_scale = cute.make_rmem_tensor_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + + for r in cutlass.range(cute.size(row_max), unroll_full=True): + acc_S_row_slice = acc_S_mn[r, None] + acc_S_row = acc_S_row_slice.load() + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + + row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4) + row_max_prev = row_max[r] + row_max[r] = row_max_cur + + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + + row_max_cur_scaled = row_max_cur * scale_log2 + minus_row_max_scaled = -row_max_cur_scaled + n = cute.size(acc_S_row_slice) + + if cutlass.const_expr(arch >= 100 and n % 2 == 0): + # SM100 packed f32x2 FMA path: scale + subtract in one pass. + for i in cutlass.range(0, n, 2, unroll_full=True): + acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row_slice[i], acc_S_row_slice[i + 1]), + (scale_log2, scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + for i in cutlass.range(n, unroll_full=True): + acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True) + acc_S_row_exp = acc_S_row_slice.load() + else: + acc_S_row_exp = cute.math.exp2( + acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True + ) + acc_S_row_slice.store(acc_S_row_exp) + + if cutlass.const_expr(is_first): + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) + row_scale[r] = 1.0 + else: + row_scale[r] = cute.math.exp2( + (row_max_prev - row_max_cur) * scale_log2, fastmath=True + ) + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_sum[r] = acc_S_row_sum + + return row_scale + + @cute.jit + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp. + + On SM100+ with an even ``num_rows`` and no sink_val, the loop is + unrolled in pairs so the key per-row arithmetic ― rcp*final_scale, + max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one + ``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more + ``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare). + """ + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_rmem_tensor_like(row_max, Float32) + + LN2 = math.log(2.0) + num_rows = cute.size(row_sum) + use_packed = cutlass.const_expr( + self.arch >= 100 and num_rows % 2 == 0 and sink_val is None + ) + + if use_packed: + for r in cutlass.range(0, num_rows, 2, unroll_full=True): + s0 = row_sum[r] + s1 = row_sum[r + 1] + m0 = row_max[r] + m1 = row_max[r + 1] + bad0 = s0 == 0.0 or s0 != s0 + bad1 = s1 == 0.0 or s1 != s1 + + # row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar + # (no packed rcp intrinsic); the trailing multiply packs. + rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0) + rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1) + row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2( + (rcp0, rcp1), (final_scale, final_scale) + ) + + # LSE = (row_max * scale_log2 + log2(row_sum)) * LN2 + # packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2. + log0 = cute.math.log2(s0, fastmath=True) + log1 = cute.math.log2(s1, fastmath=True) + lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2( + (m0, m1), (scale_log2, scale_log2), (log0, log1) + ) + lse_0, lse_1 = cute.arch.mul_packed_f32x2( + (lse_pre_0, lse_pre_1), (LN2, LN2) + ) + row_sum[r] = -Float32.inf if bad0 else lse_0 + row_sum[r + 1] = -Float32.inf if bad1 else lse_1 + else: + for r in cutlass.range(num_rows, unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) + + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + row_scale[r] = ( + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = row_sum[r] + row_sum[r] = ( + (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor.""" + acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + n = cute.size(acc_O_mn, mode=[1]) + if cutlass.const_expr(self.arch >= 100 and n % 2 == 0): + # SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput). + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + scale = row_scale[r] + for j in cutlass.range(0, n, 2, unroll_full=True): + acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2( + (acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale) + ) + else: + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + """SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths.""" + + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + @cute.jit + def update_row_max_deferred_exp2( + self, + acc_S_row: cute.TensorSSA, + is_first: int, + ) -> Tuple[Float32, Float32]: + """update_row_max variant that publishes the log2-delta (un-exp2'd) so + the consumer can do the exp2 only when an actual rescale fires. + + Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where: + - ``row_max_safe`` is the same row-max as ``update_row_max`` (with + ``rescale_threshold`` rollback applied). + - ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when + the threshold rollback fired (consumer treats as no rescale), else + the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2`` + (consumer computes ``cute.math.exp2`` and rescales). + + This keeps MUFU.EX2 off the sm_stats publication critical path that + gates the correction WG's consumer wait. + """ + publish = Float32(0.0) + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + # publish stays 0.0 (signal: no rescale needed) + else: + publish = acc_scale_ + else: + publish = acc_scale_ + self.row_max[0] = row_max_new + return row_max_safe, publish + + @cute.jit + def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0]) + self.row_max[0] = row_max_new + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + + @cute.jit + def compute_scaled_exp2_row_sum( + self, + acc_S_row: cute.Tensor, + scale: Float32, + ) -> Float32: + return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch) + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + else: + if cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True + ) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert_sum( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + init_sum: Float32, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ) -> Float32: + # When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res + # pairs in the inner loop use the FFMA2-based polynomial ex2 emulation + # (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's + # apply_exp2_convert. This removes the MUFU "wait" stall that dominates + # the second-largest stall bucket in decode (~22% of total). + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + acc_sum = (init_sum, Float32(0.0)) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2( + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + use_real = cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ) + if cutlass.const_expr(use_real): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + utils.ex2_emulation_2( + acc_S_row_frg[k, j], + acc_S_row_frg[k + 1, j], + ) + ) + acc_sum = cute.arch.add_packed_f32x2( + acc_sum, + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + return acc_sum[0] + acc_sum[1] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..985b4289e146288355dfecd7169383eb64df4f09 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py @@ -0,0 +1,967 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable +from dataclasses import dataclass + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams + +from ...quack.cute_dsl_utils import ParamsBase + +from ...src.common import utils as utils +from ...src.common.fast_math import clz + + +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `SparseAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + args.use_cluster_idx, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) + return ( + grid_x, + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_cluster_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks_cluster: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) + total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(num_block_cluster), + FastDivmodDivisor(args.num_head), + total_blocks_cluster, + cluster_shape_m=args.cluster_shape_mn[0], + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": + if const_expr(cute.size(params.cluster_shape_m) == 1): + tile_idx = cute.arch.block_idx()[0] + else: + tile_idx = cute.arch.cluster_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + usable_SM_count=0, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + cluster_shape_m = int(params.cluster_shape_m) + if usable_SM_count > 0: + sm_count = usable_SM_count + else: + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m + max_ctas = max(max_ctas, cluster_shape_m) + grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks_cluster + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.cluster_shape_m == 1): + self._tile_idx += cute.arch.grid_dim()[0] + else: + self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + num_head: Int32 + num_batch: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileLPTScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # Seems faster if swizzle is a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), + is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + if const_expr(params.lpt): + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + # Round down to nearest multiple of cluster since odd excess is always padding. + total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx // params.cluster_shape_m + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt or params.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + * params.cluster_shape_m + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + if cutlass.const_expr(params.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_m + bidx_in_cluster[0] + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdc19a08eacf9a060f2c0a7a4d50a4adb735094 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Raw TMA ops and descriptor builders. + +`tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA +descriptor construction. Non-TMA store/layout helpers are re-exported from +`copy_utils.py` for backward compatibility. +""" + +import ctypes + +from cutlass import Int32, Int64 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as cute_ir +import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir +from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen + + +# Raw TMA Ops + +TMA_CACHE_EVICT_FIRST = 0x12F0000000000000 +TMA_CACHE_EVICT_LAST = 0x14F0000000000000 + + +@dsl_user_op +def tma_tile_load( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $9;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None): + """Prefetch a raw TMA descriptor pointer into the descriptor cache.""" + ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip) + ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width) + ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip) + ptr_ty = cute_ir.PtrType.get( + cute_nvgpu_ir.TmaDescriptorTiledType.get(), + cute_ir.AddressSpace.gmem, + 128, + ) + desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip) + cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def tma_tile_prefetch( + tma_desc_ptr, + col_idx, + row_idx, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint " + "[$0, {$1, $2}], $3;\n", + "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_prefetch( + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint " + "[$0, {$1, $2, $3, $4, $5}], $6;\n", + "l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_load_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes.L2::cache_hint " + "[sa], [$3, {$4, $5}], [ma], $7;\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $0;\n" + "add.u32 sa, sa, $1;\n" + "cvt.u32.u64 ma, $8;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint " + "[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n" + "}\n", + "l,r,l,r,r,r,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_store( + tma_desc_ptr, + col_idx, + row_idx, + smem_ptr, + smem_byte_offset, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store.""" + llvm.inline_asm( + T.i32(), + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + "cvt.u32.u64 sa, $4;\n" + "add.u32 sa, sa, $5;\n" + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + " [$1, {$2, $3}], [sa];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,r,l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +# Descriptor Builders + +_TMA_DESC_BYTES = 128 + + +def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes: + import torch + import cuda.bindings.driver as cuda + + if tensor_2d.ndim != 2: + raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}") + rows, cols = tensor_2d.shape + if tensor_2d.stride(-1) != 1: + raise ValueError(f"{context} tensor must be contiguous in the last dimension") + dtype_map = { + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + } + if tensor_2d.dtype not in dtype_map: + raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}") + + sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)] + strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())] + box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)] + elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)] + err, tm = cuda.cuTensorMapEncodeTiled( + dtype_map[tensor_2d.dtype], + 2, + tensor_2d.data_ptr(), + sizes, + strides, + box, + elem_stride, + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}" + buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr()) + return bytes(buf) + + +def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device): + import torch + + desc_bytes = bytes(desc_bytes) + device = torch.device(device) + if device.type != "cuda": + raise ValueError(f"TMA descriptors require a CUDA device, got {device}") + + host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True) + host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8)) + device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device) + stream = torch.cuda.current_stream(device) + with torch.cuda.stream(stream): + device_desc.copy_(host_desc, non_blocking=True) + device_desc.record_stream(stream) + # Keep the staging buffer alive for the async copy without caching descriptors. + device_desc._tma_host_desc = host_desc + return device_desc + + +def create_flat_gather4_tma_desc(tensor_2d, box_x=64): + """Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor.""" + if tensor_2d.ndim != 2: + raise ValueError( + f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}" + ) + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=1, + context="gather4", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_q_gather4_tma_desc(q_flat, box_x=64): + return create_flat_gather4_tma_desc(q_flat, box_x=box_x) + + +def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y): + """Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride.""" + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=box_y, + context="strided 2D", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128): + """Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage.""" + import torch + + if kv_flat.ndim != 3: + raise ValueError( + f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}" + ) + total_k, head_kv, dim = kv_flat.shape + row_stride = head_kv * dim + desc_table = bytearray() + for h in range(head_kv): + head_view = torch.as_strided( + kv_flat, + size=(total_k, dim), + stride=(row_stride, 1), + storage_offset=h * dim, + ) + desc_table.extend( + _encode_tma_desc_2d_bytes( + head_view, + box_x=box_x, + box_y=box_y, + context="flat KV", + ) + ) + return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape( + head_kv, _TMA_DESC_BYTES + ) + + +# Compatibility Re-exports + +from .copy_utils import ( + atomic_add_broadcast_i32, + atomic_add_i32, + convert_layout_acc_mn, + convert_layout_from_tmem16x256b_to_acc_sm90, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, + stg_128, + stg_128_cs, + stg_128_bf16, + stg_128_bf16_cs, + stg_128_f16, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, + stg_32_fp8_e4m3, + stg_64_bf16, + stg_64_f16, +) + + +__all__ = [ + "TMA_CACHE_EVICT_FIRST", + "TMA_CACHE_EVICT_LAST", + "atomic_add_broadcast_i32", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "create_flat_gather4_tma_desc", + "create_flat_kv_tma_descs", + "create_q_gather4_tma_desc", + "create_strided_2d_tma_desc", + "make_16x256b_tensor_mn_view", + "prefetch_tma_desc_raw", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "tma_gather4", + "tma_gather4_cached", + "tma_gather4_prefetch", + "tma_tile_load", + "tma_tile_load_cached", + "tma_tile_prefetch", + "tma_tile_store", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py b/build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bd0ba76b532cb54c159eba5e82320266c80c63 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py @@ -0,0 +1,1088 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import math +import hashlib +import inspect +from typing import Type, Callable, Optional, Tuple, overload + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass.cute.runtime import from_dlpack + + +from ...quack import activation +_MIXER_ATTRS = ("__vec_size__",) + +# Obtained from sollya: +# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); +POLY_EX2 = { + 0: (1.0), + 1: ( + 1.0, + 0.922497093677520751953125, + ), + 2: ( + 1.0, + 0.6657850742340087890625, + 0.330107033252716064453125, + ), + 3: ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ), + 4: ( + 1.0, + 0.693042695522308349609375, + 0.2412912547588348388671875, + 5.2225358784198760986328125e-2, + 1.3434938155114650726318359375e-2, + ), + 5: ( + 1.0, + 0.693151414394378662109375, + 0.24016360938549041748046875, + 5.5802188813686370849609375e-2, + 9.01452265679836273193359375e-3, + 1.86810153536498546600341796875e-3, + ), +} + + +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash + + hasher = hashlib.sha256(base_hash.encode()) + + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) + + return hasher.hexdigest() + + +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale): + """Compute softmax_scale_log2 from softmax_scale. + + Returns (softmax_scale_log2, None). + """ + return softmax_scale * LOG2_E, None + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_rmem_tensor(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + else: + # New API: infers result type automatically + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@cute.jit +def fadd_exp2_scaled_reduce( + x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements" + if const_expr(arch < 100): + return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch) + elif const_expr(cute.size(x.shape) % 8 == 0): + local_sum = [ + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + ] + for i in cutlass.range_constexpr(0, cute.size(x.shape), 8): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i + 0], x[i + 1]), (scale, scale) + ) + acc2, acc3 = cute.arch.mul_packed_f32x2( + (x[i + 2], x[i + 3]), (scale, scale) + ) + acc4, acc5 = cute.arch.mul_packed_f32x2( + (x[i + 4], x[i + 5]), (scale, scale) + ) + acc6, acc7 = cute.arch.mul_packed_f32x2( + (x[i + 6], x[i + 7]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + acc2 = cute.math.exp2(acc2, fastmath=True) + acc3 = cute.math.exp2(acc3, fastmath=True) + acc4 = cute.math.exp2(acc4, fastmath=True) + acc5 = cute.math.exp2(acc5, fastmath=True) + acc6 = cute.math.exp2(acc6, fastmath=True) + acc7 = cute.math.exp2(acc7, fastmath=True) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1)) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3)) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5)) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7)) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + else: + row_sum = Float32(0.0) + for i in cutlass.range_constexpr(0, cute.size(x.shape), 2): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i], x[i + 1]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + row_sum += acc0 + acc1 + return row_sum + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_f32( + a: float | Float32, + b: float | Float32, + c: float | Float32, + d: float | Float32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n" + "mov.b32 $0, {h0, h1};\n" + "}\n", + "=r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_bf16x4( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Convert packed e4m3x4 bits into two packed bf16x2 registers.""" + out0 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "and.b32 out, q, 0x80008000;\n\t" + "and.b32 mant, q, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + out1 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, qs, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "shl.b32 qs, q, 8;\n\t" + "and.b32 out, qs, 0x80008000;\n\t" + "and.b32 mant, qs, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return out0, out1 + + +@dsl_user_op +def cvt_fp4x2_e2m1_f16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert one packed E2M1 byte into one packed f16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0;\n\t" + "mov.b32 {byte0, _, _, _}, $1;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_f16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed f16x2 registers.""" + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + +@dsl_user_op +def cvt_fp4x8_e2m1_bf16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed bf16x2 registers.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8( + src, loc=loc, ip=ip + ) + return ( + cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip), + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_scaled_e4m3x8( + src: cutlass.Int32, + scale_e4m3: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 tmp, ra;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "prmt.b32 tmp, $3, 0, 0;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "mov.b32 ra, {byte0, byte1, _, _};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t" + "mov.b32 ra, {_, _, byte2, byte3};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 sf_bytes, sf_f16x2;\n\t" + ".reg .b16 sf_pair, e0, e1, e2, e3;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + ".reg .b32 h0, h1, h2, h3;\n\t" + "prmt.b32 sf_bytes, $3, 0, 0;\n\t" + "mov.b32 {sf_pair, _}, sf_bytes;\n\t" + "cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t" + "mul.rn.f16x2 h0, h0, sf_f16x2;\n\t" + "mul.rn.f16x2 h1, h1, sf_f16x2;\n\t" + "mul.rn.f16x2 h2, h2, sf_f16x2;\n\t" + "mul.rn.f16x2 h3, h3, sf_f16x2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t" + "mov.b32 $0, {e0, e1};\n\t" + "mov.b32 $1, {e2, e3};\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def cvt_f16x2_to_bf16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert a packed f16x2 register into a packed bf16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b16 h0, h1;\n\t" + ".reg .f32 f0, f1;\n\t" + "mov.b32 {h0, h1}, $1;\n\t" + "cvt.f32.f16 f0, h0;\n\t" + "cvt.f32.f16 f1, h1;\n\t" + "cvt.rn.bf16x2.f32 $0, f1, f0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def mul_bf16x2( + a: cutlass.Int32, + b: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Multiply two packed bf16x2 registers.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Int32(a).ir_value(loc=loc, ip=ip), + cutlass.Int32(b).ir_value(loc=loc, ip=ip), + ], + "mul.rn.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32: + """Decode one E4M3 byte and replicate it into a packed bf16x2 register.""" + + src_u8 = src & cutlass.Int32(0xFF) + packed = src_u8 * cutlass.Int32(0x01010101) + out0, _ = cvt_fp8x4_e4m3_bf16x4(packed) + return out0 + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_rmem_tensor(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@cute.jit +def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None: + """Convert a Float32 rmem tensor to dst's element type. + + fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment + ``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the + DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8 + bytes within a 32-bit register cell in the order DSL chooses, which is + expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads. + """ + if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]): + cvt_f16(src, dst) + elif const_expr(dst.element_type is cutlass.Float8E4M3FN): + assert src.element_type is Float32, "src must be Float32" + assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements" + frg_tile = 4 + src_frg = cute.logical_divide(src, cute.make_layout(frg_tile)) + dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile)) + for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])): + dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type)) + else: + assert src.element_type is Float32, "src must be Float32" + dst_view = cute.make_tensor(dst.iterator, src.layout) + dst_view.store(src.load().to(dst.element_type)) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: + assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" + # We assume x <= 127.0 + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +@dsl_user_op +def ex2_emulation_2( + x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None +) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = activation.sub_packed_f32x2( + xy_rounded, (fp32_round_int, fp32_round_int) + ) + xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_rmem_tensor(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] + + +# ------------------------------------------------------------------ +# Host-side Python helpers (not @cute.jit — called from PyTorch host code) +# ------------------------------------------------------------------ + +def default_softmax_scale(dim: int) -> float: + """Default softmax scale: 1 / sqrt(dim).""" + return 1.0 / math.sqrt(dim) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f23267fe73800d35db382a1919bc28196da5aa8c --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention kernels.""" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf19c60a32d2f57595c9666323b47738b878115 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""q2k -> k2q CSR builder backed by the precompiled Torch ops. + +The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built +ahead of time by kernel-builder; it is reached through the ``_ops`` +namespace instead of being JIT-compiled at import time. + +The kernel pipeline is tuned and verified for SM100; other +architectures are not supported. +""" + +from __future__ import annotations + +import torch + +from ...._ops import ops + + +def run_build_k2q_csr( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, +) -> None: + """In-place fill of ``row_ptr`` and ``q_idx``. + + Args: + q2k: int32 [H, total_q, topK] contiguous (CUDA). + cu_seqlens_q: int32 [B+1] contiguous (CUDA). + cu_seqlens_k: int32 [B+1] contiguous (CUDA). + row_ptr: int32 [H, total_rows + 1] CUDA, written in place. + q_idx: int32 [H, total_q * topK] CUDA, written in place + (trailing slots set to -1). + topk: must be in {4, 8, 16, 32}. + blk_kv: must equal 128. + total_rows: sum over batches of ceil(seqlen_k / blk_kv). + max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound + used to size the row_map workspace and clamp valid kv ids. + """ + ops.run_build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + ) + + +def run_build_k2q_csr_with_schedule( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + qsplit_idx: torch.Tensor, + split_counts: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, + target_q_per_cta: int, + work_capacity: int, + max_seqlen_q: int, +) -> None: + """In-place fill of CSR plus fused sparse attention schedule metadata.""" + ops.run_build_k2q_csr_with_schedule( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + scheduler_metadata, + work_count, + qsplit_idx, + split_counts, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + int(target_q_per_cta), + int(work_capacity), + int(max_seqlen_q), + ) + + +def is_supported(topk: int, blk_kv: int) -> bool: + return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128 + + +__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..037791818feb030a5969ebf6ac3cc3943cdb7dce --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Split-KV schedule for paged fp8 decode attention. + +The public PageKV representation remains this repo's rectangular page table: +``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only +describes how query tiles and KV chunks are split into work items. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DecodeAttentionSchedule: + split_kv: bool + cta_tile_q: int + num_q_tiles: int + kv_chunk_size_pages: int + kv_chunk_size_tokens: int + work_count: int + padded_work_count: int + partial_rows: int + max_split_count: int + max_grid_size: int + active_blocks_per_sm: int + num_sms: int + base_cta: int + request_indices: torch.Tensor + qo_tile_indices: torch.Tensor + kv_tile_indices: torch.Tensor + merge_indptr: torch.Tensor + o_indptr: torch.Tensor + block_valid_mask: torch.Tensor + kv_pages: torch.Tensor + split_counts: torch.Tensor + + +def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def prepare_decode_schedule( + *, + seqused_k: torch.Tensor, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, +) -> DecodeAttentionSchedule: + """Build paged decode split-KV schedule on the GPU. + + A single CUDA kernel reads ``seqused_k`` on device and writes all + schedule index arrays. Only a small summary tensor is D2H-synced so + the wrapper can size O_partial / pick the kernel grid / choose the + split-vs-non-split compile path. + + ``max_seqlen_k`` is the host-side worst-case bound used to pad the + work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``. + """ + _require_i32_cuda_1d(seqused_k, name="seqused_k") + # Hard cap: current single-CTA schedule kernel stores per-batch state + # in shared memory. Larger batches require a multi-CTA cooperative + # scheduler (unimplemented). Fail fast at the Python boundary so the + # error doesn't surface from inside the CUDA extension. + if int(seqused_k.shape[0]) > 1024: + raise NotImplementedError( + "decode schedule currently supports batch <= 1024 " + f"(got batch={int(seqused_k.shape[0])}). Larger batches need " + "the multi-CTA scheduler — not yet implemented." + ) + # Two API-boundary checks tied to the kernel's packed-GQA layout + # (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8): + # + # (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as + # col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first + # q-token in the packed group) this is col_limit = seqlen_k - seqlen_q + # + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked + # row then enters a mask-codegen path with PTX-undefined shift counts + # and the kernel hangs. The condition is also semantically invalid + # in batched-decode: you can't emit seqlen_q new tokens with fewer + # than seqlen_q total context tokens (seqlen_k includes them). + # + # (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires + # when the LAST partial page has < q_tokens_per_group=8 valid + # columns, because then the *last MMA tile* hits the same all-masked + # row case for the trailing q-tokens. + # + # Both are tracked as a separate kernel-level TODO (un-pack the + # all-masked row → skip mask call, or saturate causal_col_limit at >= 1 + # in mask.py). Until then, fail fast at the Python boundary with a + # clear message rather than letting the kernel timeout. + seqlen_q_i = int(seqlen_q) + bad_q = seqused_k < seqlen_q_i + if bool(bad_q.any().item()): + bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) " + f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. " + f"This is also a batched-decode invariant: seqlen_k must include " + f"the seqlen_q new tokens being emitted." + ) + rem = seqused_k % int(page_size) + bad_rem = (rem > 0) & (rem < seqlen_q_i) + if bool(bad_rem.any().item()): + bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] % page_size ∈ " + f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. " + f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has " + f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). " + f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to " + f"a multiple of {page_size}." + ) + if int(page_size) <= 0: + raise ValueError("page_size must be positive") + if int(seqlen_q) <= 0: + raise ValueError("seqlen_q must be positive") + if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0: + raise ValueError("head counts must be positive") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(num_qo_heads) // int(num_kv_heads) != 16: + raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16") + if int(head_dim) != 128: + raise NotImplementedError("decode schedule currently supports only head_dim=128") + if int(max_seqlen_k) <= 0: + raise ValueError("max_seqlen_k must be positive") + + from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule + + raw = build_decode_schedule( + seqused_k, + page_size=int(page_size), + seqlen_q=int(seqlen_q), + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + head_dim=int(head_dim), + max_seqlen_k=int(max_seqlen_k), + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=0 if max_grid_size is None else int(max_grid_size), + fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size), + disable_split_kv=bool(disable_split_kv), + ) + return DecodeAttentionSchedule( + split_kv=bool(raw["split_kv"]), + cta_tile_q=int(raw["cta_tile_q"]), + num_q_tiles=int(raw["num_q_tiles"]), + kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]), + kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]), + work_count=int(raw["work_count"]), + padded_work_count=int(raw["padded_work_count"]), + partial_rows=int(raw["partial_rows"]), + max_split_count=int(raw["max_split_count"]), + max_grid_size=int(raw["max_grid_size"]), + active_blocks_per_sm=int(raw["active_blocks_per_sm"]), + num_sms=int(raw["num_sms"]), + base_cta=int(raw["base_cta"]), + request_indices=raw["request_indices"], + qo_tile_indices=raw["qo_tile_indices"], + kv_tile_indices=raw["kv_tile_indices"], + merge_indptr=raw["merge_indptr"], + o_indptr=raw["o_indptr"], + block_valid_mask=raw["block_valid_mask"], + kv_pages=raw["kv_pages"], + split_counts=raw["split_counts"], + ) + + +__all__ = [ + "DecodeAttentionSchedule", + "prepare_decode_schedule", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa83e39a5504ac6cf8d732255e495e48b35fa20a --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py @@ -0,0 +1,1956 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 FP4 sparse-attention indexer kernels.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from ...src.common import pipeline as common_pipeline + + +FP4_FORMAT = Literal["mxfp4", "nvfp4"] +_FP4_PACKED_D_BYTES = 64 +_HEAD_DIM = 128 +_BLOCK_K = 128 +_PAGE_SIZE = 128 +_MMA_TILER_MN = (128, 128) +_MMA_INST_SHAPE_K = 64 +_NON_CAUSAL_K_TILES_PER_CTA = 16 +_CAUSAL_K_TILES_PER_CTA = 16 +_DECODE_PACK_Q_LEN = 8 +_DECODE_QHEAD_PER_KV = 16 +_DECODE_K_TILES_PER_CTA = 16 +_AB_DTYPE = cutlass.Float4E2M1FN + + +@dataclass(frozen=True) +class Fp4FormatSpec: + name: FP4_FORMAT + sf_vec_size: int + scale_groups: int + torch_scale_dtype: torch.dtype + cutlass_scale_dtype: type + + +_FORMAT_SPECS: dict[str, Fp4FormatSpec] = { + "mxfp4": Fp4FormatSpec( + name="mxfp4", + sf_vec_size=32, + scale_groups=4, + torch_scale_dtype=torch.float8_e8m0fnu, + cutlass_scale_dtype=cutlass.Float8E8M0FNU, + ), + "nvfp4": Fp4FormatSpec( + name="nvfp4", + sf_vec_size=16, + scale_groups=8, + torch_scale_dtype=torch.float8_e4m3fn, + cutlass_scale_dtype=cutlass.Float8E4M3FN, + ), +} + + +def normalize_fp4_format(fmt: str) -> Fp4FormatSpec: + key = str(fmt).lower() + try: + return _FORMAT_SPECS[key] + except KeyError as exc: + raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc + + +def ceil_div(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def k_tiles_per_cta_for(causal: bool) -> int: + return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA + + +class Fp4IndexerScaleReorderSm100: + """Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, page_count, heads_k = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = cute.ceil_div(self.scale_groups, 4) + k_l = page_count * heads_k + + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (total_q, heads_q, self.scale_groups), + stride=(heads_q * self.scale_groups, self.scale_groups, 1), + ), + ) + k_scale = cute.make_tensor( + k_scale_ptr, + cute.make_layout( + (page_count, heads_k, _PAGE_SIZE, self.scale_groups), + stride=( + heads_k * _PAGE_SIZE * self.scale_groups, + _PAGE_SIZE * self.scale_groups, + self.scale_groups, + 1, + ), + ), + ) + + q_mma_layout = cute.make_ordered_layout( + (32, 4, rest_q_m, 4, rest_g, heads_q), + order=(2, 1, 4, 0, 3, 5), + ) + k_mma_layout = cute.make_ordered_layout( + (32, 4, 1, 4, rest_g, k_l), + order=(2, 1, 4, 0, 3, 5), + ) + q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout) + k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout) + q_scale_mma = cute.group_modes(q_scale_mma, 0, 3) + q_scale_mma = cute.group_modes(q_scale_mma, 1, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 0, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 1, 3) + + q_scale_count = total_q * heads_q * Int32(self.scale_groups) + k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups) + total_scale_count = q_scale_count + k_scale_count + grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta) + self.kernel( + q_scale, + k_scale, + q_scale_mma, + k_scale_mma, + heads_q, + heads_k, + q_scale_count, + total_scale_count, + ).launch( + grid=(grid_ctas, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + q_scale: cute.Tensor, + k_scale: cute.Tensor, + q_scale_mma: cute.Tensor, + k_scale_mma: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + q_scale_count: Int32, + total_scale_count: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + block_idx, _, _ = cute.arch.block_idx() + grid_dim, _, _ = cute.arch.grid_dim() + linear = block_idx * Int32(self.threads_per_cta) + tidx + stride = grid_dim * Int32(self.threads_per_cta) + + while linear < total_scale_count: + if linear < q_scale_count: + group = linear % Int32(self.scale_groups) + tmp = linear // Int32(self.scale_groups) + head = tmp % heads_q + row = tmp // heads_q + q_scale_mma[row, group, head] = q_scale[row, head, group] + else: + k_linear = linear - q_scale_count + group = k_linear % Int32(self.scale_groups) + tmp = k_linear // Int32(self.scale_groups) + row = tmp % Int32(_PAGE_SIZE) + tmp = tmp // Int32(_PAGE_SIZE) + head = tmp % heads_k + page = tmp // heads_k + scale_l = page * heads_k + head + k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group] + linear += stride + + +class Fp4IndexerStagedMmaSm100: + """Single-kernel FP4 indexer for preordered MMA scale storage.""" + + def __init__( + self, + *, + fmt: str, + causal: bool, + preordered_q_scale_tma: bool = False, + compact_schedule: bool = False, + use_tmem_load_red: bool = False, + ): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.preordered_q_scale_tma = bool(preordered_q_scale_tma) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal) + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + m, + _, + k, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + compact_task_count, + ) = problem_size + page_count = lk // heads_k + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (total_q, _HEAD_DIM, heads_q), + stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (total_q, _HEAD_DIM, heads_q), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor( + kv_indices_ptr, + cute.make_layout((page_count,), stride=(1,)), + ) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + if const_expr(self.preordered_q_scale_tma): + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + else: + tma_qs = tma_q + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0]) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid_x = compact_task_count + else: + grid_x = grid_q_tiles * grid_k_groups + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + q_scale_tensor, + k_scale_tensor, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + has_qo_offset, + max_k_tiles, + grid_k_groups, + ).launch( + grid=(grid_x, batch * heads_q, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset + return True + + @cute.jit + def _full_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.jit + def _partial_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + q_len: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mQS: cute.Tensor, + mKS: cute.Tensor, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + k_group_count: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + lane_idx = cute.arch.lane_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_idx, q_l, _ = cute.arch.block_idx() + batch_idx = q_l // heads_q + hq = q_l - batch_idx * heads_q + hk = hq // (heads_q // heads_k) + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + task_valid = True + q_tile_idx = Int32(0) + ktile_group = Int32(0) + if const_expr(self.compact_schedule): + remaining = task_idx + q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0]) + batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + q_scan = Int32(0) + task_valid = False + while q_scan < q_tile_count and not task_valid: + q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0]) + q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_scan_last >= q_len: + q_scan_last = q_len - Int32(1) + visible_limit = q_scan_last + causal_offset + visible_group_count = Int32(0) + if visible_limit >= Int32(0): + visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1) + if visible_group_count > batch_k_group_count: + visible_group_count = batch_k_group_count + task_valid = remaining < visible_group_count + if not task_valid: + remaining -= visible_group_count + q_scan += Int32(1) + if task_valid: + q_tile_idx = q_scan + ktile_group = remaining + else: + q_len = Int32(0) + k_len = Int32(0) + else: + q_tile_idx = task_idx // k_group_count + ktile_group = task_idx - q_tile_idx * k_group_count + q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0]) + q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_tile_last >= q_len: + q_tile_last = q_len - Int32(1) + q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len + q_tile_global_start = q_begin + q_tile_start + q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128) + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_tile_start, + q_tile_last, + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCsQ = thr_mma.partition_A(sQ_public) + tCsK = thr_mma.partition_B(sK_public) + mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma) + gQ_tma = cute.local_tile( + mQ_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + if const_expr(self.preordered_q_scale_tma): + mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma) + gQS_tma = cute.local_tile( + mQS_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + sQS = sQS_public + sKS = sKS_public + + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + if const_expr(self.preordered_q_scale_tma): + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_tma_copy_bytes, + defer_sync=True, + ).make_participants() + if const_expr(self.preordered_q_scale_tma): + qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.qs_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=qs_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + if warp_idx == self.load_warp_id: + if group_has_visible: + q_empty = q_producer.acquire_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_empty = qs_producer.acquire_and_advance() + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, q_tile_idx, 0, hq)], + tQsQS_tma[(None, qs_empty.index)], + tma_bar_ptr=qs_empty.barrier, + ) + qs_empty.commit() + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + cute.copy( + tma_q.atom, + tQgQ_tma[(None, q_tile_idx, 0, hq)], + tQsQ_tma[(None, q_empty.index)], + tma_bar_ptr=q_empty.barrier, + ) + q_empty.commit() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Move block scales into TMEM and issue one FP4 GEMM per visible K tile. + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_full = q_consumer.wait_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_full = qs_consumer.wait_and_advance() + qs_full.release() + q_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + ktile = Int32(0) + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx == self.load_warp_id: + if group_has_visible: + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Load accumulators from TMEM, reduce per-row max, and store scores. + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + q_local_store0 = q_tile_start + epi_tidx + q_global_store0 = q_begin + q_local_store0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta) + q_global_store1 = q_begin + q_local_store1 + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + tile_full = q_tile_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + if tile_mask_free: + if tile_full: + if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if coord_m == epi_tidx and q_local < q_len and k_local < k_len: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len: + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + if tile_full: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._full_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._full_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._partial_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._partial_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + if q_tile_full: + mScores[hq, ktile, q_global_store0] = row_max0 + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = row_max0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if q_tile_full: + mScores[hq, ktile, q_global_store1] = row_max1 + elif q_local_store1 < q_len: + mScores[hq, ktile, q_global_store1] = row_max1 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner) + if ktile < max_k_tiles: + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) + +class Fp4IndexerDecodeQPackSm100: + """Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, heads_k, batch = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = ceil_div(self.scale_groups, 4) + q = cute.make_tensor( + q_ptr, + cute.make_layout( + (total_q, heads_q, _FP4_PACKED_D_BYTES), + stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (heads_q, rest_q_m, rest_g, 32, 4, 4), + stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + q_pack_l = batch * heads_k + q_pack = cute.make_tensor( + q_pack_ptr, + cute.make_layout( + (q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale_pack = cute.make_tensor( + q_scale_pack_ptr, + cute.make_layout( + (q_pack_l, 1, rest_g, 32, 4, 4), + stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,))) + self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch( + grid=(q_pack_l, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQS: cute.Tensor, + mQPack: cute.Tensor, + mQSPack: cute.Tensor, + mCuQ: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + q_pack_l, _, _ = cute.arch.block_idx() + batch_idx = q_pack_l // heads_k + hk = q_pack_l - batch_idx * heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + q_len = q_end - q_begin + qhead_per_kv = heads_q // heads_k + + linear = tidx + while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES): + row = linear // Int32(_FP4_PACKED_D_BYTES) + byte = linear - row * Int32(_FP4_PACKED_D_BYTES) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + if q_local < q_len and h_in_group < qhead_per_kv: + mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte] + else: + mQPack[q_pack_l, row, byte] = cutlass.Uint8(0) + linear += Int32(self.threads_per_cta) + + scale_linear = tidx + while scale_linear < Int32(_PAGE_SIZE * self.scale_groups): + row = scale_linear // Int32(self.scale_groups) + group = scale_linear - row * Int32(self.scale_groups) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + q_abs = q_begin + q_local + if q_local >= q_len or h_in_group >= qhead_per_kv: + q_abs = q_begin + hq = hk * qhead_per_kv + src_rest_m = q_abs // Int32(128) + src_row = q_abs - src_rest_m * Int32(128) + src_row_atom = src_row % Int32(32) + src_row_major = src_row // Int32(32) + dst_row_atom = row % Int32(32) + dst_row_major = row // Int32(32) + rest_g = group // Int32(4) + group_in_rest = group - rest_g * Int32(4) + mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[ + hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest + ] + scale_linear += Int32(self.threads_per_cta) + + +class Fp4IndexerDecodePackedQSm100: + """Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``.""" + + def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + @cute.jit + def __call__( + self, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + _, + _, + _, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + ) = problem_size + page_count = lk // heads_k + q_pack_l = batch * heads_k + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_pack_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,))) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid = (compact_k_groups, heads_k, 1) + else: + grid = (grid_k_groups, batch * heads_k, 1) + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + batch, + has_qo_offset, + max_k_tiles, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset + return True + + @cute.jit + def _packed_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + h_in_group: Int32, + qhead_per_kv: Int32, + q_local: Int32, + q_len: Int32, + k_local: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + batch: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_x, task_y, _ = cute.arch.block_idx() + task_valid = True + batch_idx = Int32(0) + hk = Int32(0) + ktile_group = Int32(0) + q_l = Int32(0) + if const_expr(self.compact_schedule): + hk = task_y + group_base = Int32(0) + scan_batch = Int32(0) + task_valid = False + while scan_batch < batch and not task_valid: + batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch] + batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + task_valid = task_x < group_base + batch_groups + if not task_valid: + group_base += batch_groups + scan_batch += Int32(1) + if task_valid: + batch_idx = scan_batch + ktile_group = task_x - group_base + q_l = batch_idx * heads_k + hk + else: + ktile_group = task_x + q_l = task_y + batch_idx = q_l // heads_k + hk = q_l - batch_idx * heads_k + qhead_per_kv = heads_q // heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + if const_expr(self.compact_schedule): + if not task_valid: + q_len = Int32(0) + k_len = Int32(0) + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + gQ_tma = cute.local_tile( + mQ_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + gQS_tma = cute.local_tile( + mQS_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + if warp_idx == self.load_warp_id: + if group_has_visible: + q_pair_empty = q_producer.acquire_and_advance() + cute.copy( + tma_q.atom, + tQgQ_tma[(None, 0, 0, q_l)], + tQsQ_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, 0, 0, q_l)], + tQsQS_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + q_pair_empty.commit() + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS_public) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS_public) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_pair_full = q_consumer.wait_and_advance() + q_pair_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN) + q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN) + h_global_store = hk * qhead_per_kv + h_store + q_global_store = q_begin + q_local_store + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN) + tile_full = q_pack_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + if tile_mask_free and tile_full: + if const_expr(self.use_tmem_load_red): + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN) + q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN) + k_local = ktile * Int32(_BLOCK_K) + coord_n + valid = self._packed_coord_visible( + coord_m, + epi_tidx, + h_in_group, + qhead_per_kv, + q_local, + q_len, + k_local, + k_len, + causal_offset, + ) + if valid: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = row_max0 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b99aea3f8b4915c03fe8147127374d920970f3 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 forward kernels and combine paths.""" + +from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 + +__all__ = ["SparseAttentionForwardNvfp4KvSm100"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..531b27c9e6b4bd8c1bc74fb1f92ed98a192ca0b2 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py @@ -0,0 +1,3020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- Sparse Attention with flat varlen K/V +- Sparse Page Attention with paged K/V +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardSm100: + """SM100 sparse attention forward kernel.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + qk_dtype=None, + pv_dtype=None, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.qk_dtype_param = qk_dtype + self.pv_dtype_param = pv_dtype + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P dtype follows the PV operand policy and is packed into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mV: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_input_dtype = mK.element_type + self.v_input_dtype = mV.element_type + self.qk_dtype = ( + self.q_dtype if const_expr(self.qk_dtype_param is None) else self.qk_dtype_param + ) + if const_expr(self.pv_dtype_param is None): + legacy_fp8_kv_cache = ( + self.q_dtype == cutlass.BFloat16 + and self.k_input_dtype == cutlass.Float8E4M3FN + and self.v_input_dtype == cutlass.Float8E4M3FN + ) + self.pv_dtype = cutlass.BFloat16 if legacy_fp8_kv_cache else self.v_input_dtype + else: + self.pv_dtype = self.pv_dtype_param + self.k_dtype = self.qk_dtype + self.v_dtype = self.pv_dtype + self.p_dtype = self.pv_dtype + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported Q/K/V dtype: {self.q_dtype}") + if const_expr(self.qk_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported qk_dtype: {self.qk_dtype}") + if const_expr(self.pv_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported pv_dtype: {self.pv_dtype}") + if const_expr(self.q_dtype != self.qk_dtype): + raise TypeError("Q storage dtype must match qk_dtype") + if const_expr( + self.k_input_dtype != self.k_dtype + and not (self.k_input_dtype == cutlass.Float8E4M3FN and self.k_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 K -> BF16 QK staging is supported") + if const_expr( + self.v_input_dtype != self.v_dtype + and not (self.v_input_dtype == cutlass.Float8E4M3FN and self.v_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 V -> BF16 PV staging is supported") + self.k_fp8_to_bf16 = ( + self.k_input_dtype == cutlass.Float8E4M3FN + and self.k_dtype == cutlass.BFloat16 + ) + self.v_fp8_to_bf16 = ( + self.v_input_dtype == cutlass.Float8E4M3FN + and self.v_dtype == cutlass.BFloat16 + ) + self.kv_fp8_to_bf16 = self.k_fp8_to_bf16 or self.v_fp8_to_bf16 + self.qk_mma_kind = "f8f6f4" if const_expr(self.qk_dtype.width == 8) else "f16" + self.pv_mma_kind = "f8f6f4" if const_expr(self.pv_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.p_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV = [assume_tensor_aligned(t) for t in (mK, mV)] + + if const_expr(not self.paged_kv): + # Flat varlen K/V use CUTE-managed TMA descriptors, matching FA: + # K: [total_k, h, d] -> [total_k, d, h]. + # V: [total_k, h, d] -> [d, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Sparse Page Attention with page-sized blocks can use the blocked + # paged TMA layout directly. Host input is [page, head, token, dim]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d,h,b) -> (d,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp8_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim), + stride=(self.head_dim, 1), + ), + cute.make_layout((1,)), + ) + sV_fp8_layout = cute.append( + cute.make_layout( + (self.head_dim, self.n_block_size), + stride=(1, self.head_dim), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.p_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms + # ------------------------------------------------------------------ + k_tma_layout = ( + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2]) + ) + v_tma_layout = ( + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2]) + ) + kv_tma_bytes = ( + cute.size_in_bytes(self.k_input_dtype, k_tma_layout) + + cute.size_in_bytes(self.v_input_dtype, v_tma_layout)) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + if const_expr(self.k_fp8_to_bf16): + tma_atom_K, mK = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp8_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim), + ) + else: + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + if const_expr(self.v_fp8_to_bf16): + tma_atom_V, mV = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp8_layout, mode=[0, 1]), + (self.head_dim, self.n_block_size), + ) + else: + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for unified kernel signature. Small-GQA Q load + # uses raw gather4 and keeps mQ_2d as a plain row-major GMEM tensor. + tma_atom_Q = tma_atom_V + else: + tma_atom_Q, mQ_2d = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + if const_expr(self.k_fp8_to_bf16): + mbar_k_tma: cute.struct.MemRange[Int64, 2] + if const_expr(self.v_fp8_to_bf16): + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + if const_expr(self.k_fp8_to_bf16): + sKFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.k_input_dtype, cute.cosize(sK_fp8_layout) + ], + self.buffer_align_bytes] + if const_expr(self.v_fp8_to_bf16): + sVFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.v_input_dtype, cute.cosize(sV_fp8_layout) + ], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp8_layout, sV_fp8_layout, tP_layout, + tma_atom_K, tma_atom_V, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + kv_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + tma_K: cute.Tensor, + tma_V: cute.Tensor, + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp8_layout: cute.Layout, + sV_fp8_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atoms + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + kv_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + if const_expr(self.k_fp8_to_bf16): + sKFp8 = storage.sKFp8.get_tensor(sK_fp8_layout) + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + if const_expr(self.v_fp8_to_bf16): + sVFp8 = storage.sVFp8.get_tensor(sV_fp8_layout) + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_tma_bytes = cute.size_in_bytes( + self.k_input_dtype, + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2])) + v_tma_bytes = cute.size_in_bytes( + self.v_input_dtype, + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + if const_expr(self.k_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_k_ptr, k_tma_bytes) + if const_expr(self.v_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_v_ptr, v_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + if const_expr(self.kv_fp8_to_bf16): + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + if const_expr(self.k_fp8_to_bf16): + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if const_expr(self.v_fp8_to_bf16): + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if warp_idx == Int32(self.total_warps - 1): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + if const_expr(self.kv_fp8_to_bf16): + self._wg_load_kv_maybe_cast( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + sKFp8 if const_expr(self.k_fp8_to_bf16) else None, + sVFp8 if const_expr(self.v_fp8_to_bf16) else None, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + mbar_k_tma_ptr if const_expr(self.k_fp8_to_bf16) else None, + mbar_v_tma_ptr if const_expr(self.v_fp8_to_bf16) else None, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + else: + self._wg_load_kv( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.k_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sKFp8, + sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + False, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.v_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sVFp8, + sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + True, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _convert_fp8x16_to_bf16x16( + self, + src: cute.Tensor, + dst: cute.Tensor, + ): + src_i32 = cute.recast_tensor(src, cutlass.Int32) + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(4): + ( + dst_i32[word_idx * 2], + dst_i32[word_idx * 2 + 1], + ) = utils.cvt_fp8x4_e4m3_bf16x4(src_i32[word_idx]) + + @cute.jit + def _convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + elems_per_load: cutlass.Constexpr[int] = 16 + elems_per_store: cutlass.Constexpr[int] = 8 + chunks_per_row: cutlass.Constexpr[int] = self.head_dim // elems_per_load + r_fp8 = cute.make_rmem_tensor((elems_per_load,), cutlass.Float8E4M3FN) + r_bf16 = cute.make_rmem_tensor((elems_per_load,), cutlass.BFloat16) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * chunks_per_row + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(chunks_per_row) + chunk = task_idx - row * Int32(chunks_per_row) + col = chunk * Int32(elems_per_load) + smem_offset = row * Int32(self.head_dim) + col + s_fp8_ptr = cute.make_ptr( + cutlass.Float8E4M3FN, + sFp8.iterator.toint() + Int64(smem_offset), + mem_space=sFp8.iterator.memspace, + assumed_align=elems_per_load, + ) + s_fp8_vec = cute.make_tensor( + s_fp8_ptr, + cute.make_layout(elems_per_load), + ) + cute.autovec_copy(s_fp8_vec, r_fp8) + self._convert_fp8x16_to_bf16x16(r_fp8, r_bf16) + if const_expr(is_v): + sBf16_view = sBf16[(None, row % Int32(16)), 0, row // Int32(16), 0] + sBf16_vec = cute.local_tile(sBf16_view, (elems_per_load,), (chunk,)) + else: + sBf16_vec = sBf16[ + (row, None), + 0, + (chunk % Int32(4), chunk // Int32(4)), + 0, + ] + r_tiles = cute.logical_divide(r_bf16, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sBf16_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_load // elems_per_store): + cute.autovec_copy(r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv_maybe_cast( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sKFp8: Optional[cute.Tensor], + sVFp8: Optional[cute.Tensor], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + mbar_k_tma_ptr: Optional[cutlass.Pointer], + mbar_v_tma_ptr: Optional[cutlass.Pointer], + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.k_fp8_to_bf16): + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, + 0, + cute.make_layout(1), + gK, + sKFp8, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + else: + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.v_fp8_to_bf16): + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + gV, + sVFp8, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + else: + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + mbar_tma_ptr, + mbar_ready_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + if has_work: + cute.arch.mbarrier_wait(mbar_tma_ptr, 0) + self._convert_fp8_kv_to_bf16_smem( + sFp8, + sBf16, + lane, + warp_idx_in_wg, + num_dequant_warps, + is_v, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_ready_ptr) + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if producer_warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.p_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.p_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (p_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / p_dtype.width`` packed fp32 TMEM columns + # ``// (p_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.p_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1a8d6bf92b16d2943aa5e40fd91e26224ac40 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py @@ -0,0 +1,3305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel with NVFP4 K/V. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- BF16 Q +- packed NVFP4 K/V data +- E4M3 per-1x16 K/V scales in cuBLAS/cuDNN 128x4 tiled layout +- FP32 per-tensor K/V global scales +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardNvfp4KvSm100: + """SM100 sparse attention forward kernel with NVFP4 K/V.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + fp8_pair_dequant: bool = True, + has_k_global_scale: bool = True, + has_v_global_scale: bool = True, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardNvfp4KvSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardNvfp4KvSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.fp8_pair_dequant = fp8_pair_dequant + self.has_k_global_scale = has_k_global_scale + self.has_v_global_scale = has_v_global_scale + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardNvfp4KvSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P is bf16 and starts halfway into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mV: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mKScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened K rows and dim/16 cols + mVScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened V rows and dim/16 cols + mKGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mVGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_cache_dtype = mK.element_type + self.v_cache_dtype = mV.element_type + self.k_scale_dtype = mKScale.element_type + self.v_scale_dtype = mVScale.element_type + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"KVFP4 forward requires BF16 or FP8 E4M3 Q, got {self.q_dtype}") + self.k_dtype = self.q_dtype + self.v_dtype = self.q_dtype + if const_expr(self.k_cache_dtype is not cutlass.Uint8 or self.v_cache_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects packed uint8 K/V, got {self.k_cache_dtype}, {self.v_cache_dtype}" + ) + if const_expr(self.k_scale_dtype is not cutlass.Uint8 or self.v_scale_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects uint8 E4M3 scales, got {self.k_scale_dtype}, {self.v_scale_dtype}" + ) + if const_expr(self.has_k_global_scale and mKGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 K global scale") + if const_expr(self.has_v_global_scale and mVGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 V global scale") + self.mma_kind = "f8f6f4" if const_expr(self.q_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.q_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV, mKScale, mVScale = [ + assume_tensor_aligned(t) for t in (mK, mV, mKScale, mVScale) + ] + + if const_expr(not self.paged_kv): + # Flat varlen K/V: + # K: [total_k, h, d/2] -> [total_k, d/2, h]. + # V: [total_k, h, d/2] -> [d/2, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Host input is [page, head, token, dim/2]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d/2,h,b) -> (d/2,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp4_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim // 2), + stride=(self.head_dim // 2, 1), + ), + cute.make_layout((1,)), + ) + sV_fp4_layout = cute.append( + cute.make_layout( + (self.head_dim // 2, self.n_block_size), + stride=(1, self.head_dim // 2), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms. Packed FP4 K/V are staged by TMA, then dequantized into + # BF16 MMA SMEM layout by the KV load warps. + # ------------------------------------------------------------------ + k_fp4_tma_bytes = cute.size_in_bytes( + self.k_cache_dtype, cute.select(sK_fp4_layout, mode=[0, 1])) + v_fp4_tma_bytes = cute.size_in_bytes( + self.v_cache_dtype, cute.select(sV_fp4_layout, mode=[0, 1])) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_atom_K_fp4, mK_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp4_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim // 2), + ) + tma_atom_V_fp4, mV_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp4_layout, mode=[0, 1]), + (self.head_dim // 2, self.n_block_size), + ) + mK = mK_tma + mV = mV_tma + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for the unified kernel signature. Small-GQA Q + # loading uses raw gather4, so mQ_2d must stay as the plain GMEM + # tensor. The placeholder uses the natural SMEM top-level shape. + tma_atom_Q, _ = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (8, q_load_tile)) + else: + tma_atom_Q, mQ_2d_tma = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + mQ_2d = mQ_2d_tma + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + mbar_k_tma: cute.struct.MemRange[Int64, 2] + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + sKFp4: cute.struct.Align[ + cute.struct.MemRange[self.k_cache_dtype, cute.cosize(sK_fp4_layout)], + self.buffer_align_bytes] + sVFp4: cute.struct.Align[ + cute.struct.MemRange[self.v_cache_dtype, cute.cosize(sV_fp4_layout)], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mKScale, mVScale, mKGlobalScale, mVGlobalScale, + mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp4_layout, sV_fp4_layout, tP_layout, + tma_atom_K_fp4, tma_atom_V_fp4, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + k_fp4_tma_bytes, v_fp4_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp4_layout: cute.Layout, + sV_fp4_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atom + tma_atom_K_fp4: cute.CopyAtom, + tma_atom_V_fp4: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + k_fp4_tma_bytes: cutlass.Constexpr[int], + v_fp4_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sKFp4 = storage.sKFp4.get_tensor(sK_fp4_layout) + sVFp4 = storage.sVFp4.get_tensor(sV_fp4_layout) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_smem_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + v_smem_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_fp4_tma_bytes) + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_fp4_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if ( + warp_idx == Int32(self.total_warps - 1) + and warp_idx >= Int32(self.kv_load_warp_base + self.num_kv_load_warps) + ): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + q_group_start = Int32(0) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + self._wg_load_kv( + tma_atom_K_fp4, tma_atom_V_fp4, + mK, mV, + mKScale, mVScale, + mKGlobalScale, mVGlobalScale, + sPagedKvIdx, + sKFp4, sVFp4, sK, sV, + mbar_k_tma_ptr, mbar_v_tma_ptr, + mbar_k_ptr, mbar_v_ptr, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + num_heads_kv, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_k_from_tma_staging( + mKScale, + mKGlobalScale, + sPagedKvIdx, + sKFp4, sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_v_from_tma_staging( + mVScale, + mVGlobalScale, + sPagedKvIdx, + sVFp4, sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _scale_128x4_offset( + self, + row: Int32, + col: Int32, + scale_cols: cutlass.Constexpr[int], + ) -> Int32: + tiles_n: cutlass.Constexpr[int] = (scale_cols + 3) // 4 + tile_m = row // Int32(128) + tile_n = col // Int32(4) + outer = row % Int32(128) + inner = col % Int32(4) + return ( + (tile_m * Int32(tiles_n) + tile_n) * Int32(512) + + (outer % Int32(32)) * Int32(16) + + (outer // Int32(32)) * Int32(4) + + inner + ) + + @cute.jit + def _load_scale_bf16x2( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return utils.cvt_fp8_e4m3_to_bf16x2_replicated(cutlass.Int32(scale_byte)) + + @cute.jit + def _load_scale_e4m3_u8( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return cutlass.Int32(scale_byte) + + @cute.jit + def _dequant_fp4x16_to_bf16( + self, + src_words: cute.Tensor, + combined_scale_bf16x2: Int32, + dst: cute.Tensor, + ): + r_bf16 = cute.make_rmem_tensor((2,), cutlass.BFloat16) + r_bf16_i32 = cute.recast_tensor(r_bf16, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3 = utils.cvt_fp4x8_e2m1_bf16x8( + src_words[word_idx] + ) + bf16_pairs = (bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3) + for pair_idx in cutlass.range_constexpr(4): + r_bf16_i32[0] = utils.mul_bf16x2( + bf16_pairs[pair_idx], + combined_scale_bf16x2, + ) + dst[word_idx * 8 + 2 * pair_idx + 0] = r_bf16[0] + dst[word_idx * 8 + 2 * pair_idx + 1] = r_bf16[1] + + @cute.jit + def _dequant_fp4x16_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + + @cute.jit + def _dequant_fp4x32_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3_lo: Int32, + scale_e4m3_hi: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3_lo, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx + 2], + scale_e4m3_hi, + ) + dst_i32[word_idx * 2 + 4] = fp8_lo + dst_i32[word_idx * 2 + 5] = fp8_hi + + @cute.jit + def _flat_kv_scale_row( + self, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return token_idx * num_heads_kv + head_kv_idx + + @cute.jit + def _paged_kv_scale_row( + self, + page_idx: Int32, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return (page_idx * num_heads_kv + head_kv_idx) * Int32(self.page_size) + token_idx + + @cute.jit + def _load_k_fp4_to_smem( + self, + sKFp4: cute.Tensor, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mKScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sK_vec = sK[(row, None), 0, pair_col, 0] + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.k_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.k_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.k_dtype, + num_bits_per_copy=elems_per_store * self.k_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + else: + combined_bf16x2 = self._load_scale_bf16x2(mKScale, scale_row, scale_col) + if const_expr(self.has_k_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mKGlobalScale[0], + mKGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + sK_cols = sK[(row, None), 0, scale_col // Int32(2), 0] + sK_vec = cute.local_tile( + sK_cols, + (elems_per_block,), + (scale_col % Int32(2),), + ) + else: + sK_vec = sK[ + (row, None), + 0, + (scale_col % Int32(4), scale_col // Int32(4)), + 0, + ] + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _load_v_fp4_to_smem( + self, + sVFp4: cute.Tensor, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sV: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mVScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_pair,), (pair_col,)) + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.v_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.v_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.v_dtype, + num_bits_per_copy=elems_per_store * self.v_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + combined_bf16x2 = self._load_scale_bf16x2(mVScale, scale_row, scale_col) + if const_expr(self.has_v_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mVGlobalScale[0], + mVGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + else: + sV_cols = sV[(None, row % Int32(16)), 0, row // Int32(16), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_block,), (scale_col,)) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K_fp4, + tma_atom_V_fp4, + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sVFp4: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mbar_k_tma_ptr, + mbar_v_tma_ptr, + mbar_k_ptr, + mbar_v_ptr, + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.paged_kv): + mK_cur = mK[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + mK[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K_fp4, + 0, + cute.make_layout(1), + gK, + sKFp4, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.paged_kv): + mV_cur = mV[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + mV[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V_fp4, + 0, + cute.make_layout(1), + gV, + sVFp4, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_dequant_k_from_tma_staging( + self, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sK: cute.Tensor, + mbar_k_tma_ptr, + mbar_k_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_k_tma_ptr, 0) + self._load_k_fp4_to_smem( + sKFp4, + mKScale, + mKGlobalScale, + sPagedKvIdx, + sK, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + @cute.jit + def _wg_dequant_v_from_tma_staging( + self, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sVFp4: cute.Tensor, + sV: cute.Tensor, + mbar_v_tma_ptr, + mbar_v_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_v_tma_ptr, 0) + self._load_v_fp4_to_smem( + sVFp4, + mVScale, + mVGlobalScale, + sPagedKvIdx, + sV, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if const_expr(do_final_acquire) and producer_warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if const_expr(do_final_acquire) and warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_k_global_scale + ): + k_global = mKGlobalScale[0] + for i in cutlass.range_constexpr(0, cute.size(tSrS_t2r.shape), 2): + tSrS_t2r[i], tSrS_t2r[i + 1] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[i], tSrS_t2r[i + 1]), + (k_global, k_global), + ) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.q_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (q_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / q_dtype.width`` packed fp32 TMEM columns + # ``// (q_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.q_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + mVGlobalScale, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + mVGlobalScale: Optional[cute.Tensor], + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..a3894130432f6483291fe23c064efa7369f6d509 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py @@ -0,0 +1,1498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse forward combine kernel and public launcher. + +This keeps the local fake-layout -> real-layout epilogue needed by the lean +sparse forward path. +""" + +# Modified Step 7: O_out write with SMEM fake->real column permutation. +# O_partial dim is in STG.128 fake layout; O_out dim is real layout. +import math +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, Int64, Boolean, const_expr + +from ....src.common import utils +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + +from ....src.common.pack_gqa import PackGQAComb +from ....src.common.tma_utils import ( + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, +) + + +class SparseAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + tile_m: int = 8, + k_block_size: int = 64, + topk: int = 16, + num_threads: int = 256, + stages: int = 4, + use_pdl: bool = False, + min_blocks_per_mp: int = 0, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param tile_m: m block size + :param k_block_size: k block size + :param topk: exact number of split partials + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.topk = topk + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + self.use_pdl = use_pdl + self.min_blocks_per_mp = min_blocks_per_mp + self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16) + self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + tile_m, + k_block_size, + topk, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [ + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + Float32, + ]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if tile_m % 8 != 0: + return False + if topk > 256: + return False + if (tile_m * topk) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store). + # Keep this independent from O_partial: fp8 partial uses 16 elements + # per 128b transaction, while bf16/fp16 O stores must remain 8-wide. + output_copy_elems = universal_copy_bits // self.dtype.width + assert self.k_block_size % output_copy_elems == 0 + gmem_threads_per_row_o = k_block_gmem // output_copy_elems + assert self.num_threads % gmem_threads_per_row_o == 0 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o), + order=(1, 0), + ) + vO_layout = cute.make_layout((1, output_copy_elems)) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tO_layout, + vO_layout, + ) + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.topk, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1) + ) + + # O_partial staging layout. + if const_expr( + self.dtype_partial + in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN] + ): + smem_layout_atom_o = _get_cpasync_smem_layout_atom( + self.dtype_partial, self.k_block_size + ) + self.smem_layout_o = cute.tile_to_shape( + smem_layout_atom_o, + (self.tile_m, self.k_block_size, self.stages), + (0, 1, 2), + ) + else: + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + mLSE_temperature_partial: Optional[cute.Tensor] = None, + mLSE_temperature: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + mSplitCounts: Optional[cute.Tensor] = None, + mOutputScale: Optional[cute.Tensor] = None, + qhead_per_kvhead: Int32 = Int32(1), + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr( + mLSE_temperature_partial is not None + and mLSE_temperature_partial.element_type not in [Float32] + ): + raise TypeError("temperature LSE partial tensor must be Float32") + if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]): + raise TypeError("temperature LSE tensor must be Float32") + if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)): + raise ValueError( + "temperature LSE partial and output tensors must either both be provided or both be None" + ) + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]): + raise ValueError( + "temperature LSE partial tensor must have 3 or 4 dimensions: " + "(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]): + raise ValueError( + "temperature LSE tensor must have 2 or 3 dimensions: " + "(batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mSplitCounts is not None): + if const_expr(mSplitCounts.element_type not in [Int32]): + raise TypeError("split_counts tensor must be Int32") + if const_expr(cu_seqlens is not None): + if const_expr(len(mSplitCounts.shape) != 2): + raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)") + elif const_expr(len(mSplitCounts.shape) != 3): + raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)") + if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]): + raise TypeError("output_scale tensor must be Float32") + + mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b) + # Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads. + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + mLSE_temperature_partial = ( + cute.make_tensor( + mLSE_temperature_partial.iterator, + cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose), + ) + if mLSE_temperature_partial is not None + else None + ) + mLSE_temperature = ( + cute.make_tensor( + mLSE_temperature.iterator, + cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose), + ) + if mLSE_temperature is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + # Output-dtype permutation buffer for Step 7 (tile_m × k_block_size). + # Accumulation stays fp32; the final dtype conversion happens before + # the fake→real SMEM scatter to reduce half-output SMEM pressure. + if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]): + smem_layout_perm = cute.make_layout( + (self.tile_m, self.k_block_size), + stride=(self.k_block_size + 16, 1), + ) + else: + smem_layout_perm = cute.make_ordered_layout( + (self.tile_m, self.k_block_size), order=(1, 0) + ) + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sLSETemperature: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + sO_perm: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch) + # Head separated from seqlen → enables future TMA (contiguous Sq tiles) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + varlen_batch_idx, + semaphore_to_reset, + mSplitCounts, + mOutputScale, + qhead_per_kvhead, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + smem_layout_perm, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + self.use_pdl, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + min_blocks_per_mp=self.min_blocks_per_mp, + use_pdl=self.use_pdl, + ) + + @cute.jit + def decode_flat_row_idx( + self, + idx: Int32, + head_divmod: FastDivmodDivisor, + ): + """Decode flattened tile rows under the H_q-innermost contract.""" + q_idx_local, head_idx = divmod(idx, head_divmod) + return q_idx_local, head_idx + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSE_temperature_partial: Optional[cute.Tensor], + mLSE_temperature: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + mSplitCounts: Optional[cute.Tensor], + mOutputScale: Optional[cute.Tensor], + qhead_per_kvhead: Int32, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout | cute.ComposedLayout, + smem_layout_perm: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + use_pdl: cutlass.Constexpr[bool], + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm) + + # Handle semaphore reset — wait for dependent grids first + if const_expr(use_pdl and semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 + ): + cute.arch.griddepcontrol_wait() + semaphore_to_reset[0] = 0 + + if const_expr(num_splits_dynamic_ptr is not None): + raise ValueError("K2 combine requires compile-time exact topK") + num_splits = Int32(self.topk) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + output_scale = Float32(1.0) + if const_expr(mOutputScale is not None): + output_scale = mOutputScale[0] + + if const_expr(not varlen) or m_block * self.tile_m < max_idx: + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + if const_expr(use_pdl): + cute.arch.griddepcontrol_wait() + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + # `cLSE` (identity tensor for row/split coord tracking) is reused + # later in steps 4-5, so it must be defined on both branches. + cLSE = cute.make_identity_tensor((self.topk, self.tile_m)) + # Reshape mLSE_partial to PackGQA packed layout and delegate the + # tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq) + # into one compound dim with H_q innermost (stride 1), so thread + # rows that vary along h_pos produce one-sector coalesced reads. + # Non-varlen path only — varlen keeps the original inline loop. + if const_expr(not varlen): + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + # mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose + # [3,0,2,1] on [topK,B,Sq,H_q] and dropping B. + # Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim: + mLSE_partial_reord = cute.make_tensor( + mLSE_partial_cur.iterator, + cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2) + # shape ((H_q, Sq), topK) with H_q innermost. + packgqa = PackGQAComb( + m_block_size=self.tile_m, + head_dim_padded=0, # unused for LSE load + check_hdim_oob=False, # unused for LSE load + qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly + ) + packgqa.load_LSE( + mLSE_partial_packed, + sLSE, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_reord = cute.make_tensor( + mLSE_temperature_partial_cur.iterator, + cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_temperature_partial_packed = cute.group_modes( + mLSE_temperature_partial_reord, 0, 2) + packgqa.load_LSE( + mLSE_temperature_partial_packed, + sLSE_temperature, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + else: + # Varlen path keeps the same H_q-innermost flat-row contract: + # after transpose [1, 0, 2], mLSE_partial_cur is + # (q_local, split, head). + # mSplitCounts is the authoritative valid-split count per + # packed (q_abs, kv_head); masked splits stay at -inf and + # therefore drop out of the final kernel LSE_out reduction. + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_copy = cute.tiled_divide( + mLSE_temperature_partial_cur, (1,)) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + row_count = ( + mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead] + if const_expr(mSplitCounts is not None) + else num_splits + ) + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur_copy = ( + mLSE_temperature_partial_copy[None, m_idx, None, head_idx]) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + if const_expr(mLSE_temperature_partial is not None): + cute.copy( + gmem_thr_copy_LSE, + mLSE_temperature_partial_cur_copy[None, si], + tLSEsLSE_temperature[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) + + # Precompute per-row values for flattened (q_local, head) tiles. + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate in tile + idx = m_block * self.tile_m + mi + if idx >= max_idx: + tOhidx[m] = -1 + tOmidx[m] = 0 + tOSplitCount[m] = 0 + tOrOptr[m] = cutlass.Int64(0) + else: + tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod) + if const_expr(mSplitCounts is None): + tOSplitCount[m] = num_splits + elif const_expr(cu_seqlens is None): + tOSplitCount[m] = mSplitCounts[ + batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + else: + tOSplitCount[m] = mSplitCounts[ + offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + tOrOptr[m] = utils.elem_pointer( + mO_partial_cur, + (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]), + ).toint() + + tOpO = None + if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOSplitCount, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + if const_expr(mLSE_temperature_partial is not None): + ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature) + ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature) + cute.copy( + s2r_tiled_copy_LSE, + ts2rsLSE_temperature, + ts2rrLSE_temperature, + ) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row. Invalid splits + # have already been filled with -inf, so Step 5 can write the + # kernel-native LSE_out directly. + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + # Normalize scales + inv_sum = 0.0 + if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur: + final_lse[m] = -Float32.inf + else: + final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + if const_expr(mLSE_temperature_partial is not None): + final_lse_temperature = cute.make_rmem_tensor( + cute.size(ts2rrLSE_temperature, mode=[2]), Float32) + for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_temperature_max = cute.arch.warp_reduction_max( + ts2rrLSE_temperature[None, None, m] + .load() + .reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ), + threads_in_group=threads_per_col, + ) + lse_temperature_max_cur = ( + 0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max + ) + LOG2_E = math.log2(math.e) + lse_temperature_sum_cur = 0.0 + for s in cutlass.range( + cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE_temperature[0, s, m] * LOG2_E + - (lse_temperature_max_cur * LOG2_E), + fastmath=True, + ) + lse_temperature_sum_cur += scale + lse_temperature_sum_cur = cute.arch.warp_reduction_sum( + lse_temperature_sum_cur, threads_in_group=threads_per_col + ) + if ( + max_valid_split[m] < 0 + or lse_temperature_sum_cur == 0.0 + or lse_temperature_sum_cur != lse_temperature_sum_cur + ): + final_lse_temperature[m] = -Float32.inf + else: + final_lse_temperature[m] = ( + cute.math.log(lse_temperature_sum_cur, fastmath=True) + + lse_temperature_max + ) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.tile_m: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # This writeback is the authoritative LSE_out returned by the + # public Sparse Attention / Sparse Page Attention interface. + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + mLSE_cur = mLSE[None, None, batch_idx] + else: + mLSE_cur = cute.domain_offset((offset, 0), mLSE) + if const_expr(mLSE_temperature is not None): + if const_expr(cu_seqlens is None): + mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx] + else: + mLSE_temperature_cur = cute.domain_offset( + (offset, 0), mLSE_temperature) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mLSE_cur[m_idx, head_idx] = final_lse[m] + if const_expr(mLSE_temperature is not None): + mLSE_temperature_cur[m_idx, head_idx] = ( + final_lse_temperature[m]) + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # Flush any outstanding async-copy groups before the local Step-7 + # permutation buffer is read on the tail of the kernel. + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # =============================== + # Step 7: Write final O to gmem (fake→real via SMEM) + # =============================== + + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) + if const_expr(cu_seqlens is None): + mO_cur = mO[None, None, None, batch_idx] + else: + mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + num_vals = const_expr(cute.size(tOcO, mode=[0])) + if const_expr(not use_pdl): + # Direct / standalone calls don't participate in the K1->K2 + # dependency chain. Use a simple per-element real-column store + # path here to keep mixed-shape launches stable. + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO[k]: + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype) + else: + # 7a: fp32 accumulator -> output dtype SMEM with fake→real + # permutation. The dedicated permutation buffer stays separate + # from the O_partial pipeline staging buffer. + sO_perm = sO_perm_buf + + if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]): + # O_partial uses a dtype-specific STG.128 fake layout, but + # sO_perm is in the final O dtype. For all supported fake + # layouts, adjacent fake pairs map to adjacent real columns, + # so write the final BF16/F16 O pair as one 32-bit SMEM store. + assert num_vals % 2 == 0 + r2s_o_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=32, + ) + rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32) + sO_perm_i32_base = cute.make_ptr( + dtype=cutlass.Int32, + value=sO_perm.iterator.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v_pair in cutlass.range(num_vals // 2, unroll_full=True): + v = v_pair * 2 + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o0 = tOrO[v, m, k] + o1 = tOrO[v + 1, m, k] + if const_expr(mOutputScale is not None): + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), + (output_scale, output_scale), + ) + rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype) + smem_pair_ptr = cute.make_ptr( + dtype=cutlass.Int32, + value=( + sO_perm_i32_base.toint() + + Int64( + row_local * sO_perm_i32_row_stride + + real_col // Int32(2) + ) + * Int64(4) + ), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_pair = cute.make_tensor( + smem_pair_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair) + else: + # 7a: iterate over ALL val elements in mode[0]. + # tOcO[v, m, k][1] gives different fake_col for each v. + r2s_o_scalar_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, + ) + rO_scalar = cute.make_rmem_tensor((1,), self.dtype) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + rO_scalar[0] = o_val.to(self.dtype) + smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col)) + smem_scalar_ptr = cute.make_ptr( + dtype=self.dtype, + value=smem_ptr.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=self.dtype.width // 8, + ) + sO_scalar = cute.make_tensor( + smem_scalar_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar) + + cute.arch.sync_threads() + + # 7b: SMEM (real order, output dtype) → GMEM + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOcO_store = gmem_thr_copy_O.partition_D(cO) + tOsO_store = gmem_thr_copy_O.partition_D(sO_perm) + rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + num_store_rows = const_expr(cute.size(tOcO_store, mode=[1])) + num_store_vals = const_expr(cute.size(tOcO_store, mode=[0])) + tOpO_store = None + if const_expr(not self.is_even_k): + tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO_store), unroll_full=True): + tOpO_store[k] = ( + tOcO_store[0, 0, k][1] + < mO_partial.shape[1] - k_block * self.k_block_size + ) + + # Read output dtype from SMEM (now in real column order). + for m in cutlass.range(num_store_rows, unroll_full=True): + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k]) + + # Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7) + for m in cutlass.range(num_store_rows, unroll_full=True): + row_local = tOcO_store[0, m, 0][0] + idx = m_block * self.tile_m + row_local + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mO_cur_copy = cute.tiled_divide( + mO_cur[m_idx, None, head_idx], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + k_idx = tOcO_store[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOSplitCount: cute.Tensor, + tOpO: Optional[cute.Tensor], + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]): + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_cur_copy[None, k_idx, split], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, k].fill(0) + + +def _get_cutlass_dtype(torch_dtype: torch.dtype): + if torch_dtype not in torch2cute_dtype_map: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + return torch2cute_dtype_map[torch_dtype] + + +_combine_compile_cache = {} + + +def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = const_expr(dtype.width // 8) + bytes_per_row = const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout( + (8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), + order=(1, 0), + ), + ) + + +def combine( + o_partial_fake, + lse_partial, + o_out, + lse_out, + *, + lse_temperature_partial=None, + lse_temperature_out=None, + cu_seqlens=None, + seqused=None, + split_counts=None, + output_scale=None, + use_pdl=False, +): + """K2: merge sparse forward split partials into the final output. + + STG.128 fake-layout handling remains an internal implementation detail. + When lse_out is provided, the kernel writes the final authoritative + log-sum-exp for each query row/head directly into that tensor. + + Args: + o_partial_fake: + Batched: [num_splits, batch, Sq, head_q, dim] + Varlen: [num_splits, total_q, head_q, dim] + lse_partial: + Batched: [num_splits, batch, Sq, head_q] + Varlen: [num_splits, total_q, head_q] + o_out: + Batched: [batch, Sq, head_q, dim] + Varlen: [total_q, head_q, dim] + lse_out: + Batched: [batch, Sq, head_q] + Varlen: [total_q, head_q] + lse_temperature_partial: + Optional temperature-scaled LSE partial with the same shape as + lse_partial. + lse_temperature_out: + Optional temperature-scaled final LSE with the same shape as + lse_out. + cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine. + seqused: Optional [batch] int32 effective lengths for combine. + split_counts: Optional int32 rowwise valid split counts prepared from + q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen: + [total_q, head_kv]. + output_scale: Optional fp32 tensor with at least one element. When + provided, the final O accumulator is multiplied once before store. + use_pdl: When True, wait on PDL dependencies from the producer K1 + kernel. When False, launch without PDL waits. + """ + D = o_partial_fake.shape[-1] + num_splits = o_partial_fake.shape[0] + return_temperature_lse = ( + lse_temperature_partial is not None or lse_temperature_out is not None + ) + if (lse_temperature_partial is None) != (lse_temperature_out is None): + raise ValueError( + "lse_temperature_partial and lse_temperature_out must either both be provided or both be None" + ) + if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape: + raise ValueError( + "lse_temperature_partial must have the same shape as lse_partial, " + f"got {lse_temperature_partial.shape} vs {lse_partial.shape}" + ) + if lse_temperature_out is not None: + if lse_out is None: + raise ValueError("lse_temperature_out requires lse_out") + if lse_temperature_out.shape != lse_out.shape: + raise ValueError( + "lse_temperature_out must have the same shape as lse_out, " + f"got {lse_temperature_out.shape} vs {lse_out.shape}" + ) + if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32: + raise TypeError("temperature LSE tensors must be torch.float32") + + partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype) + out_dtype = _get_cutlass_dtype(o_out.dtype) + if output_scale is not None: + if output_scale.dtype != torch.float32: + raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}") + if output_scale.numel() < 1: + raise ValueError("output_scale must contain at least one element") + if output_scale.device != o_out.device: + raise ValueError("output_scale must be on the same device as o_out") + output_scale = output_scale.contiguous() + if split_counts is not None: + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_out.ndim == 4: + if split_counts.ndim != 3: + raise ValueError( + f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[:2] != o_out.shape[:2]: + raise ValueError( + f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}" + ) + else: + if cu_seqlens is None: + raise ValueError("split_counts with varlen output requires cu_seqlens") + if split_counts.ndim != 2: + raise ValueError( + f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[0] != o_out.shape[0]: + raise ValueError( + f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q " + f"({o_out.shape[0]})" + ) + if o_out.shape[-2] % split_counts.shape[-1] != 0: + raise ValueError( + f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})" + ) + qheadperkv = o_out.shape[-2] // split_counts.shape[-1] + else: + qheadperkv = 1 + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous") + if seqused is not None: + if seqused.dtype != torch.int32: + raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}") + if seqused.ndim != 1: + raise ValueError(f"seqused must be rank-1, got {seqused.shape}") + if not seqused.is_contiguous(): + raise ValueError("seqused must be contiguous") + + k_block_size = 128 if D > 64 else 64 + tile_m = 64 + has_cu_seqlens = cu_seqlens is not None + has_seqused = seqused is not None + has_lse = lse_out is not None + has_split_counts = split_counts is not None + has_output_scale = output_scale is not None + min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0 + + key = ( + "combine", + D, + k_block_size, + tile_m, + num_splits, + partial_dtype, + out_dtype, + has_cu_seqlens, + has_seqused, + has_lse, + bool(return_temperature_lse), + has_split_counts, + has_output_scale, + use_pdl, + min_blocks_per_mp, + ) + if key not in _combine_compile_cache: + from ....src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _combine_compile_cache[key] = loaded + else: + from ....quack.compile_utils import make_fake_tensor + + kernel = SparseAttentionForwardCombine( + dtype=out_dtype, + dtype_partial=partial_dtype, + head_dim=D, + tile_m=tile_m, + k_block_size=k_block_size, + topk=num_splits, + use_pdl=use_pdl, + min_blocks_per_mp=min_blocks_per_mp, + # stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM, + # theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35% + # -> 88.64%. Runtime latency within noise (kernel already at HBM + # bandwidth ceiling in practice) but the cleaner SOL profile + # matters for downstream NCU comparison. + stages=2, + ) + div = 128 // partial_dtype.width + if has_cu_seqlens: + total_q, nheads = (cute.sym_int64() for _ in range(2)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, total_q, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + mO = make_fake_tensor( + out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if return_temperature_lse + else None + ) + else: + batch, sq, nheads = (cute.sym_int64() for _ in range(3)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + mO = make_fake_tensor( + out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if return_temperature_lse + else None + ) + if not has_split_counts: + mSplitCounts = None + elif has_cu_seqlens: + total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2)) + mSplitCounts = make_fake_tensor( + Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1 + ) + else: + nheads_kv = cute.sym_int64() + mSplitCounts = make_fake_tensor( + Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2 + ) + mOutputScale = ( + make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0) + if has_output_scale + else None + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + _combine_compile_cache[key] = cute.compile( + kernel, + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + None + if cu_seqlens is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None + if seqused is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None, + None, + None, + mSplitCounts, + mOutputScale, + Int32(qheadperkv), + stream, + options="--enable-tvm-ffi", + ) + save_aot(key, _combine_compile_cache[key]) + + with torch.cuda.nvtx.range("K2_Combine"): + _combine_compile_cache[key]( + o_partial_fake, + lse_partial, + o_out, + lse_out, + lse_temperature_partial, + lse_temperature_out, + cu_seqlens, + seqused, + None, + None, + None, + split_counts, + output_scale, + qheadperkv, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64a0616bd5bb9c987e43b87bcbf9e89001fbb36 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""CUTE DSL launchers for paged fp8 decode forward.""" + +from __future__ import annotations + +import torch + +from .atten_fwd import run_decode_attention +from .combine import run_decode_combine + + +def decode_forward_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + merge_indptr: torch.Tensor, + O_partial: torch.Tensor | None, + LSE_partial: torch.Tensor | None, + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + max_split_count: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + O_partial_dummy: torch.Tensor | None = None, + LSE_partial_dummy: torch.Tensor | None = None, +) -> None: + """Launch dense paged fp8 decode forward and optional compressed combine. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated + placeholder buffers for the non-split path. When supplied, ``run_decode_attention`` + skips the per-call ``torch.empty`` it would otherwise need to satisfy the + kernel's positional arg signature, saving ~5us on small-kv calls. + """ + + run_decode_attention( + q, + k, + v, + page_table, + seqused_k, + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + o_indptr, + out, + lse, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(page_size), + kv_chunk_size_pages=int(kv_chunk_size_pages), + split_kv=bool(split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + O_partial_dummy=O_partial_dummy, + LSE_partial_dummy=LSE_partial_dummy, + ) + if split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode requires O_partial and LSE_partial") + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + run_decode_combine( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q=int(seqlen_q), + q_tokens_per_group=q_tokens_per_group, + max_split_count=int(max_split_count), + ) + + +__all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..9a56bb20363deffd4c850533484427bc128b3c84 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py @@ -0,0 +1,2691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Dense paged fp8 decode forward path. + +This file owns the CUTE DSL entry point for decode attention via +``SparseDecodeAttentionForwardSm100`` — SM100 UTCMMA + persistent +scheduling, paged fp8 Q/K/V, BSA blk128-style intra-warp overlap pipeline. +Forward only. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable, Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cutlass_dsl import BaseDSL +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from ....quack import copy_utils, layout_utils + +from ....src.common import pipeline +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +from ....src.common.pack_gqa import pack_gqa_layout +from ....src.common.tile_scheduler import SchedulingMode +from ....src.sm100.fwd_decode.tile_scheduler import ( + DecodeTileScheduler, + DecodeTileSchedulerArguments, +) + + +class SparseDecodeAttentionForwardSm100: + """SM100 dense paged fp8 decode forward attention (UTCMMA + CLC). + + Scope (Phase 1): + - Dense decode, ``split_kv=False``, single q-tile per work item + (``packed_q = seqlen_q * qhead_per_kv <= tile_m=128``). + - Causal only. KV reverse page loop; first reverse block applies + causal/seqlen mask, the rest is unmasked. + - fp8 Q/K/V, bf16 O, fp32 LSE. P is quantized to fp8_e4m3fn before PV + via ``SoftmaxSm100.apply_exp2_convert`` (mirror of prefill fp8 PV). + - per-batch ``mSeqUsedK[b]`` heterogeneous; no uniform-length assumptions. + + Production scope reached at Phase 4+: + - Multi q-tile (Phase 2), split-KV partial writeback (Phase 3), + CLC persistent scheduling (Phase 4), TC SOL >= 90% (Phase 7). + """ + + # UTCMMA K-tile width (matches prefill SparseAttentionForwardSm100). + k_tile = 64 + + def __init__( + self, + head_dim: int = 128, + qhead_per_kv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + page_size: int = 128, + split_kv: bool = False, + causal: bool = True, + write_lse: bool = True, + disable_softmax_exp2: bool = False, + ): + # --- structural constraints (Phase 1 scope) ------------------------- + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeAttentionForwardSm100 currently supports only D=128, " + f"got D={head_dim}" + ) + if m_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires tile_m=128, got {m_block_size}" + ) + if n_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires n_block_size=128, got {n_block_size}" + ) + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal n_block_size ({n_block_size})" + ) + if qhead_per_kv not in (16, 8, 4, 2, 1): + raise ValueError( + f"qhead_per_kv must be in {{1, 2, 4, 8, 16}}, got {qhead_per_kv}" + ) + if not causal: + raise NotImplementedError( + "decode UMMA forward currently supports only causal=True" + ) + + self.head_dim = int(head_dim) + self.qhead_per_kv = int(qhead_per_kv) + self.m_block_size = int(m_block_size) + self.n_block_size = int(n_block_size) + self.page_size = int(page_size) + self.tile_m = int(m_block_size) + self.split_kv = bool(split_kv) + self.causal = bool(causal) + self.write_lse = bool(write_lse) + self.disable_softmax_exp2 = bool(disable_softmax_exp2) + # FA fp8 SM100 fwd uses a threshold of 4.0 to avoid rescaling O for + # small row-max movements; correction receives acc_scale directly. + self.rescale_threshold = 4.0 + + # q tokens packed per (m_block_size) row group along M. + self.q_tokens_per_group = self.m_block_size // self.qhead_per_kv + + self.mma_tiler_qk = (self.m_block_size, self.n_block_size, self.head_dim) + self.mma_tiler_pv = (self.m_block_size, self.head_dim, self.n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # --- pipeline ring stages (BSA blk128 q_stage=1, s_stage=2) --- + self.q_stage = 1 + self.s_stage = 2 + self.o_stage = 2 + # Keep the fp8 decode KV ring deep enough to cover the K0/Q/K1/V0... + # order. This matches sage's fp8 setting and removes the underfed + # two-stage KV pipeline seen in the q8/16K non-split case. + self.kv_stage = 4 + self.k_stages = 2 + # Match prefill: PV is split at 3/4 of n_block_size for fp8. The + # producer (P store) must publish exactly 3N/4 fp8 columns at the + # signal point; that requires the TMEM-store atom Repetition to be + # ``8`` (one PV ``f8f6f4`` K=32 segment = 8 fp32 packed cols), so + # ``shape[2]=4`` chunks and ``split_idx=3`` lands on the 3N/4 + # boundary exactly. The previous N/2 cap was a workaround for + # ``Repetition(16)`` whose coarser chunk boundary could not + # represent 3N/4. + self.split_P_arrive = self.n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # --- warp layout (16 warps / 512 threads) — BSA-aligned (Phase 1.10.6b) + # 0-3 softmax WG 0 + # 4-7 softmax WG 1 + # 8-11 correction WG (acc_O rescale across pages + final epilogue + # write-back; participates in TmemPtr barrier) + # 12 MMA issue warp + # 13 spare / future CLC scheduler + # 14 load warp (serial Q + K + V TMA loads) + # 15 empty / register-budget reserve + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.correction_warp_base = ( + self.softmax1_warp_base + self.warps_per_group) + self.mma_warp_id = self.correction_warp_base + self.warps_per_group + self.spare_warp_id = self.mma_warp_id + 1 + self.load_warp_id = self.spare_warp_id + 1 + self.empty_warp_id = self.load_warp_id + 1 + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps + + # --- TMEM layout (fp8 P width-pack: 4 fp8 lanes per fp32 column) --- + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for head_dim_v=128 + # P (fp8) overlays the second half of each S tile via recast_ptr. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = self.n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * self.n_block_size + # fp8 P occupies n_block_size * fp8_width / fp32_width = n/4 fp32 cols. + # P offset is set in __call__ once q_dtype is known (defer to Phase 1.3). + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # --- register budget per role (BSA hdim>=96 default) --- + self.num_regs_softmax = 184 + self.num_regs_correction = 88 + self.num_regs_other = 56 + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_epilogue = self.num_regs_other + self.num_regs_empty = self.num_regs_other + + # exp2 emulation for causal: matches prefill ex2_emu_freq=16. + # disable_softmax_exp2 (Phase 7 SOL gate) bypasses both emulation and + # native exp2 — the convert pass becomes a pure fp32 -> fp8 cast. + self.ex2_emu_freq = 16 if (self.causal and not self.disable_softmax_exp2) else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # --- SM100 cluster config (single-CTA for decode, no 2-CTA pair) - + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + self.use_clc_scheduler = True + self.scheduling_mode = SchedulingMode.CLC + self.sched_stages = 2 + self.clc_scheduler_warp_id = self.empty_warp_id + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # Phase 1.2+ fills in the body. Phase 1.1 keeps signatures stable so + # the rest of the codepath (run_decode_attention dispatch in 1.10) + # can wire to this class without further churn. + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # [B, Sq, Hq, D] fp8 + mK: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mV: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mPageTable: cute.Tensor, # [B, max_pages] int32 + mSeqUsedK: cute.Tensor, # [B] int32 + mRequestIndices: cute.Tensor, # [work_capacity] int32 + mQoTileIndices: cute.Tensor, # [work_capacity] int32 + mKvTileIndices: cute.Tensor, # [work_capacity] int32 + mBlockValidMask: cute.Tensor, # [work_capacity] int32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] bf16 + mLSE: cute.Tensor, # [total_q, Hq] fp32 + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + softmax_scale: Float32, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + stream: cuda.CUstream = None, + ): + # --- dtype contract ------------------------------------------------ + if const_expr(mQ.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA Q must be Float8E4M3FN") + if const_expr(mK.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA K must be Float8E4M3FN") + if const_expr(mV.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA V must be Float8E4M3FN") + if const_expr(mO.element_type is not cutlass.BFloat16): + raise TypeError("decode UMMA output O must be BFloat16") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode UMMA output LSE must be Float32") + if const_expr(self.split_kv): + if const_expr(mO_partial is None or mO_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 O_partial") + if const_expr(mLSE_partial is None or mLSE_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 LSE_partial") + + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = ( + mO_partial.element_type if const_expr(self.split_kv) + else mO.element_type + ) + # f8f6f4 MMA descriptor kind for fp8 Q/K/V. + self.mma_kind = "f8f6f4" + # fp8 P width-pack ratio: each fp32 TMEM column holds 4 fp8 P lanes. + # Computed here so __init__ stays dtype-agnostic and the TMEM offsets + # can later be derived from this ratio in Phase 1.3. + elem_bytes = const_expr(self.q_dtype.width // 8) + p_cols_as_fp32 = const_expr( + self.n_block_size * self.q_dtype.width // Float32.width + ) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + + mQ, mK, mV, mO, mLSE = [ + assume_tensor_aligned(t) for t in (mQ, mK, mV, mO, mLSE) + ] + if const_expr(mO_partial is not None): + mO_partial = assume_tensor_aligned(mO_partial) + if const_expr(mLSE_partial is not None): + mLSE_partial = assume_tensor_aligned(mLSE_partial) + mO_epilogue = mO_partial if const_expr(self.split_kv) else mO + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO_epilogue) + self.epi_tile = (self.m_block_size, self.head_dim) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T + PV. PV uses MN-major V operand (V already + # transposed in the layout below) and a TMEM operand source for P. + # Phase 1.4 builds tiled_mma_qk; Phase 1.5 adds tiled_mma_pv so sV + # layout can derive the MN-major swizzle. + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # Paged K/V tensor view permutation. + # Input layout [num_pages, Hkv, page_size, D] (nhsd) is permuted to + # [page_size, D, Hkv, num_pages] for the paged TMA descriptor (K). + # V gets an additional (s,d) swap to become MN-major: + # [D, page_size, Hkv, num_pages]. + # ------------------------------------------------------------------ + mK_paged = cute.make_tensor( + mK.iterator, cute.select(mK.layout, mode=[2, 3, 1, 0]) + ) + mV_kv = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=[2, 3, 1, 0]) + ) + mV_paged = cute.make_tensor( + mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3]) + ) + + # ------------------------------------------------------------------ + # Q SMEM layout + BSA/FA PackGQA full-tile TMA atom. + # + # Runtime Q is [B, Sq, Hq, D]. We transpose to [Sq, D, Hq, B], then + # fold qhead_per_kv into the M dimension: + # ((qhead_per_kv, Sq), D, Hkv, B) + # This lets one Q TMA load cover the whole packed (tile_m, D) tile + # instead of issuing one TMA per q token. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + mQ = cute.make_tensor( + mQ.iterator, cute.select(mQ.layout, mode=[1, 3, 2, 0])) + nheads_kv = mK.shape[1] + mQ = pack_gqa_layout(mQ, self.qhead_per_kv, nheads_kv, head_idx=2) + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + + # ------------------------------------------------------------------ + # K / V SMEM layouts + TMA atoms (paged). + # sK uses the QK MMA operand B swizzle; sV uses the PV MMA operand B + # swizzle (MN-major). tP_layout is the TMEM-side P descriptor — no + # SMEM is actually allocated for P, it overlays the S region in TMEM + # via cute.recast_ptr in Phase 1.7. + # ------------------------------------------------------------------ + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + tma_atom_K, mK_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK_paged, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + tma_atom_V, mV_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV_paged, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # ------------------------------------------------------------------ + # Phase 1.10.6b-B-2: TMA-store atom for the epilogue write-back. + # Non-split writes bf16 final O; split-KV writes fp32 O_partial. + # sO follows FA/BSA epilogue layout: one full m_block x D tile in + # SMEM. Both paths expose global O as a packed-GQA tensor view so the + # final store is a full BSA-style m_block x D TMA tile. + # ------------------------------------------------------------------ + sO_layout = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.q_stage, + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_heads_kv_tma = mK.shape[1] + total_o_rows_tma = ( + mO_epilogue.shape[0] + // (num_heads_kv_tma * self.qhead_per_kv) + ) + head_stride_tma = self.head_dim + o_row_stride_tma = ( + num_heads_kv_tma * self.qhead_per_kv * self.head_dim) + kv_head_stride_tma = self.qhead_per_kv * self.head_dim + mO_epilogue_tma = cute.make_tensor( + mO_epilogue.iterator, + cute.make_layout( + ((self.qhead_per_kv, total_o_rows_tma), self.head_dim, num_heads_kv_tma), + stride=((head_stride_tma, o_row_stride_tma), 1, kv_head_stride_tma), + ), + ) + tma_atom_O, mO_tma = cpasync.make_tiled_tma_atom( + tma_store_op, + mO_epilogue_tma, + cute.select(sO_layout, mode=[0, 1]), + self.epi_tile, + ) + + # Pre-multiply softmax scale by log2(e) so the inner exp2 path can + # operate without re-scaling at every iteration. Mirrors prefill. + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + + work_capacity = mRequestIndices.shape[0] + num_heads_kv = mK.shape[1] + tile_sched_args = DecodeTileSchedulerArguments( + Int32(work_capacity), + Int32(num_heads_kv), + cluster_shape_mn=self.cluster_shape_mn, + ) + tile_sched_params = DecodeTileScheduler.to_underlying_arguments( + tile_sched_args, + scheduling_mode=self.scheduling_mode, + ) + self.tile_scheduler_cls = DecodeTileScheduler + grid = DecodeTileScheduler.get_grid_shape(tile_sched_params) + + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + # ------------------------------------------------------------------ + # SharedStorage mirrors BSA blk128's pipeline mesh for dense paged + # decode: Q, shared K/V, S/P/O, P-lastsplit, O-acc, O-epilogue and + # softmax stats mbarriers, plus the TMEM allocator state and SMEM + # staging tensors. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] + mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_P_full_lastsplit: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_O_full: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_softmax_stats0: cute.struct.MemRange[Int64, 2] + mbar_softmax_stats1: cute.struct.MemRange[Int64, 2] + mbar_O_epi: cute.struct.MemRange[Int64, self.s_stage * 2] + # Phase 1.10.6b-B-2: bf16 sO SMEM staging buffer for the TMA + # store epilogue. Sized for one full m_block_size × head_dim + # tile (single stage; overlap with sQ left for later perf tune). + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # ------------------------------------------------------------------ + # Launch — decode tasks are consumed from the + # (work_idx, head_kv_idx) scheduler space. In CLC mode grid is the + # BSA-style hardware problem shape; in static mode it is capped to the + # SM count and each CTA walks the flattened task stream. + # ------------------------------------------------------------------ + # q_tma_bytes (and Phase 1.5+: kv_tma_bytes / q_subtile_bytes) are + # recomputed inside the kernel from the constexpr SMEM layouts. + # Passing them as Constexpr[int] kernel args ended up marshalling + # to dynamic Int32 here, which then tripped MbarrierArray's + # `if tx_count < 0` check inside PipelineTmaUmma.create. + self.kernel( + mQ, mK_paged, mV_paged, + mPageTable, mSeqUsedK, + mRequestIndices, mQoTileIndices, mKvTileIndices, mBlockValidMask, + mSplitCounts, mOIndptr, + mO, mO_tma, mLSE, + mO_partial, mLSE_partial, + softmax_scale_log2, + sQ_layout, sK_layout, sV_layout, tP_layout, sO_layout, + tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, + tiled_mma_qk, tiled_mma_pv, + tile_sched_params, + seqlen_q, page_size, kv_chunk_size_pages, + Int32(num_heads_kv), + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=( + self.cluster_shape_mnk + if cute.size(self.cluster_shape_mnk) > 1 else None + ), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + # --- runtime tensors ------------------------------------------------- + mQ: cute.Tensor, # [((qhead_per_kv, Sq), D, Hkv, B)] + mK_paged: cute.Tensor, # [page_size, D, Hkv, num_pages] fp8 + mV_paged: cute.Tensor, # [D, page_size, Hkv, num_pages] fp8 + mPageTable: cute.Tensor, + mSeqUsedK: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mBlockValidMask: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mO_tma: cute.Tensor, + mLSE: cute.Tensor, + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + # --- scalars --------------------------------------------------------- + softmax_scale_log2: Float32, + # --- SMEM layouts ---------------------------------------------------- + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + # --- TMA atoms ------------------------------------------------------- + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + # --- TiledMma -------------------------------------------------------- + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: DecodeTileScheduler.Params, + # --- Int32 iteration bounds ------------------------------------------ + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, work-item dispatch. + # ------------------------------------------------------------------ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + if warp_idx == Int32(0): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + # ------------------------------------------------------------------ + # SMEM allocation — same SharedStorage type was registered on the + # class in __call__ (Phase 1.3). Every warp materialises the same + # storage view; later phases populate sQ/sK/sV/mbar contents. + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + # sQ is the MMA-operand layout and now also the Q TMA load target: + # PackGQA makes the global Q view match the full BSA (tile_m, D) tile. + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + # ------------------------------------------------------------------ + # TMEM allocator — MMA warp performs the allocation, all softmax / + # store / MMA warps participate in the TmemPtr named barrier that + # broadcasts the allocator pointer. Spare warp and KV-load warps + # do not touch TMEM directly. + # ------------------------------------------------------------------ + # TmemPtr participants: 2 softmax WGs (8 warps) + correction WG + # (4 warps) + MMA warp = 13 warps × WARP_SIZE. Load / spare / + # empty warps don't touch TMEM and don't arrive on this barrier. + tmem_alloc_warps: cutlass.Constexpr[int] = ( + self.warps_per_group * 3 + 1) + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + ) + tmem_cols = self.tmem_total + + # ------------------------------------------------------------------ + # Cluster layout + warp-specialized pipelines. + # Mirrors prefill (src/sm100/fwd/atten_fwd.py:617-683): cta_layout_vmnk + # is rebuilt in-kernel from tiled_mma_qk.thr_id.shape so its size is + # constexpr (the `cute.size(cta_layout_vmnk) == 1` check inside + # PipelineTmaUmma.create folds at compile time). pipeline_q is + # joined by the BSA S/P/O and shared K/V pipelines below. + # ------------------------------------------------------------------ + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + # One softmax WG participates per S/P/O stage; correction and the + # epilogue warp handle O rescale and TMA write-back. + softmax_warps = ThreadCooperativeGroup(self.warps_per_group) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + + # Recompute TMA byte counts inside the kernel from the constexpr SMEM + # layouts — see note in __call__ above the self.kernel(...) call for + # why these can't be plumbed through as Constexpr[int] kernel args. + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + k_tma_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # Decode KV follows BSA's single K/V ring: K0 is primed before Q, + # then K1, V0, K2, V1, ... share one PipelineTmaUmma state while + # landing in separate sK/sV SMEM tensors. For fp8 decode K/V TMA + # tiles have the same byte count, so the shared barrier uses K's count. + pipeline_kv = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=k_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + # ------------------------------------------------------------------ + # BSA pipeline mesh. + # pipeline_s_p_o — MMA→{softmax,correction} (8-warp cluster + # consumer). MMA producer_commit signals + # "S ready"; consumer_release signals "P stored + # and acc_O rescaled — MMA can issue next QK". + # pipeline_o_acc — MMA→correction (acc_O updated by PV). + # pipeline_sm_stats0/1 — softmax→correction stage-local stats. + # This avoids the per-warp NamedBarrier used by + # the BSA reference while preserving the same + # first/rescale/final signal sequence. + # pipeline_o_epi — correction→epilogue warp 13 (final O ready). + # ------------------------------------------------------------------ + softmax_correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE + * (self.warps_per_group + self.warps_per_group) # = 256 + ) + correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group # = 128 + ) + epilogue_warp_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE # warp 13 = 32 threads + ) + + pipeline_s_p_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_warps, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o_acc = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_O_full.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats0 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats0.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_sm_stats1 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats1.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_o_epi = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_O_epi.data_ptr(), + num_stages=self.s_stage, + producer_group=correction_threads, + consumer_group=epilogue_warp_threads, + defer_sync=True, + ) + + # Fence mbar init across all regular pipelines. CLC pipeline setup + # follows the BSA ordering: arrive after mbar init, create scheduler + # state, then wait before TMEM allocation and role dispatch. + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps = ( + self.threads_per_cta // cute.arch.WARP_SIZE + ) * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_clc_consumer_warps, + ) + clc_pipeline = cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ) + tile_scheduler = self.tile_scheduler_cls.create( + tile_sched_params, clc_response_ptr=clc_response_ptr + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + tile_scheduler.set_clc_pipeline( + clc_pipeline, clc_consumer_state) + else: + clc_pipeline = None + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # Single load warp issues Q + K + V TMA serially; no inter-warp + # broadcast / Q-load WG barrier needed (the BSA-aligned layout + # collapses the previous 4-warp Q-load fan-out into one warp). + + # ------------------------------------------------------------------ + # Phase 1.10.3: pre-dispatch TMEM partitions for softmax read/write. + # Mirrors prefill softmax body setup + # (src/sm100/fwd/atten_fwd.py:807-829, 1891-1921). Built once across + # all warps so each softmax WG can take its stage slice. + # ------------------------------------------------------------------ + thr_mma_qk_pre = tiled_mma_qk.get_slice(0) + qk_acc_shape_pre = thr_mma_qk_pre.partition_shape_C( + self.mma_tiler_qk[:2]) + tStS_base_pre = thr_mma_qk_pre.make_fragment_C(qk_acc_shape_pre) + tStS_pre = cute.make_tensor( + tStS_base_pre.iterator, + cute.append( + tStS_base_pre.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tScS_pre = thr_mma_qk_pre.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS_pre = tScS_pre[(None, None), 0, 0] + # fp8 P occupies n_block_size * fp8_width / fp32_width fp32 cols. + tilePlikeFP32 = const_expr( + self.mma_tiler_qk[1] * self.q_dtype.width // Float32.width) + tmem_load_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + # Repetition(8) gives ``tStP_r2t.shape[2] = tilePlikeFP32 / 8 = 4`` + # chunks for fp8 (tilePlikeFP32=32), with each chunk publishing + # 8 fp32 cols = 32 fp8 cols = exactly one PV ``f8f6f4`` K=32 + # segment. ``split_idx = 4 * 3N/4 / N = 3`` aligns the early + # publish edge to the producer/consumer K boundary. Larger + # Repetition (e.g. 16) would coarsen shape[2] to 2 and force + # split_idx to floor to 1, publishing only N/2 of P before MMA's + # first three K=32 segments need cols 0..3N/4 — that mismatch is + # the NaN source the workaround used to dodge with split=N/2. + tmem_store_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), + Float32, + ) + tmem_store_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tmem_load_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + # ------------------------------------------------------------------ + # Warp role dispatch. Bodies are filled in Phase 1.3-1.9: + # softmax WG 0/1 (warps 0-3, 4-7) — softmax + P fp32->fp8 convert + # store / Q-load WG (warps 8-11) — Q TMA gather + epilogue store + # MMA warp (warp 12) — UTCMMA QK + PV issue + # correction WG (warps 8-11) — per-page acc_O rescale + epilogue + # MMA warp (warp 12) — UTCMMA QK + PV issue + # spare warp (warp 13) — empty / future CLC scheduler + # load warp (warp 14) — serial Q + K + V TMA loads + # empty warp (warp 15) — register-budget reserve + # ------------------------------------------------------------------ + is_softmax0_warp = ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + ) + is_softmax1_warp = ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.correction_warp_base) + ) + is_correction_warp = ( + warp_idx >= Int32(self.correction_warp_base) + and warp_idx < Int32(self.mma_warp_id) + ) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + is_spare_warp = warp_idx == Int32(self.spare_warp_id) + is_load_warp = warp_idx == Int32(self.load_warp_id) + is_empty_warp = warp_idx == Int32(self.empty_warp_id) + + if const_expr(self.use_clc_scheduler): + if warp_idx == Int32(self.clc_scheduler_warp_id): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + self.clc_scheduler_warp(clc_pipeline, tile_scheduler) + is_empty_warp = False + + if is_softmax0_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg0 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg0 + self.softmax_loop( + 0, + self.softmax0_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats0, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_softmax1_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg1 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg1 + self.softmax_loop( + 1, + self.softmax1_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats1, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_correction_warp: + cute.arch.setmaxregister_decrease(self.num_regs_correction) + # Participate in TmemPtr handshake so the MMA warp can free. + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_corr + + self.correction_loop( + tiled_mma_pv, + tStS_pre, + tScS_pre, + tmem_load_vec_atom_pre, + pipeline_s_p_o, + pipeline_sm_stats0, + pipeline_sm_stats1, + pipeline_o_acc, + pipeline_o_epi, + sO, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mSplitCounts, + mOIndptr, + mLSE, + mLSE_partial, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + num_heads_kv, + softmax_scale_log2, + ) + tmem_alloc_barrier.arrive() + + if is_spare_warp: + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + self.epilogue_s2g( + mO_tma, + sO, + tma_atom_O, + pipeline_o_epi, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mOIndptr, + mBlockValidMask, + tile_scheduler, + seqlen_q, + ) + + if is_load_warp: + self.load( + tiled_mma_qk, + tiled_mma_pv, + mQ, + mK_paged, + mV_paged, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_q, + pipeline_kv, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + if is_empty_warp: + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + if is_mma_warp: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ---------------------------------------------------------------- + # MMA warp — Phase 1.6: QK fp8×fp8→fp32 UMMA. Phase 1.10.1 now + # wraps the body in the real TMEM allocator lifecycle: + # tmem.allocate(cols) -> wait_for_alloc -> retrieve_ptr + # -> ... QK work ... + # -> relinquish_alloc_permit -> tmem_alloc_barrier.arrive_and_wait + # -> free(ptr, cols) + # Softmax WG 0/1 participate via wait_for_alloc + retrieve_ptr + + # tmem_alloc_barrier.arrive (4+4+1 = 9 warps). + # ---------------------------------------------------------------- + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr # consumed by gemm_pv via raw TMEM offsets + + self.mma( + sQ, + sK, + sV, + tP_layout, + tiled_mma_qk, + tiled_mma_pv, + pipeline_q, + pipeline_kv, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_o_acc, + mRequestIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + # Phase 1.10.1: TMEM allocator teardown. + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + + @cute.jit + def clc_scheduler_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync, + tile_scheduler: DecodeTileScheduler, + ) -> None: + clc_producer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, + self.sched_stages, + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.producer_acquire(clc_producer_state) + mbarrier_addr = clc_pipeline.producer_get_barrier( + clc_producer_state) + tile_scheduler.advance_to_next_work( + mbarrier_addr=mbarrier_addr, + response_stage=clc_producer_state.index, + ) + clc_producer_state.advance() + + clc_pipeline.consumer_wait(clc_consumer_state) + work_tile = tile_scheduler.get_current_work( + response_stage=clc_consumer_state.index) + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() + clc_pipeline.producer_tail(clc_producer_state) + + @cute.jit + def correction_loop( + self, + tiled_mma_pv: cute.TiledMma, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tmem_load_vec_atom_pre: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_sm_stats0: pipeline.PipelineAsync, + pipeline_sm_stats1: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + pipeline_o_epi: pipeline.PipelineAsync, + sO: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mLSE: cute.Tensor, + mLSE_partial: Optional[cute.Tensor], + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + softmax_scale_log2: Float32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg_corr = warp_idx - Int32(self.correction_warp_base) + group_tidx_corr = ( + warp_idx_in_wg_corr * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + + # First iter: no correction is required. Notify MMA that the + # initial O slots are available, matching BSA's correction_loop. + for stage_init in cutlass.range_constexpr(self.s_stage): + pipeline_s_p_o.consumer_release_w_index(Int32(stage_init)) + + o_corr_consumer_phase = Int32(0) + sm_stats0_consumer_phase = Int32(0) + sm_stats1_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + thr0_rs = tiled_mma_pv.get_slice(0) + pv_acc_shape_rs_c = thr0_rs.partition_shape_C( + self.mma_tiler_pv[:2]) + tOtO_base_rs_c = thr0_rs.make_fragment_C(pv_acc_shape_rs_c) + tOtO_rs_c = cute.make_tensor( + tOtO_base_rs_c.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base_rs_c.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tScS_vec_layout_corr = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec_corr = cute.make_tensor( + tScS_pre.iterator, tScS_vec_layout_corr) + tSAcc_corr0 = tStS_pre[(None, None), 0, 0, 0] + tSAcc_corr1 = tStS_pre[(None, None), 0, 0, 1] + tStS_vec0_layout_corr = cute.composition( + tSAcc_corr0.layout, cute.make_layout((self.m_block_size, 2))) + tStS_vec1_layout_corr = cute.composition( + tSAcc_corr1.layout, cute.make_layout((self.m_block_size, 2))) + tStStats0_t2r_src = cute.make_tensor( + tSAcc_corr0.iterator, tStS_vec0_layout_corr) + tStStats1_t2r_src = cute.make_tensor( + tSAcc_corr1.iterator, tStS_vec1_layout_corr) + thr_tmem_load_vec = tcgen05.make_tmem_copy( + tmem_load_vec_atom_pre, + tStStats0_t2r_src, + ).get_slice(group_tidx_corr) + tStStats0_t2r = thr_tmem_load_vec.partition_S(tStStats0_t2r_src) + tStStats1_t2r = thr_tmem_load_vec.partition_S(tStStats1_t2r_src) + tScStats_t2r = thr_tmem_load_vec.partition_D(tScS_vec_corr) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_corr = mRequestIndices[work_idx] + qo_tile_corr = mQoTileIndices[work_idx] + seqused_k_corr = mSeqUsedK[batch_idx_corr] + split_idx_corr = mKvTileIndices[work_idx] + kv_pages_corr = ( + seqused_k_corr + page_size - Int32(1)) // page_size + kv_page_begin_corr = split_idx_corr * kv_chunk_size_pages + kv_page_end_corr = cutlass.min( + kv_pages_corr, + kv_page_begin_corr + kv_chunk_size_pages, + ) + page_count_corr = kv_page_end_corr - kv_page_begin_corr + block_iter_count_corr = ( + page_count_corr + Int32(1)) & ~Int32(1) + stage0_count_corr = block_iter_count_corr // Int32(2) + stage1_count_corr = block_iter_count_corr // Int32(2) + + if stage0_count_corr > Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + if stage1_count_corr > Int32(0): + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + for page_rel_corr in cutlass.range( + Int32(self.s_stage), block_iter_count_corr, unroll=1 + ): + # sm_stats[0] now holds the deferred-exp2 log2-delta: + # 0.0 means "no rescale needed", a negative value is the + # raw delta that needs exp2 to become a true scale factor. + if (page_rel_corr & Int32(1)) == Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 0], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 1], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(1)) + + for stage_wait in cutlass.range_constexpr(self.s_stage): + stage_count_wait = ( + stage0_count_corr + if const_expr(stage_wait == 0) + else stage1_count_corr + ) + if stage_count_wait > Int32(0): + pipeline_o_acc.consumer_wait_w_index_phase( + Int32(stage_wait), o_corr_consumer_phase) + + row_sum0 = Float32(0.0) + row_sum1 = Float32(0.0) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + for stage_final in cutlass.range_constexpr(self.s_stage): + if const_expr(stage_final == 0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum0 = tSrStats[0] + row_max0 = tSrStats[1] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum1 = tSrStats[0] + row_max1 = tSrStats[1] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + zero0 = row_sum0 == Float32(0.0) or row_sum0 != row_sum0 + zero1 = row_sum1 == Float32(0.0) or row_sum1 != row_sum1 + rm0 = -Float32.inf if zero0 else row_max0 + rm1 = -Float32.inf if zero1 else row_max1 + row_max_comb = cutlass.max(rm0, rm1) + row_max_safe = ( + Float32(0.0) + if row_max_comb == -Float32.inf + else row_max_comb + ) + scale0 = ( + Float32(0.0) + if zero0 + else cute.math.exp2( + (rm0 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + scale1 = ( + Float32(0.0) + if zero1 + else cute.math.exp2( + (rm1 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + row_sum_comb = row_sum0 * scale0 + row_sum1 * scale1 + combined_zero_or_nan = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + inv_sum = cute.arch.rcp_approx( + Float32(1.0) + if combined_zero_or_nan else row_sum_comb) + final_scale0 = scale0 * inv_sum + final_scale1 = scale1 * inv_sum + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(0), corr_epi_producer_phase) + self.correction_epilogue_combine( + tiled_mma_pv, + sO[None, None, 0], + group_tidx_corr, + final_scale0, + final_scale1, + ) + + if const_expr(self.write_lse or self.split_kv): + if group_tidx_corr < Int32(self.m_block_size): + is_bad_lse = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + LN2 = Float32(math.log(2.0)) + lse_val = ( + -Float32.inf if is_bad_lse + else ( + row_max_safe * softmax_scale_log2 + + cute.math.log2(row_sum_comb, fastmath=True) + ) * LN2 + ) + tok_lse = group_tidx_corr // Int32(self.qhead_per_kv) + if tok_lse < seqlen_q: + h_in_kv_lse = ( + group_tidx_corr + - tok_lse * Int32(self.qhead_per_kv)) + q_idx_lse = ( + qo_tile_corr * Int32(self.q_tokens_per_group) + + tok_lse + ) + h_abs_lse = ( + head_kv_idx * Int32(self.qhead_per_kv) + + h_in_kv_lse + ) + if const_expr(self.split_kv): + q_tokens_per_group = Int32( + self.q_tokens_per_group) + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row_lse = ( + mOIndptr[batch_idx_corr] + + split_idx_corr * q_stride_partial + + q_idx_lse + ) + mLSE_partial[ + partial_row_lse, h_abs_lse] = lse_val + else: + q_abs_lse = ( + batch_idx_corr * seqlen_q + q_idx_lse) + mLSE[q_abs_lse, h_abs_lse] = lse_val + + for stage_release in cutlass.range_constexpr(self.s_stage): + stage_count_release = ( + stage0_count_corr + if const_expr(stage_release == 0) + else stage1_count_corr + ) + if stage_count_release > Int32(0): + pipeline_s_p_o.consumer_release_w_index( + Int32(stage_release)) + pipeline_o_acc.consumer_release_w_index( + Int32(stage_release)) + if block_iter_count_corr > Int32(0): + o_corr_consumer_phase = ( + o_corr_consumer_phase ^ Int32(1)) + + pipeline_o_epi.producer_commit_w_index(Int32(0)) + corr_epi_producer_phase = ( + corr_epi_producer_phase ^ Int32(1)) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), corr_epi_producer_phase) + + @cute.jit + def epilogue_s2g( + self, + mO_tma: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + pipeline_o_epi: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mOIndptr: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + ) -> None: + epi_consumer_phase = Int32(0) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + split_idx = mKvTileIndices[work_idx] + + pipeline_o_epi.consumer_wait_w_index_phase( + Int32(0), epi_consumer_phase) + q_tokens_per_group = Int32(self.q_tokens_per_group) + gO = cute.local_tile( + mO_tma[None, None, head_kv_idx], + self.epi_tile, + (None, 0), + ) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO) + if const_expr(not self.split_kv): + q_abs = ( + batch_idx * seqlen_q + + qo_tile * q_tokens_per_group + ) + dst_idx = q_abs // q_tokens_per_group + else: + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row = ( + mOIndptr[batch_idx] + + split_idx * q_stride_partial + + qo_tile * q_tokens_per_group + ) + dst_idx = partial_row // q_tokens_per_group + store_O(src_idx=Int32(0), dst_idx=dst_idx) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0) + pipeline_o_epi.consumer_release_w_index(Int32(0)) + epi_consumer_phase = epi_consumer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def correction_epilogue_combine( + self, + tiled_mma_pv: cute.TiledMma, + sO: cute.Tensor, + tidx: Int32, + scale0: Float32, + scale1: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr_mma.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr_mma.make_fragment_C(pv_acc_shape) + tOtO = cute.make_tensor( + tOtO_base.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tOsO = thr_mma.get_slice(0).partition_C(sO) + tOcO_full = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = ( + 8 * 32 // self.o_dtype.width + ) + tOsO_i = cute.logical_divide( + tOsO, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOcO_i = cute.logical_divide( + tOcO_full, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO0_i = cute.logical_divide( + tOtO[None, None, None, 0], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO1_i = cute.logical_divide( + tOtO[None, None, None, 1], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_load_atom = sm100_utils.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=self.use_2cta_instrs, + ) + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO0_i[(None, None), 0]) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load) + tiled_smem_store = cute.make_tiled_copy_D( + smem_copy_atom, tiled_tmem_load) + tOtO0_t2r = thr_tmem_load.partition_S( + tOtO0_i[(None, None), None]) + tOtO1_t2r = thr_tmem_load.partition_S( + tOtO1_i[(None, None), None]) + tOsO_s2r = copy_utils.partition_D_position_independent( + thr_tmem_load, tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D( + tOcO_i[(None, None), None]) + + for col_pass_idx in cutlass.range( + self.head_dim // corr_tile_size, unroll_full=True): + tOtO0_t2r_i = tOtO0_t2r[None, 0, 0, col_pass_idx] + tOtO1_t2r_i = tOtO1_t2r[None, 0, 0, col_pass_idx] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, col_pass_idx] + frg_shape = tOcO_t2r[None, 0, 0, col_pass_idx].shape + tOrO0_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + tOrO1_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + is_zero_output = ( + scale0 == Float32(0.0) and scale1 == Float32(0.0) + ) + if not is_zero_output: + cute.copy(tiled_tmem_load, tOtO0_t2r_i, tOrO0_frg) + cute.copy(tiled_tmem_load, tOtO1_t2r_i, tOrO1_frg) + for j in cutlass.range( + 0, cute.size(tOrO0_frg), 2, unroll_full=True + ): + o0_a, o0_b = cute.arch.mul_packed_f32x2( + (tOrO0_frg[j], tOrO0_frg[j + 1]), + (scale0, scale0), + ) + o1_a, o1_b = cute.arch.mul_packed_f32x2( + (tOrO1_frg[j], tOrO1_frg[j + 1]), + (scale1, scale1), + ) + tOrO0_frg[j], tOrO0_frg[j + 1] = ( + cute.arch.add_packed_f32x2( + (o0_a, o0_b), (o1_a, o1_b)) + ) + else: + tOrO0_frg.fill(Float32(0.0)) + copy_utils.cvt_copy(tiled_smem_store, tOrO0_frg, tOsO_r2s_i) + cute.arch.fence_view_async_shared() + + @cute.jit + def correction_rescale( + self, + tiled_mma_pv: cute.TiledMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + tOcO = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = 16 + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition( + tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count: cutlass.Constexpr[int] = self.head_dim // corr_tile_size + for fi in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment( + tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor( + tOtO_t2r.iterator + fi * corr_tile_size, + tOtO_t2r.layout, + ) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range( + 0, cute.size(tOrO_frg), 2, unroll_full=True + ): + tOrO_frg[j], tOrO_frg[j + 1] = ( + cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + ) + tOtO_r2t_i = cute.make_tensor( + tOtO_r2t.iterator + fi * corr_tile_size, + tOtO_r2t.layout, + ) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tP_layout: cute.ComposedLayout, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + thr_mma_qk = tiled_mma_qk.get_slice(0) + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0_layout = tSrQ[None, None, None, 0].layout + tSrK0_layout = tSrK[None, None, None, 0].layout + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, 0].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, q_smem_base, tSrQ0_layout, + var_name_prefix="decode_q_smem_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="decode_qk_idesc") + gemm_qk = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0_layout, + smem_var_name_prefix="decode_q_smem_desc", + idesc_var_name="decode_qk_idesc", + smem_offset=0, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP_base = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = const_expr(Float32.width // self.v_dtype.width) + tP_stage_stride = const_expr( + self.tmem_stage_stride * tP_width_ratio) + tOrP = cute.make_tensor( + tOrP_base.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP_base.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + tOrV = tiled_mma_pv.make_fragment_B(sV) + pv_mma_op = tiled_mma_pv.op + sm100_helpers.declare_ptx_idesc( + pv_mma_op, var_name="decode_pv_idesc") + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage) + phase_s0 = Int32(0) + phase_s1 = Int32(0) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_mma = mRequestIndices[work_idx] + split_idx_mma = mKvTileIndices[work_idx] + seqused_k_mma = mSeqUsedK[batch_idx_mma] + kv_pages_mma = ( + seqused_k_mma + page_size - Int32(1)) // page_size + kv_page_begin_mma = split_idx_mma * kv_chunk_size_pages + kv_page_end_mma = cutlass.min( + kv_pages_mma, + kv_page_begin_mma + kv_chunk_size_pages, + ) + page_count_mma = kv_page_end_mma - kv_page_begin_mma + block_iter_count_mma = ( + page_count_mma + Int32(1)) & ~Int32(1) + + pipeline_q.consumer_wait_w_index_phase( + Int32(0), mma_q_consumer_phase) + mma_q_consumer_phase = mma_q_consumer_phase ^ Int32(1) + if block_iter_count_mma > Int32(0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(0)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(1): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(1)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(self.s_stage): + for page_rel_pv in cutlass.range( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + unroll=1, + ): + pv_slot = page_rel_pv & Int32(1) + pv_stage_iter = page_rel_pv // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + page_rel_qk = page_rel_pv + Int32(self.s_stage) + qk_slot = page_rel_qk & Int32(1) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + qk_slot * Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(qk_slot) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + pipeline_q.consumer_release_w_index(Int32(0)) + + if block_iter_count_mma > Int32(0): + page_rel_epi_begin = cutlass.max( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin, block_iter_count_mma, unroll=1 + ): + pv_slot = page_rel_epi & Int32(1) + pv_stage_iter = page_rel_epi // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + pipeline_o_acc.producer_commit_w_index(pv_slot) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def softmax_loop( + self, + stage: cutlass.Constexpr[int], + warp_base: cutlass.Constexpr[int], + softmax_scale_log2: Float32, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tilePlikeFP32: cutlass.Constexpr[int], + tmem_load_atom_pre: cute.CopyAtom, + tmem_store_atom_pre: cute.CopyAtom, + tmem_store_vec_atom_pre: cute.CopyAtom, + thr_mma_qk_pre: cute.core.ThrMma, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg = warp_idx - Int32(warp_base) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stage_i32 = Int32(stage) + + tSAcc = tStS_pre[(None, None), 0, 0, stage] + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom_pre, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS_pre) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32)), + ) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, + tStP_layout, + ) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom_pre, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + tScS_vec_layout = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec = cute.make_tensor(tScS_pre.iterator, tScS_vec_layout) + tStS_vec_layout = cute.composition( + tSAcc.layout, cute.make_layout((self.m_block_size, 2))) + tStStats_r2t_dst = cute.make_tensor( + tSAcc.iterator, tStS_vec_layout) + thr_tmem_store_vec = tcgen05.make_tmem_copy( + tmem_store_vec_atom_pre, + tStStats_r2t_dst, + ).get_slice(group_tidx) + tStStats_r2t = thr_tmem_store_vec.partition_D(tStStats_r2t_dst) + tScStats_r2t = thr_tmem_store_vec.partition_S(tScS_vec) + tScP_shape = ( + self.mma_tiler_qk[0] // thr_mma_qk_pre.thr_id.shape, + tilePlikeFP32, + ) + + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32, + ) + s_consumer_phase = Int32(0) + sm_stats_producer_phase = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=self.rescale_threshold, + ) + softmax.reset() + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + seqused_k = mSeqUsedK[batch_idx] + split_idx = mKvTileIndices[work_idx] + kv_pages = ( + seqused_k + page_size - Int32(1)) // page_size + kv_page_begin = split_idx * kv_chunk_size_pages + kv_page_end = cutlass.min( + kv_pages, kv_page_begin + kv_chunk_size_pages + ) + page_count = kv_page_end - kv_page_begin + block_iter_count = (page_count + Int32(1)) & ~Int32(1) + if const_expr(stage == 0): + stage_page_count = block_iter_count // Int32(2) + else: + stage_page_count = block_iter_count // Int32(2) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seqlen_q, + seqused_k, + False, + False, + False, + True, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + qhead_per_kvhead_packgqa=self.qhead_per_kv, + ) + wg_count = stage_page_count + if wg_count > Int32(0): + page_rel0 = stage_i32 + page_rel0_clamped = cutlass.min( + page_rel0, page_count - Int32(1)) + page_idx_global = kv_page_end - Int32(1) - page_rel0_clamped + kv_valid_cols = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global * page_size, + ) + if page_rel0 >= page_count: + kv_valid_cols = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, + mask, + stage_i32, + s_consumer_phase, + page_idx_global, + qo_tile, + kv_valid_cols, + tStS_t2r, + tScS_t2r, + tStP_r2t, + tSrP_r2t_f32, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, + warp_idx_in_wg, + tStStats_r2t, + tScStats_r2t, + sm_stats_producer_phase, + is_first=True, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + for stage_iter in cutlass.range( + Int32(1), wg_count, unroll=1 + ): + page_rel = ( + stage_iter * Int32(self.s_stage) + stage_i32) + page_rel_clamped = cutlass.min( + page_rel, page_count - Int32(1)) + page_idx_global_n = ( + kv_page_end - Int32(1) - page_rel_clamped) + kv_valid_cols_n = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global_n * page_size, + ) + # Dummy-iter analysis: with s_stage=2, the WG that + # handles stage_i32=0 only ever sees page_rel ≤ + # block_iter_count - 2 < page_count → NEVER dummy. + # The WG with stage_i32=1 sees page_rel = + # block_iter_count - 1 at its last iter, which + # equals page_count iff page_count is odd → only + # WG1 may need the runtime mask_dummy_only guard. + # Pass None for WG0 so the const_expr branch in + # softmax_step eliminates the runtime check + # entirely (compile-time disappears). + if const_expr(stage == 0): + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + # mask_dummy_only=None → no runtime check + ) + else: + is_dummy = page_rel >= page_count + if is_dummy: + kv_valid_cols_n = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + mask_dummy_only=is_dummy, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = softmax.row_sum[0] + tSrStats[1] = softmax.row_max[0] + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + else: + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = Float32(0.0) + tSrStats[1] = -Float32.inf + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + mask: AttentionMask, + stage: Int32, + s_phase: Int32, + page_idx: Int32, + qo_tile: Int32, + kv_valid_cols: Int32, + tStS_t2r: cute.Tensor, + tScS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tSrP_r2t_f32: cute.Tensor, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_vec: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tStStats_r2t: cute.Tensor, + tScStats_r2t: cute.Tensor, + sm_stats_producer_phase: Int32, + is_first: cutlass.Constexpr[bool], + apply_mask: cutlass.Constexpr[bool] = True, + mask_dummy_only: Optional[cutlass.Boolean] = None, + ) -> Int32: + # apply_mask=False is the inner-page fast path: skip both the seqlen + # bounds check and the causal-diagonal check, which together cost ~15 + # cyc per iter on the producer pre-publication critical path that + # gates correction WG's consumer_wait (top long_scoreboard PC in NCU). + # Callers must only set apply_mask=False when they can prove the tile + # is fully unmasked (no partial-page seqlen tail, no causal diagonal + # cut). + # + # mask_dummy_only (runtime bool, used only when apply_mask=False): + # when True the iter is a "dummy" rounded-up iter that needs the + # mask to zero out garbage S — runs the mask at runtime cost. For + # non-dummy iters it stays the fast no-mask path. + pipeline_s_p_o.consumer_wait_w_index_phase(stage, s_phase) + sm_stats_try_acquire = ( + pipeline_sm_stats.producer_try_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + ) + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(apply_mask): + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + elif const_expr(mask_dummy_only is not None): + if mask_dummy_only: + # Dummy iter — zero everything via mask (kv_valid_cols=0 + # makes mask_r2p_lambda set all positions to -inf). + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + # Publish acc_scale in log2-domain (un-exp2'd); correction WG does + # the exp2 only when an actual rescale fires. Removes MUFU.EX2 from + # the sm_stats publication critical path that gates correction's + # consumer_wait (the dominant long_scoreboard hot PC in NCU). + row_max, acc_scale_log2 = softmax.update_row_max_deferred_exp2( + tSrS_t2r.load(), is_first) + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase, sm_stats_try_acquire) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = acc_scale_log2 + tSrStats[1] = row_max + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # exp2 for the internal row_sum carry happens AFTER producer_commit, so + # it no longer extends correction's consumer-wait window. + # acc_scale_log2 == 0.0 in the threshold/first-iter paths makes + # exp2(0)=1.0, which is the no-rescale identity for the row_sum carry — + # semantically equivalent to the original ``acc_scale=1.0`` branch. + if const_expr(is_first): + row_sum_init = Float32(0.0) + else: + acc_scale_mult = cute.math.exp2(acc_scale_log2, fastmath=True) + row_sum_init = softmax.row_sum[0] * acc_scale_mult + # Bulk EX2 emulation parameters. + # + # ex2_emu_freq=16 emulate exp2 with FFMA2 polynomial on + # 15 of every 16 (j, k) positions; the + # remaining 1/16 still issues MUFU.EX2. + # This cuts the MUFU.EX2 throughput bottleneck + # in the softmax inner loop (≈22k cyc + # saved per stage at baseline). + # ex2_emu_res=3 degree-3 polynomial; res=4 broke + # kv=1024 close-tolerance even with + # poly_degree=5 — 3 is the most aggressive + # setting that still passes cos_sim ≥ 0.99 + # against the reference for the fp8 PV path. + # ex2_emu_start_frg=1 skip the emulation for fragment index 0 + # (preserves accuracy on the first iter + # where row_max is least settled). + # + # If you tune these, re-run the variable-kv self-consistency check + # (split vs non-split must stay at cos_min ≥ 0.99). + softmax.row_sum[0] = softmax.scale_apply_exp2_convert_sum( + tSrS_t2r, + row_max, + tSrP_r2t, + row_sum_init, + ex2_emu_freq=16, + ex2_emu_res=3, + ex2_emu_start_frg=1, + ) + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k], + ) + if const_expr(self.split_P_arrive > 0): + split_P_arrive_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive + // self.n_block_size + ) + if const_expr(k + 1 == split_P_arrive_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_s_p_o.consumer_release_w_index(stage) + cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_p_lastsplit.producer_commit_w_index(stage) + else: + pipeline_s_p_o.consumer_release_w_index(stage) + return sm_stats_producer_phase + + @cute.jit + def load( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mQ: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + cute.arch.setmaxregister_decrease(self.num_regs_load) + thr_mma_qk_ld = tiled_mma_qk.get_slice(0) + thr_mma_pv_ld = tiled_mma_pv.get_slice(0) + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_ld = mRequestIndices[work_idx] + qo_tile_ld = mQoTileIndices[work_idx] + split_idx_ld = mKvTileIndices[work_idx] + seqused_k_ld = mSeqUsedK[batch_idx_ld] + kv_pages_ld = ( + seqused_k_ld + page_size - Int32(1)) // page_size + kv_page_begin_ld = split_idx_ld * kv_chunk_size_pages + kv_page_end_ld = cutlass.min( + kv_pages_ld, kv_page_begin_ld + kv_chunk_size_pages + ) + page_count_ld = kv_page_end_ld - kv_page_begin_ld + block_iter_count_ld = ( + page_count_ld + Int32(1)) & ~Int32(1) + physical_page_v0 = Int32(0) + physical_page_v1 = Int32(0) + + mQ_cur_ld = mQ[None, None, None, batch_idx_ld][ + None, None, head_kv_idx + ] + tiler_gQ_ld = ( + (self.mma_tiler_qk[0] * self.q_stage), + self.head_dim, + ) + gQ_ld = cute.local_tile( + mQ_cur_ld, tiler_gQ_ld, (qo_tile_ld, 0)) + gQ_ld = layout_utils.select( + cute.flat_divide(gQ_ld, (self.mma_tiler_qk[0],)), + mode=[0, 2, 1], + ) + tSgQ_ld = thr_mma_qk_ld.partition_A(gQ_ld) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ_ld, sQ + ) + mK_cur_ld = mK_paged[None, None, head_kv_idx, None] + gK_ld = cute.local_tile( + mK_cur_ld, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + tSgK_ld = thr_mma_qk_ld.partition_B(gK_ld) + tKsK_ld, tKgK_ld = cpasync.tma_partition( + tma_atom_K, 0, cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_ld, 0, 3), + ) + mV_cur_ld = mV_paged[None, None, head_kv_idx, None] + gV_ld = cute.local_tile( + mV_cur_ld, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + tOgV_ld = thr_mma_pv_ld.partition_B(gV_ld) + tVsV_ld, tVgV_ld = cpasync.tma_partition( + tma_atom_V, 0, cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV_ld, 0, 3), + ) + + if block_iter_count_ld > Int32(0): + # Prime K0 before Q; then follow BSA order + # K1, V0, K2, V1, ... + page_idx_ld0 = kv_page_end_ld - Int32(1) + physical_page_v0 = mPageTable[batch_idx_ld, page_idx_ld0] + physical_page_v1 = physical_page_v0 + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v0, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + self.load_Q( + load_Q_fn_full, + pipeline_q, + Int32(0), + q_producer_phase, + ) + q_producer_phase = q_producer_phase ^ Int32(1) + + if block_iter_count_ld > Int32(0): + if block_iter_count_ld > Int32(1): + page_rel_k1 = cutlass.min( + Int32(1), page_count_ld - Int32(1)) + page_idx_ld1 = kv_page_end_ld - Int32(1) - page_rel_k1 + physical_page_v1 = mPageTable[ + batch_idx_ld, page_idx_ld1] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v1, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + if block_iter_count_ld > Int32(2): + for page_rel in cutlass.range( + Int32(0), + block_iter_count_ld - Int32(2), + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + page_rel_k_ld = cutlass.min( + page_rel + Int32(2), + page_count_ld - Int32(1), + ) + page_idx_k_ld = ( + kv_page_end_ld - Int32(1) - page_rel_k_ld) + physical_page_k_ld = mPageTable[ + batch_idx_ld, page_idx_k_ld] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_k_ld, + pipeline_kv, + kv_producer_state, + ) + if (page_rel & Int32(1)) == Int32(0): + physical_page_v0 = physical_page_k_ld + else: + physical_page_v1 = physical_page_k_ld + kv_producer_state.advance() + + page_rel_epi_begin_ld = cutlass.max( + Int32(0), + block_iter_count_ld - Int32(2), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin_ld, + block_iter_count_ld, + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel_epi, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel_epi & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.consumer_advance() + + pipeline_kv.producer_tail(kv_producer_state) + pipeline_q.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), q_producer_phase) + + @cute.jit + def load_Q( + self, + load_Q_fn: Callable, + pipeline_q: pipeline.PipelineAsync, + stage: Int32, + phase: Int32, + ) -> None: + pipeline_q.producer_acquire_w_index_phase(stage, phase) + load_Q_fn( + src_idx=Int32(0), + dst_idx=stage, + tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage), + ) + + @cute.jit + def load_KV_physical( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + physical_page: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + ) -> None: + pipeline_kv.producer_acquire(producer_state) + cute.copy( + tma_atom, + tXgX[(None, 0, physical_page)], + tXsX[(None, producer_state.index)], + tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state), + ) + +_atten_compile_cache: dict[tuple[object, ...], object] = {} + + +def run_decode_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + disable_softmax_exp2: bool = False, + O_partial_dummy: Optional[torch.Tensor] = None, + LSE_partial_dummy: Optional[torch.Tensor] = None, +) -> None: + """Launch the SM100 UMMA paged decode attention CUTE DSL kernel. + + qhead_per_kv is derived from input shapes (q.shape[1] // k.shape[1]). + disable_softmax_exp2 toggles the sage-style host flag (decision §1.7); + default False keeps full ex2 emulation. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` let callers pre-allocate the + placeholder buffers for the non-split path, avoiding ~5us of per-call + ``torch.empty`` overhead in tight decoding loops. + """ + + q_dtype = torch2cute_dtype_map[q.dtype] + o_dtype = torch2cute_dtype_map[out.dtype] + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + write_lse = bool(return_lse) or bool(split_kv) + if int(seqlen_q) != q_tokens_per_group: + raise NotImplementedError( + "decode fp8 currently assumes one full packed-q tile: " + f"seqlen_q must equal {q_tokens_per_group}, got {seqlen_q}" + ) + key = ( + "decode_attention", + q.shape[-1], + q_dtype, + o_dtype, + bool(split_kv), + bool(causal), + int(qhead_per_kv), + int(seqlen_q), + bool(write_lse), + bool(disable_softmax_exp2), + ) + if key not in _atten_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + head_q = cute.sym_int64() + num_pages = cute.sym_int64() + head_kv = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + max_pages = cute.sym_int64() + work_capacity = cute.sym_int64() + partial_rows = cute.sym_int64() + partial_rows_flat = cute.sym_int64() + head_dim = int(q.shape[-1]) + kernel = SparseDecodeAttentionForwardSm100( + head_dim=head_dim, + qhead_per_kv=int(qhead_per_kv), + page_size=int(page_size), + split_kv=bool(split_kv), + causal=bool(causal), + write_lse=bool(write_lse), + disable_softmax_exp2=bool(disable_softmax_exp2), + ) + # Always pass non-None fake tensors so the @cute.kernel positional + # arg marshalling stays stable; the kernel only reads these when + # split_kv=True (decision #10 epilogue branch). + fake_O_partial = make_fake_tensor( + Float32, (partial_rows_flat, head_dim), divisibility=4) + fake_LSE_partial = make_fake_tensor( + Float32, (partial_rows, head_q), divisibility=1, leading_dim=1) + # Q is passed as a [B, Sq, Hq, D] view so the kernel can build the same + # PackGQA TMA view used by FA/BSA and issue one full-tile Q TMA. + # O still uses the compact 2D view for the packed-GQA TMA epilogue. + total_q_flat = cute.sym_int64() + _atten_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor( + q_dtype, (batch, int(seqlen_q), head_q, head_dim), + divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(Int32, (batch, max_pages), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(o_dtype, (total_q_flat, head_dim), divisibility=128 // o_dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + fake_O_partial, + fake_LSE_partial, + Float32(float(softmax_scale)), + Int32(int(seqlen_q)), + Int32(int(page_size)), + Int32(int(kv_chunk_size_pages)), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + q_4d = q.view( + q.shape[0] // int(seqlen_q), int(seqlen_q), q.shape[1], q.shape[2]) + out_2d = out.view(out.shape[0] * out.shape[1], out.shape[2]) + # Compile keeps non-None fake partial buffers for positional stability + # (see fake_O_partial / fake_LSE_partial above). Runtime callers that + # don't need them (split_kv=False) pass None; allocate small uninitialized + # dummy buffers so the kernel signature still matches without launching + # torch fill kernels. + if O_partial is None: + # Reuse caller-cached dummy when available (e.g. the + # SparseDecodePagedAttentionWrapper plan() pre-allocation), else + # allocate a small placeholder on the fly. + O_partial_kernel = ( + O_partial_dummy + if O_partial_dummy is not None + else torch.empty( + (1, q.shape[2]), dtype=torch.float32, device=q.device) + ) + else: + O_partial_kernel = O_partial.view( + O_partial.shape[0] * O_partial.shape[1], O_partial.shape[2]) + if LSE_partial is None: + LSE_partial = ( + LSE_partial_dummy + if LSE_partial_dummy is not None + else torch.empty( + (1, q.shape[1]), dtype=torch.float32, device=q.device) + ) + with torch.cuda.nvtx.range("Decode_Attention"): + _atten_compile_cache[key]( + q_4d, k, v, page_table, seqused_k, + request_indices, qo_tile_indices, kv_tile_indices, block_valid_mask, + split_counts, o_indptr, + out_2d, lse, O_partial_kernel, LSE_partial, + softmax_scale, seqlen_q, page_size, kv_chunk_size_pages, + ) + + +__all__ = ["SparseDecodeAttentionForwardSm100", "run_decode_attention"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bab26c200fff9c62644849b18e55f060fa8783f --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Paged decode split-KV scheduling backed by the precompiled Torch op. + +The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is +built ahead of time by kernel-builder. The op returns the schedule arrays +plus a fixed-order scalar summary, which is reassembled into the schedule +dict here. +""" + +from __future__ import annotations + +import torch + +from ....._ops import ops + +# Order of the scalar summary returned by the op; must match +# csrc/build_decode_schedule.cu. +_SCALAR_KEYS = ( + "split_kv", + "cta_tile_q", + "num_q_tiles", + "kv_chunk_size_pages", + "kv_chunk_size_tokens", + "work_count", + "padded_work_count", + "partial_rows", + "max_split_count", + "max_grid_size", + "active_blocks_per_sm", + "num_sms", + "base_cta", +) + + +def build_decode_schedule( + seqused_k: torch.Tensor, + *, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: int = 0, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> dict[str, object]: + """GPU-only schedule build: single CUDA kernel produces all schedule + index arrays on device. Only a small summary tensor is D2H'd at the end + so the wrapper can size O_partial, pick the kernel grid, and choose + split/non-split compile path. + + ``max_seqlen_k`` is required as the host-side worst-case bound for + padding the work-tile arrays. + """ + + ( + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + kv_pages, + merge_indptr, + o_indptr, + scalars, + ) = ops.build_decode_schedule( + seqused_k, + int(page_size), + int(seqlen_q), + int(num_qo_heads), + int(num_kv_heads), + int(head_dim), + int(max_seqlen_k), + bool(enable_cuda_graph), + int(max_grid_size), + int(fixed_split_size), + bool(disable_split_kv), + ) + + raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars))) + raw["split_kv"] = bool(raw["split_kv"]) + raw["request_indices"] = request_indices + raw["qo_tile_indices"] = qo_tile_indices + raw["kv_tile_indices"] = kv_tile_indices + raw["block_valid_mask"] = block_valid_mask + raw["split_counts"] = split_counts + raw["kv_pages"] = kv_pages + raw["merge_indptr"] = merge_indptr + raw["o_indptr"] = o_indptr + + # The CUDA kernel writes into worst-case-padded buffers (size = + # batch * num_q_tiles * max_pages_global) but only the first + # ``padded_work_count`` entries are valid. Downstream consumers + # (tile_scheduler) take grid size from ``request_indices.shape[0]`` + # so we narrow the views to that count; the underlying allocation + # is unchanged so this is a view, no copy. + pad = int(raw["padded_work_count"]) + for key in ( + "request_indices", + "qo_tile_indices", + "kv_tile_indices", + "block_valid_mask", + ): + raw[key] = raw[key].narrow(0, 0, pad) + return raw + + +__all__ = ["build_decode_schedule"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..3d308bd26c281e744cc7289b1265d8192c1f39e7 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py @@ -0,0 +1,680 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""LDGSTS split-KV combine for paged decode attention.""" + +import math +from functools import partial +from typing import Type + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.cute.nvgpu import cpasync + +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map + + +class SparseDecodeForwardCombine: + """Combine split-KV decode partials with FA-style LDGSTS staging. + + ``mO_partial`` and ``mLSE_partial`` use the split-major padded layout: + ``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where + ``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``. + A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one + ``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM + via ``cpasync.CopyG2SOp`` before the split reduction. + """ + + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + *, + tile_m: int = 64, + k_block_size: int = 128, + max_splits: int = 4, + num_threads: int = 256, + stages: int = 2, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}" + ) + if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]: + raise TypeError(f"Unsupported output dtype: {dtype}") + if dtype_partial is not Float32: + raise TypeError("decode O_partial must be Float32") + if k_block_size != head_dim: + raise NotImplementedError("decode combine currently uses one D=128 k block") + if tile_m % 8 != 0: + raise ValueError("decode combine tile_m must be divisible by 8") + if max_splits < 1 or max_splits > 256: + raise ValueError("decode combine max_splits must be in [1, 256]") + + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.max_splits = max_splits + self.num_threads = num_threads + self.stages = stages + self.is_even_k = head_dim % k_block_size == 0 + + def _setup_attributes(self) -> None: + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 + if self.k_block_size % 128 == 0 + else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout + ) + + lse_copy_bits = Float32.width + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, cute.make_layout(1) + ) + + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.max_splits, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) + ) + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32 + mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] + mLSE: cute.Tensor, # [total_q, Hq] fp32 + seqlen_q: Int32, + q_tokens_per_group: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mO_partial.element_type is not Float32): + raise TypeError("decode O_partial tensor must be Float32") + if const_expr(mLSE_partial.element_type is not Float32): + raise TypeError("decode LSE_partial tensor must be Float32") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode LSE tensor must be Float32") + if const_expr(mO.element_type != self.dtype): + raise TypeError("decode O tensor dtype must match kernel dtype") + if const_expr(mSplitCounts.element_type is not Int32): + raise TypeError("decode split_counts tensor must be Int32") + if const_expr(mOIndptr.element_type is not Int32): + raise TypeError("decode o_indptr tensor must be Int32") + + mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [ + assume_tensor_aligned(t) + for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE) + ] + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.tile_m], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + total_q = mO.shape[0] + head_q = mO.shape[1] + batch = mSplitCounts.shape[0] + head_divmod = FastDivmodDivisor(head_q) + grid = ( + cute.ceil_div(seqlen_q * head_q, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mSplitCounts, + mOIndptr, + mO, + mLSE, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + head_divmod, + Int32(total_q), + Int32(head_q), + seqlen_q, + q_tokens_per_group, + ).launch( + grid=grid, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + head_divmod: FastDivmodDivisor, + total_q: Int32, + head_q: Int32, + seqlen_q: Int32, + q_tokens_per_group: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + + split_count = mSplitCounts[batch_idx] + q_stride = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + max_idx = seqlen_q * head_q + + if m_block * Int32(self.tile_m) < max_idx: + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + partial_base = mOIndptr[batch_idx] + q_idx + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < split_count: + partial_row = partial_base + si * q_stride + lse_ptr = ( + mLSE_partial.iterator + + Int64(partial_row) * Int64(head_q) + + Int64(q_head) + ) + lse_gmem_ptr = cute.make_ptr( + Float32, + lse_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + lse_src = cute.make_tensor(lse_gmem_ptr, (1,)) + cute.copy( + gmem_thr_copy_LSE, + lse_src, + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOqidx = cute.make_rmem_tensor(num_rows, Int32) + tOhidx = cute.make_rmem_tensor(num_rows, Int32) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] + idx = m_block * Int32(self.tile_m) + mi + if idx >= max_idx: + tOqidx[m] = Int32(0) + tOhidx[m] = -Int32(1) + else: + tOqidx[m], tOhidx[m] = divmod(idx, head_divmod) + + load_O_partial = partial( + self.load_O_partial, + mO_partial, + mOIndptr, + gmem_tiled_copy_O_partial, + tOsO_partial, + tOqidx, + tOhidx, + tOcO, + batch_idx, + q_stride, + split_count, + head_q, + k_block, + ) + + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < split_count: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + max_valid_idx = -Int32(1) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + + lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max + LOG2_E = Float32(math.log2(math.e)) + lse_sum_cur = Float32(0.0) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + (ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E, + fastmath=True, + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = ( + Float32(0.0) + if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur) + else cute.arch.rcp_approx(lse_sum_cur) + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + if mi < Int32(self.tile_m): + sMaxValidSplit[mi] = max_valid_split[m] + + if k_block == Int32(0): + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + q_abs = batch_idx * seqlen_q + q_idx + mLSE[q_abs, q_head] = lse_sum[m] + + cute.arch.sync_threads() + + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max( + thr_max_valid_split, + sMaxValidSplit[tOcO[0, m, 0][0]], + ) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(Float32(0.0)) + + stage_load = self.stages - 1 + stage_compute = 0 + for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4): + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] + + split_to_load = s + Int32(self.stages - 1) + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0): + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0): + q_abs = batch_idx * seqlen_q + tOqidx[m] + row_ptr = ( + mO.iterator + + ( + (Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + mO_partial: cute.Tensor, + mOIndptr: cute.Tensor, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOsO_partial: cute.Tensor, + tOqidx: cute.Tensor, + tOhidx: cute.Tensor, + tOcO: cute.Tensor, + batch_idx: Int32, + q_stride: Int32, + split_count: Int32, + head_q: Int32, + k_block: Int32, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= Int32(0): + if split < split_count: + partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m] + row_ptr = ( + mO_partial.iterator + + ( + (Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO_partial.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_partial_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_partial_row_copy = cute.tiled_divide( + mO_partial_row, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_row_copy[None, k_idx], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, None].fill(Float32(0.0)) + + +_combine_compile_cache: dict[tuple[object, ...], object] = {} + + +def _next_power_of_2(x: int) -> int: + return 1 << (max(int(x), 1) - 1).bit_length() + + +def run_decode_combine( + O_partial: torch.Tensor, + LSE_partial: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + *, + seqlen_q: int, + q_tokens_per_group: int, + max_split_count: int, +) -> None: + """Launch LDGSTS decode split-KV combine.""" + + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + if lse.dtype != torch.float32: + raise TypeError(f"lse must be torch.float32, got {lse.dtype}") + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_indptr.dtype != torch.int32: + raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}") + if out.ndim != 3 or O_partial.ndim != 3: + raise ValueError("decode combine expects O tensors with shape [rows, heads, D]") + if LSE_partial.ndim != 2 or lse.ndim != 2: + raise ValueError("decode combine expects LSE tensors with shape [rows, heads]") + if out.shape[1:] != O_partial.shape[1:]: + raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}") + if lse.shape != out.shape[:2]: + raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}") + if LSE_partial.shape != O_partial.shape[:2]: + raise ValueError( + f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}" + ) + if split_counts.ndim != 1 or o_indptr.ndim != 1: + raise ValueError("split_counts and o_indptr must be rank-1 tensors") + if o_indptr.shape != (split_counts.shape[0] + 1,): + raise ValueError( + f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)" + ) + seqlen_q = int(seqlen_q) + q_tokens_per_group = int(q_tokens_per_group) + if seqlen_q <= 0: + raise ValueError("seqlen_q must be positive") + if q_tokens_per_group <= 0: + raise ValueError("q_tokens_per_group must be positive") + if out.shape[0] != split_counts.shape[0] * seqlen_q: + raise ValueError( + f"out rows {out.shape[0]} must equal batch*seqlen_q " + f"{split_counts.shape[0]}*{seqlen_q}" + ) + + max_split_count = int(max_split_count) + if max_split_count <= 0: + raise ValueError("max_split_count must be positive") + if max_split_count > 256: + raise NotImplementedError( + f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}" + ) + max_splits = max(4, _next_power_of_2(max_split_count)) + tile_m = 64 + k_block_size = int(out.shape[-1]) + stages = 2 + + dtype = torch2cute_dtype_map[out.dtype] + key = ( + "decode_combine_ldgsts", + out.shape[-1], + dtype, + O_partial.dtype, + seqlen_q, + q_tokens_per_group, + tile_m, + k_block_size, + max_splits, + stages, + ) + if key not in _combine_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + partial_rows = cute.sym_int64() + head_q = cute.sym_int64() + head_dim = int(out.shape[-1]) + kernel = SparseDecodeForwardCombine( + dtype=dtype, + dtype_partial=Float32, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + max_splits=max_splits, + stages=stages, + ) + _combine_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4), + make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + Int32(seqlen_q), + Int32(q_tokens_per_group), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"): + _combine_compile_cache[key]( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q, + q_tokens_per_group, + ) + + +__all__ = ["SparseDecodeForwardCombine", "run_decode_combine"] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..13b487402bf52d008b7ff7edbe9d584f366256b9 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Decode-specific tile scheduler for paged fp8 attention. + +The pre-schedule step builds a dense worklist over decode KV chunks. Static +persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id. +CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``, +and maps the canceled CTA coordinate back to the same logical task space. +""" + +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ....quack.cute_dsl_utils import ParamsBase + +from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo + + +@dataclass +class DecodeTileSchedulerArguments(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + +class DecodeTileScheduler: + """Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks.""" + + @dataclass + class Params(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + num_heads_kv_divmod: FastDivmodDivisor + total_tasks: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + def __init__( + self, + params: Params, + task_idx: Int32, + clc_scheduler=None, + clc_pipeline=None, + clc_consumer_state=None, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ): + self.params = params + self._task_idx = task_idx + self._clc_scheduler = clc_scheduler + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + self._clc_response_ptr = clc_response_ptr + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: DecodeTileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1" + total_tasks = args.work_capacity * args.num_heads_kv + return DecodeTileScheduler.Params( + args.work_capacity, + args.num_heads_kv, + FastDivmodDivisor(args.num_heads_kv), + total_tasks, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + @staticmethod + def _clc_grid_shape(params: Params): + return ( + cute.round_up(params.work_capacity, params.cluster_shape_m), + params.num_heads_kv, + Int32(1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ) -> "DecodeTileScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + from cutlass.utils import ( + ClcDynamicPersistentTileScheduler, + ClcDynamicPersistentTileSchedulerParams, + ) + + cutlass_params = ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + block_idx = cute.arch.block_idx() + grid_dim = cute.arch.grid_dim() + clc_scheduler = ClcDynamicPersistentTileScheduler.create( + cutlass_params, + block_idx, + grid_dim, + clc_response_ptr, + ) + return DecodeTileScheduler( + params, + block_idx[0], + clc_scheduler, + clc_response_ptr=clc_response_ptr, + loc=loc, + ip=ip, + ) + + if const_expr(params.cluster_shape_m == 1): + task_idx = cute.arch.block_idx()[0] + else: + task_idx = cute.arch.cluster_idx()[0] + return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return DecodeTileScheduler._clc_grid_shape(params) + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m + grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + @cute.jit + def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo: + work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod) + return WorkTileInfo( + (Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)), + is_valid, + ) + + @cute.jit + def _clc_work_to_coords(self, work) -> WorkTileInfo: + work_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + work_idx = work_idx // self.params.cluster_shape_m + return WorkTileInfo( + ( + Int32(work_idx), + Int32(work.tile_idx[1]), + Int32(0), + Int32(0), + ), + work.is_valid_tile, + ) + + @cute.jit + def _clc_response_to_work( + self, + response_stage: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + # CLC responses are 16B opaque records. The scheduler warp can query + # the next stage before all consumer warps have read the current one, + # so each pipeline stage needs its own response slot. + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response( + response_ptr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32( + self.params.cluster_shape_m) + return WorkTileInfo( + ( + Int32(m_idx) + cta_idx_in_cluster, + Int32(n_idx), + Int32(l_idx), + Int32(0), + ), + is_valid, + ) + + @cute.jit + def get_current_work( + self, + response_stage: Int32 = Int32(0), + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_response_to_work( + response_stage, loc=loc, ip=ip) + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + is_valid = self._task_idx < self.params.total_tasks + return self._task_to_work(self._task_idx, is_valid) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_scheduler.initial_work_tile_info() + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work( + self, + *, + loc=None, + ip=None, + mbarrier_addr=None, + response_stage: Int32 = Int32(0), + ): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + assert mbarrier_addr is not None + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + with cute.arch.elect_one(): + cute.arch.issue_clc_query( + mbarrier_addr, response_ptr, loc=loc, ip=ip) + else: + assert mbarrier_addr is None + if const_expr(self.params.cluster_shape_m == 1): + self._task_idx += cute.arch.grid_dim()[0] + else: + self._task_idx += cute.arch.cluster_dim()[0] + + def consumer_advance(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + response_stage = self._clc_consumer_state.index + self._clc_pipeline.consumer_wait(self._clc_consumer_state) + work_tile = self.get_current_work(response_stage=response_stage) + self._clc_pipeline.consumer_release(self._clc_consumer_state) + self._clc_consumer_state.advance() + return work_tile + self.advance_to_next_work() + return self.get_current_work() + + def set_clc_pipeline(self, clc_pipeline, clc_consumer_state): + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return DecodeTileScheduler(*obj_list, loc=self._loc) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py new file mode 100644 index 0000000000000000000000000000000000000000..8e59b3d55bd3e9b164dac1e474dd648501c1aa51 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse k2q CSR builder for SM100. + +Thin dispatcher that calls the CUDA C++ kernel pipeline in +``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and +``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than +silently falling back to a torch-reference path. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL + + +_SUPPORTED_TOPK = (4, 8, 16, 32) +_SUPPORTED_BLK_KV = 128 + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +class SparseK2qCsrBuilderSm100: + """Build the k2q CSR reverse index for sparse attention on SM100. + + The public API matches the historical CUTE DSL builder so callers + (``sparse_index_utils.build_k2q_csr``, attention kernels) need no + changes. Internally the kernel pipeline runs five CUDA C++ kernels: + ``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem`` + -> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``). + """ + + def __init__(self) -> None: + # No persistent state — the JIT-compiled extension is loaded + # lazily by ``src.sm100.build_k2q_csr`` on first call. + self._run = None + self._run_with_schedule = None + + def _ensure_loaded(self) -> None: + if self._run is None: + from ...src.sm100.build_k2q_csr import ( + run_build_k2q_csr, + run_build_k2q_csr_with_schedule, + ) + self._run = run_build_k2q_csr + self._run_with_schedule = run_build_k2q_csr_with_schedule + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + *, + total_k: int, + blk_kv: int = 128, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]: + # ---- Validation ---------------------------------------------------- + if blk_kv != _SUPPORTED_BLK_KV: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports blk_kv == " + f"{_SUPPORTED_BLK_KV}, got {blk_kv}" + ) + if q2k_indices.dtype != torch.int32: + raise TypeError( + f"q2k_indices must be torch.int32, got {q2k_indices.dtype}" + ) + if q2k_indices.ndim != 3: + raise ValueError( + f"q2k_indices must be rank-3 [head_kv, total_q, topK], " + f"got shape {tuple(q2k_indices.shape)}" + ) + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous") + if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32") + if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must share shape [B + 1]" + ) + if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda): + raise ValueError("all inputs must be CUDA tensors") + if ( + q2k_indices.device != cu_seqlens_q.device + or q2k_indices.device != cu_seqlens_k.device + ): + raise ValueError("all inputs must share a device") + if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous(): + raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous") + + total_k = int(total_k) + if total_k < 0: + raise ValueError(f"total_k must be non-negative, got {total_k}") + + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + if topk not in _SUPPORTED_TOPK: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports topK in " + f"{_SUPPORTED_TOPK}, got {topk}" + ) + + batch = int(cu_seqlens_q.shape[0] - 1) + if batch < 0: + raise ValueError("cu_seqlens tensors must have shape [B + 1]") + if return_schedule and max_seqlen_k is None: + raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True") + max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k + max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv) + if total_rows is not None: + total_rows = int(total_rows) + elif total_k % blk_kv == 0: + total_rows = total_k // blk_kv + else: + total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv) + if total_rows < 0: + raise ValueError(f"total_rows must be non-negative, got {total_rows}") + total_rows = max(total_rows, 0) + nnz_upper_bound = total_q * topk + qhead_per_kv = int(qhead_per_kv) + if qhead_per_kv <= 0: + raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}") + if return_schedule: + if max_seqlen_q is None: + raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True") + max_seqlen_q = int(max_seqlen_q) + + # ---- Output tensors ------------------------------------------------ + device = q2k_indices.device + k2q_row_ptr = torch.empty( + (head_kv, total_rows + 1), dtype=torch.int32, device=device, + ) + k2q_q_indices = torch.empty( + (head_kv, nnz_upper_bound), dtype=torch.int32, device=device, + ) + schedule = None + if return_schedule: + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), dtype=torch.int32, device=device + ) + work_count = torch.empty((1,), dtype=torch.int32, device=device) + qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.empty( + (total_q, head_kv), dtype=torch.int32, device=device + ) + schedule = SparseAttentionSchedule( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + qsplit_indices=qsplit_indices, + split_counts=split_counts, + target_q_per_cta=target_q_per_cta, + ) + + # Empty workload short-circuit (the CUDA path also handles this, + # but doing it here saves a JIT load for trivial calls). + if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0: + k2q_row_ptr.zero_() + k2q_q_indices.fill_(-1) + if schedule is not None: + schedule.work_count.zero_() + schedule.split_counts.zero_() + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices + + self._ensure_loaded() + with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"): + if schedule is None: + self._run( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + topk, + blk_kv, + total_rows, + max_kv_blocks, + ) + else: + self._run_with_schedule( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + schedule.scheduler_metadata, + schedule.work_count, + schedule.qsplit_indices, + schedule.split_counts, + topk, + blk_kv, + total_rows, + max_kv_blocks, + schedule.target_q_per_cta, + schedule.work_capacity, + max_seqlen_q, + ) + if schedule is not None: + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices diff --git a/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_scheduler.py b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..662e48f905249913a381f5d11a3f0c49626e98bd --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_scheduler.py @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Prepare scheduler for SM100 sparse attention. + +The scheduler converts uneven CSR k2q row fanout into a flat worklist consumed +by sparse attention kernels. Each work item covers a contiguous q-index range +within one (head_kv, csr row) and carries the decoded batch/KV-block coordinate. +""" + +from dataclasses import dataclass +from typing import Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32, const_expr + +from ...src.common import copy_utils, utils +from ...src.common.cute_dsl_utils import ( + assume_tensor_aligned, + to_cute_tensor as to_cute_tensor_kvouter, +) + + +_PREPARE_COMPILE_CACHE: dict = {} + + +@dataclass +class SparseAttentionSchedule: + enabled: bool + scheduler_metadata: Optional[torch.Tensor] + work_count: Optional[torch.Tensor] + qsplit_indices: Optional[torch.Tensor] = None + split_counts: Optional[torch.Tensor] = None + target_q_per_cta: int = 0 + + @property + def work_capacity(self) -> int: + return 0 if self.scheduler_metadata is None else int(self.scheduler_metadata.shape[0]) + + +SparseSchedulePlan = SparseAttentionSchedule + + +class SparseAttentionScheduleModel: + """Host-side helpers for sparse attention schedule sizing.""" + + @staticmethod + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + @staticmethod + def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + def _target_q_per_cta( + self, + *, + total_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + num_sm = torch.cuda.get_device_properties(device).multi_processor_count + if usable_SM_count > 0: + num_sm = min(int(usable_SM_count), num_sm) + q_tokens_per_group = 128 // qhead_per_kv + total_refs_upper = total_q * topk * head_kv + desired_work_items = max(num_sm * 2, 1) + total_groups_upper = self._ceil_div(max(total_refs_upper, 1), q_tokens_per_group) + target_groups_per_cta = min( + 512, + max(1, self._ceil_div(total_groups_upper, desired_work_items)), + ) + return target_groups_per_cta * q_tokens_per_group + + def balanced_target_q_per_cta( + self, + *, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + q_tokens_per_group = 128 // qhead_per_kv + occupancy_target = self._target_q_per_cta( + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + sink_balance_cap = max(q_tokens_per_group, int(topk) * int(blk_kv) * 2) + target = min(max(occupancy_target, q_tokens_per_group), sink_balance_cap) + return self._round_up(target, q_tokens_per_group) + + def flat_schedule_capacity( + self, + *, + total_rows: int, + total_q: int, + topk: int, + head_kv: int, + target_q_per_cta: int, + ) -> int: + row_upper = max(total_rows, 0) * max(head_kv, 1) + refs_upper = max(total_q, 0) * max(topk, 1) * max(head_kv, 1) + split_upper = self._ceil_div(max(refs_upper, 1), max(target_q_per_cta, 1)) + return max(1, row_upper + split_upper) + + +SPARSE_SCHEDULE_MODEL = SparseAttentionScheduleModel() + + +class SparseAttentionPrepareFlatScheduleSm100: + """Build a compact flat worklist by splitting each CSR row into chunks.""" + + def __init__( + self, + *, + num_threads: int = 128, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + self.warps_per_cta = num_threads // 32 + + @cute.jit + def _emit_work( + self, + mSchedulerMetadata: cute.Tensor, + work_idx: Int32, + work_capacity: Int32, + head_kv_idx: Int32, + row_linear: Int32, + q_begin: Int32, + q_count: Int32, + batch_idx: Int32, + kv_block_idx: Int32, + ): + if work_idx < work_capacity: + mSchedulerMetadata[work_idx, Int32(0)] = head_kv_idx + mSchedulerMetadata[work_idx, Int32(1)] = row_linear + mSchedulerMetadata[work_idx, Int32(2)] = q_begin + mSchedulerMetadata[work_idx, Int32(3)] = q_count + mSchedulerMetadata[work_idx, Int32(4)] = batch_idx + mSchedulerMetadata[work_idx, Int32(5)] = kv_block_idx + + @cute.jit + def _rows_in_batch( + self, + mCuSeqlensK: cute.Tensor, + batch_idx: Int32, + blk_kv: Int32, + ) -> Int32: + seqlen = mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + return (seqlen + blk_kv - Int32(1)) // blk_kv + + @cute.jit + def _rows_before_level( + self, + mCuSeqlensK: cute.Tensor, + level: Int32, + blk_kv: Int32, + ) -> Int32: + total = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + total += cutlass.min(rows, level) + return total + + @cute.jit + def _max_rows_per_batch( + self, + mCuSeqlensK: cute.Tensor, + blk_kv: Int32, + ) -> Int32: + max_rows = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + max_rows = cutlass.max(max_rows, rows) + return max_rows + + @cute.jit + def _decode_sparse_row_linear( + self, + mCuSeqlensK: cute.Tensor, + row_linear: Int32, + blk_kv: Int32, + ) -> tuple[Int32, Int32]: + lo = Int32(0) + hi = self._max_rows_per_batch(mCuSeqlensK, blk_kv) + while lo < hi: + mid = (lo + hi) // Int32(2) + rows_before_next = self._rows_before_level( + mCuSeqlensK, + mid + Int32(1), + blk_kv, + ) + if rows_before_next <= row_linear: + lo = mid + Int32(1) + else: + hi = mid + + level = lo + offset = row_linear - self._rows_before_level(mCuSeqlensK, level, blk_kv) + active_idx = Int32(0) + batch_idx = Int32(0) + found = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + if found == Int32(0): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + if rows > level: + if active_idx == offset: + batch_idx = b + found = Int32(1) + active_idx += Int32(1) + return batch_idx, level + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + blk_kv: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mCuSeqlensK.element_type != Int32): + raise TypeError("mCuSeqlensK must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount = [ + assume_tensor_aligned(t) + for t in (mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount) + ] + total_rows = mK2qCounts.shape[1] - Int32(1) + total_row_heads = total_rows * num_heads_kv + grid_ctas = cute.ceil_div(total_row_heads, self.warps_per_cta) + + self.kernel( + mK2qCounts, + mCuSeqlensK, + mSchedulerMetadata, + mWorkCount, + target_q_per_cta, + work_capacity, + num_heads_kv, + total_rows, + blk_kv, + ).launch( + grid=(grid_ctas,), + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + total_rows: Int32, + blk_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + lane_idx = tidx % Int32(32) + warp_idx = tidx // Int32(32) + row_head_idx = block_idx * Int32(self.warps_per_cta) + warp_idx + total_row_heads = total_rows * num_heads_kv + + head_kv_idx = Int32(0) + row_linear = Int32(0) + row_count = Int32(0) + num_chunks = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + if row_head_idx < total_row_heads: + row_linear = row_head_idx // num_heads_kv + head_kv_idx = row_head_idx - row_linear * num_heads_kv + if lane_idx == Int32(0): + row_start = mK2qCounts[head_kv_idx, row_linear] + row_end = mK2qCounts[head_kv_idx, row_linear + Int32(1)] + row_count = row_end - row_start + batch_idx, kv_block_idx = self._decode_sparse_row_linear( + mCuSeqlensK, + row_linear, + blk_kv, + ) + if row_count > Int32(0): + num_chunks = ( + row_count + target_q_per_cta - Int32(1) + ) // target_q_per_cta + row_count = cute.arch.shuffle_sync(row_count, offset=0) + num_chunks = cute.arch.shuffle_sync(num_chunks, offset=0) + batch_idx = cute.arch.shuffle_sync(batch_idx, offset=0) + kv_block_idx = cute.arch.shuffle_sync(kv_block_idx, offset=0) + + chunk_idx = lane_idx + while chunk_idx < num_chunks: + work_idx = cute.arch.atomic_add( + mWorkCount.iterator.llvm_ptr, + Int32(1), + sem="relaxed", + scope="gpu", + ) + q_begin = chunk_idx * target_q_per_cta + q_count = cutlass.min(target_q_per_cta, row_count - q_begin) + self._emit_work( + mSchedulerMetadata, + work_idx, + work_capacity, + head_kv_idx, + row_linear, + q_begin, + q_count, + batch_idx, + kv_block_idx, + ) + chunk_idx += Int32(32) + + +class SparseAttentionPrepareFwdSplitAtomicSm100: + """Build packed q_idx/split_slot metadata for fwd K1 without K1 atomics.""" + + def __init__( + self, + *, + num_threads: int = 256, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + + @cute.struct + class SharedStorage: + sRow: cute.struct.MemRange[Int32, 3] + + self.shared_storage = SharedStorage + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + work_capacity: Int32, + max_seqlen_q: Int32, + topk: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mK2qIndices.element_type != Int32): + raise TypeError("mK2qIndices must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + if const_expr(mK2qQSplitIndices.element_type != Int32): + raise TypeError("mK2qQSplitIndices must be Int32") + if const_expr(mSplitCounts.element_type != Int32): + raise TypeError("mSplitCounts must be Int32") + if const_expr(mCuSeqlensQ.element_type != Int32): + raise TypeError("mCuSeqlensQ must be Int32") + ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) = [ + assume_tensor_aligned(t) + for t in ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) + ] + self.kernel( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + max_seqlen_q, + topk, + ).launch( + grid=(work_capacity,), + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + max_seqlen_q: Int32, + topk: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + if block_idx < mWorkCount[Int32(0)]: + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sRow = storage.sRow.get_tensor(cute.make_layout((3,))) + head_kv_idx = mSchedulerMetadata[block_idx, Int32(0)] + row_linear = mSchedulerMetadata[block_idx, Int32(1)] + q_begin = mSchedulerMetadata[block_idx, Int32(2)] + q_count = mSchedulerMetadata[block_idx, Int32(3)] + batch_idx_t0 = mSchedulerMetadata[block_idx, Int32(4)] + + if tidx == Int32(0): + row_start_t0 = mK2qCounts[head_kv_idx, row_linear] + q_begin + sRow[0] = row_start_t0 + sRow[1] = q_count + sRow[2] = batch_idx_t0 + cute.arch.barrier() + row_start = sRow[0] + row_count = sRow[1] + batch_idx = sRow[2] + qi = tidx + while qi < row_count: + edge = row_start + qi + q_idx = mK2qIndices[head_kv_idx, edge] + if q_idx >= Int32(0) and q_idx < max_seqlen_q: + q_abs = mCuSeqlensQ[batch_idx] + q_idx + split_ptr = utils.elem_pointer( + mSplitCounts, + (q_abs, head_kv_idx), + ) + split_slot = copy_utils.atomic_add_i32(split_ptr) + if split_slot < topk: + mK2qQSplitIndices[head_kv_idx, edge] = ( + q_idx | ((split_slot & Int32(0xFF)) << Int32(24)) + ) + qi += Int32(self.num_threads) + + +def _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + work_capacity: int, + max_seqlen_q: int, + topk: int, +): + key = ( + "sparse_prepare_fwd_split_atomic_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFwdSplitAtomicSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(split_counts), + to_cute_tensor_kvouter(cu_seqlens_q), + Int32(work_capacity), + Int32(max_seqlen_q), + Int32(topk), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def _get_sparse_prepare_flat_schedule( + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + target_q_per_cta: int, + scheduler_metadata_capacity: int, + head_kv: int, + blk_kv: int, +): + key = ( + "sparse_prepare_flat_schedule_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFlatScheduleSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(cu_seqlens_k), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + Int32(target_q_per_cta), + Int32(scheduler_metadata_capacity), + Int32(head_kv), + Int32(blk_kv), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def prepare_sparse_flat_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + if not enabled: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + + total_rows = int(k2q_row_ptr.shape[1] - 1) + if total_rows <= 0 or head_kv <= 0: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), + dtype=torch.int32, + device=device, + ) + work_count = torch.zeros((1,), dtype=torch.int32, device=device) + scheduler_metadata.zero_() + + compiled_prepare = _get_sparse_prepare_flat_schedule( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFlatSchedule"): + compiled_prepare( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + + return SparseSchedulePlan( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + target_q_per_cta=target_q_per_cta, + ) + +def prepare_sparse_fwd_schedule_and_split( + *, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + max_seqlen_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + blk_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + plan = prepare_sparse_fwd_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=blk_kv, + device=device, + enabled=enabled, + usable_SM_count=usable_SM_count, + ) + if not plan.enabled: + return plan + if plan.scheduler_metadata is None or plan.work_count is None: + raise RuntimeError("fwd GPU schedule requires metadata") + if topk > 255: + raise ValueError(f"packed qsplit metadata supports topK <= 255, got {topk}") + if max_seqlen_q >= (1 << 24): + raise ValueError( + "packed qsplit metadata supports batch-local q_idx < 2^24, " + f"got max_seqlen_q={max_seqlen_q}" + ) + if k2q_qsplit_indices.shape != k2q_q_indices.shape: + raise ValueError("k2q_qsplit_indices shape must match k2q_q_indices") + if split_counts.dtype != torch.int32 or k2q_qsplit_indices.dtype != torch.int32: + raise TypeError("split metadata tensors must be torch.int32") + if split_counts.shape != (total_q, head_kv): + raise ValueError( + f"split_counts must have shape ({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if cu_seqlens_q.dtype != torch.int32: + raise TypeError("cu_seqlens_q must be torch.int32") + if cu_seqlens_q.ndim != 1 or not cu_seqlens_q.is_contiguous(): + raise ValueError("cu_seqlens_q must be a contiguous rank-1 tensor") + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + with torch.cuda.nvtx.range("SparseAttention_InitFwdSplitState"): + split_counts.zero_() + + compiled_split = _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFwdSplit_Atomic"): + compiled_split( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + plan.qsplit_indices = k2q_qsplit_indices + plan.split_counts = split_counts + return plan + + +def prepare_sparse_fwd_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + return prepare_sparse_flat_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=int(total_q), + topk=int(topk), + blk_kv=int(blk_kv), + head_kv=int(head_kv), + qhead_per_kv=int(qhead_per_kv), + device=device, + enabled=bool(enabled), + usable_SM_count=int(usable_SM_count), + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7d5e4ade468de366bb73eed0ccb38d4e358cf8 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100. + +Hub-kernel packaging of the CuTe-DSL sparse attention stack from +https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The +host-side helper kernels (CSR builder, decode scheduler) are precompiled +Torch ops; the attention kernels are compiled at runtime through +nvidia-cutlass-dsl. +""" + +# Sparse attention forward / decode. +from .interface import ( + SparseDecodePagedAttentionWrapper, + sparse_atten_func, + sparse_atten_nvfp4_kv_func, + sparse_decode_atten_func, +) + +# CSR + schedule construction. +from .sparse_index_utils import build_k2q_csr + +# SM100 fused CSR builder. +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + +# FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK +# selection + q2k construction remain caller-owned downstream steps. +from .fp4_indexer_interface import fp4_indexer_block_scores + +# NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention. +from .quantize import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + quantize_bf16_to_nvfp4_128x4, + quantize_kv_bf16_to_nvfp4_128x4, + swizzle_nvfp4_scale_to_128x4, +) + +__version__ = "0.1.1" + +__all__ = [ + # attention + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "SparseDecodePagedAttentionWrapper", + # indexing / CSR + "fp4_indexer_block_scores", + "build_k2q_csr", + "SparseK2qCsrBuilderSm100", + # nvfp4 quantization helpers + "Nvfp4QuantizedTensor", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "dequantize_nvfp4_128x4_to_bf16", + "swizzle_nvfp4_scale_to_128x4", + "nvfp4_global_scale_from_amax", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..1c4b3c9ad43725738ab93fba902ca3d506a807d7 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57ddba49bfab6c891491975df75662cfbd54b4aa379af6f75d91c11b90b70b31 +size 1383480 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6be2da4d5d784683e9e2fb8bfe08e93847dc6640 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _msa_cuda_09d7851 +ops = torch.ops._msa_cuda_09d7851 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_msa_cuda_09d7851::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py b/build/torch211-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..48dc1d05480355d2af4f4e47142ae4cd692184b0 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py @@ -0,0 +1,1061 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Public FP4 sparse-attention indexer block-score interface.""" + +from __future__ import annotations + +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32 +from cutlass.cute.runtime import make_ptr + +from .src.sm100.fp4_indexer import ( + Fp4FormatSpec, + Fp4IndexerDecodePackedQSm100, + Fp4IndexerDecodeQPackSm100, + Fp4IndexerScaleReorderSm100, + Fp4IndexerStagedMmaSm100, + _BLOCK_K, + _DECODE_K_TILES_PER_CTA, + _DECODE_PACK_Q_LEN, + _DECODE_QHEAD_PER_KV, + _FP4_PACKED_D_BYTES, + _HEAD_DIM, + _MMA_TILER_MN, + _PAGE_SIZE, + ceil_div, + k_tiles_per_cta_for, + normalize_fp4_format, +) + + +_PUBLIC_SCALE_LAYOUT = "public" +_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma" +_FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _device_arch(device: torch.device) -> tuple[int, int]: + major, minor = torch.cuda.get_device_capability(device) + return int(major), int(minor) + + +def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool: + return device_arch >= (10, 3) + + +def normalize_scale_layout(scale_layout: str) -> str: + """Normalize and validate FP4 indexer scale layout mode. + + Parameters + ---------- + scale_layout : str + Either ``"public"`` for logical scale tensors or ``"preordered_mma"`` + for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``. + + Returns + ------- + str + The normalized scale layout string. + """ + + scale_layout = str(scale_layout) + if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT): + raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}") + return scale_layout + + +def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int: + if q_len <= 0 or k_len <= 0: + return 0 + q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0]) + k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta) + group_tokens = k_tiles_per_cta * _BLOCK_K + causal_offset = int(k_len) - int(q_len) + tasks = 0 + for q_tile_idx in range(q_tile_count): + q_tile_start = q_tile_idx * _MMA_TILER_MN[0] + q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1) + visible_limit = q_tile_last + causal_offset + if visible_limit >= 0: + tasks += min(k_group_count, visible_limit // group_tokens + 1) + return tasks + + +def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int: + """Conservative X-grid bound for per-batch causal prefill compact mapping.""" + + if max_q_len <= 0 or max_k_len <= 0: + return 0 + q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0]) + candidates = {int(max_q_len)} + for q_tile_idx in range(q_tile_count): + q_len = q_tile_idx * _MMA_TILER_MN[0] + 1 + if q_len <= max_q_len: + candidates.add(q_len) + return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates) + + +def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None: + if tensor.device != device: + raise ValueError(f"{name} must be on the same CUDA device") + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None: + fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None) + allowed = {torch.uint8, torch.int8} + if fp4_x2_dtype is not None: + allowed.add(fp4_x2_dtype) + if tensor.dtype not in allowed: + raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}") + + +def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 3: + raise ValueError(f"{name} must have shape [total_q, Hq, 64]") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 4: + raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]") + if int(tensor.shape[-2]) != _PAGE_SIZE: + raise ValueError(f"{name}.shape[-2] must be 128") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def validate_q_scale_thg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + total_q: int, + heads: int, +) -> None: + """Validate public Q FP4 scale layout ``[total_q, Hq, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical Q scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + total_q : int + Total query token count. + heads : int + Number of Q heads. + """ + + expected = (int(total_q), int(heads), fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def validate_k_scale_phsg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + page_count: int, + heads: int, +) -> None: + """Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical K scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + page_count : int + Number of physical KV pages. + heads : int + Number of KV heads. + """ + + expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``.""" + + spec = normalize_fp4_format(fp4_format) + return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l)) + + +def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g) + + +def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return contiguous storage shape for preordered MMA scales.""" + + spec = normalize_fp4_format(fp4_format) + return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4) + + +def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_storage_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1) + + +def validate_mma_scale_storage( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + mn: int, + l: int, +) -> None: + """Validate preordered MMA scale storage expected by the FP4 indexer. + + Parameters + ---------- + scale : torch.Tensor + Tensor view whose shape/stride should match + ``fp4_indexer_mma_scale_storage_shape`` and + ``fp4_indexer_mma_scale_storage_stride``. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + mn : int + Logical M/N extent of the scale domain. + l : int + Logical batch/head extent folded into the final layout dimension. + """ + + expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name) + expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name) + if tuple(scale.shape) != expected_shape: + raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}") + if tuple(scale.stride()) != expected_stride: + raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + + +def _empty_mma_scale_tensor( + *, + mn: int, + l: int, + spec: Fp4FormatSpec, + device: torch.device, +) -> torch.Tensor: + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + storage = torch.empty( + (int(l), rest_m, rest_g, 32, 4, 4), + dtype=spec.torch_scale_dtype, + device=device, + ) + return storage.permute(3, 4, 1, 5, 2, 0) + + +def _compile_fp4_scale_reorder_kernel( + *, + fmt: Fp4FormatSpec, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_scale_reorder_sm100_1cta", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_reorder_scales_for_mma_cute( + q_scale: torch.Tensor, + k_scale: torch.Tensor, + *, + fp4_format: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reorder public Q/K FP4 scales to MMA-friendly storage. + + Parameters + ---------- + q_scale : torch.Tensor + Public Q scale tensor with shape ``[total_q, Hq, G]``. + k_scale : torch.Tensor + Public K scale tensor with shape ``[page_count, Hk, 128, G]``. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by + ``validate_mma_scale_storage``. These tensors can be passed to + ``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``. + """ + + spec = normalize_fp4_format(fp4_format) + if q_scale.device != k_scale.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device") + _require_cuda_tensor(q_scale, name="q_scale") + _require_cuda_tensor(k_scale, name="k_scale") + if q_scale.ndim != 3: + raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}") + if k_scale.ndim != 4: + raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}") + total_q, heads_q, _ = (int(v) for v in q_scale.shape) + page_count, heads_k, _, _ = (int(v) for v in k_scale.shape) + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + + q_scale_mma = _empty_mma_scale_tensor( + mn=total_q, + l=heads_q, + spec=spec, + device=q_scale.device, + ) + k_scale_mma = _empty_mma_scale_tensor( + mn=_PAGE_SIZE, + l=page_count * heads_k, + spec=spec, + device=k_scale.device, + ) + + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + q_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + problem_size = ( + Int32(total_q), + Int32(heads_q), + Int32(page_count), + Int32(heads_k), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream) + compiled = _compile_fp4_scale_reorder_kernel( + fmt=spec, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + q_scale_mma_ptr=q_scale_mma_ptr, + k_scale_mma_ptr=k_scale_mma_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return q_scale_mma, k_scale_mma + + +def _compile_fp4_decode_q_pack_kernel( + *, + fmt: Fp4FormatSpec, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_q_pack_sm100", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _pack_decode_q_for_mma( + q_bytes: torch.Tensor, + q_scale_storage: torch.Tensor, + cu_seqlens_q: torch.Tensor, + *, + fmt: Fp4FormatSpec, + heads_q: int, + heads_k: int, + batch: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_pack = torch.empty( + (batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + dtype=torch.uint8, + device=q_bytes.device, + ) + q_scale_pack = torch.empty( + fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name), + dtype=fmt.torch_scale_dtype, + device=q_bytes.device, + ) + if q_pack.data_ptr() % 128 != 0: + raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA") + if q_scale_pack.data_ptr() % 32 != 0: + raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned") + q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(q_bytes.shape[0]), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream) + compiled = _compile_fp4_decode_q_pack_kernel( + fmt=fmt, + q_ptr=q_ptr, + q_scale_ptr=q_scale_ptr, + q_pack_ptr=q_pack_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return q_pack, q_scale_pack + + +def _compile_fp4_decode_packed_q_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_packed_q_sm100", + fmt.name, + bool(causal), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodePackedQSm100( + fmt=fmt.name, + causal=causal, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _run_fp4_decode_packed_q_scores( + q_pack: torch.Tensor, + k_bytes: torch.Tensor, + q_scale_pack: torch.Tensor, + k_scale_storage: torch.Tensor, + scores: torch.Tensor, + kv_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + qo_offset_arg: torch.Tensor, + *, + fmt: Fp4FormatSpec, + causal: bool, + has_qo_offset: int, + heads_q: int, + heads_k: int, + batch: int, + max_k_tiles: int, + total_q: int, + device_arch: tuple[int, int], + use_tmem_load_red: bool, +) -> None: + page_count = int(k_bytes.shape[0]) + rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA) + compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA) + compact_schedule = compact_groups < rectangular_groups + if compact_schedule: + scores.fill_(float("-inf")) + + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + k_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + problem_size = ( + Int32(_PAGE_SIZE), + Int32(max_k_tiles * _PAGE_SIZE), + Int32(_HEAD_DIM), + Int32(batch * heads_k), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream) + compiled = _compile_fp4_decode_packed_q_kernel( + fmt=fmt, + causal=causal, + compact_schedule=compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_pack_ptr=q_pack_ptr, + k_ptr=k_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + + +def _compile_fp4_qk_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + preordered_q_scale_tma: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_staged_mma_sm100", + fmt.name, + bool(causal), + bool(preordered_q_scale_tma), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerStagedMmaSm100( + fmt=fmt.name, + causal=causal, + preordered_q_scale_tma=preordered_q_scale_tma, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_block_scores( + q_fp4: torch.Tensor, + k_fp4: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + *, + max_seqlen_q: int, + max_seqlen_k: int, + kv_indices: torch.Tensor, + fp4_format: str, + causal: bool = False, + qo_offset: Optional[torch.Tensor] = None, + scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT, +) -> torch.Tensor: + """Return FP4 QK max scores per 128-token KV page. + + Parameters + ---------- + q_fp4 : torch.Tensor + Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last + dimension stores two FP4 values per byte for logical head dimension + 128. + k_fp4 : torch.Tensor + Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``. + q_scale : torch.Tensor + Q scale tensor. With ``scale_layout="public"``, shape is + ``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use + ``fp4_indexer_reorder_scales_for_mma_cute`` output layout. + k_scale : torch.Tensor + K scale tensor. With ``scale_layout="public"``, shape is + ``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the + preordered MMA scale layout. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + cu_page_offsets : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request + page counts. + max_seqlen_q : int + Maximum Q sequence length. + max_seqlen_k : int + Maximum KV sequence length. + kv_indices : torch.Tensor + Flattened physical page indices with shape ``[sum_pages]`` and dtype + int32. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + causal : bool, optional + Whether to apply causal masking. + qo_offset : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid + only when ``causal=True``. + scale_layout : str, optional + ``"public"`` accepts logical public scale tensors and launches a scale + reorder kernel. ``"preordered_mma"`` expects preordered MMA scale + tensors and skips the reorder. + + Returns + ------- + torch.Tensor + Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32. + Entries beyond the valid KV page range are ``-inf``. + """ + + spec = normalize_fp4_format(fp4_format) + causal = bool(causal) + scale_layout = normalize_scale_layout(scale_layout) + use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE + q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4") + k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4") + total_q, heads_q, _ = (int(v) for v in q_bytes.shape) + page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape) + if page_size != _PAGE_SIZE: + raise ValueError(f"k_fp4 page_size must be 128, got {page_size}") + if heads_q % heads_k != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + _require_cuda_tensor(q_fp4, name="q_fp4") + _require_cuda_tensor(k_fp4, name="k_fp4") + device_arch = _device_arch(q_fp4.device) + use_tmem_load_red = _supports_tmem_load_red(device_arch) + _require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device) + _require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device) + _require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device) + if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4") + if scale_layout == _PUBLIC_SCALE_LAYOUT: + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + else: + validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q) + validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k) + batch = int(cu_seqlens_q.shape[0]) - 1 + if batch < 0: + raise ValueError("cu_seqlens_q must have shape [B + 1]") + if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape: + raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]") + if q_bytes.data_ptr() % 128 != 0: + raise ValueError("q_fp4 data pointer must be 128B aligned for TMA") + if k_bytes.data_ptr() % 128 != 0: + raise ValueError("k_fp4 data pointer must be 128B aligned for TMA") + if kv_indices is None: + raise ValueError("kv_indices is required") + if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1: + raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device") + if not kv_indices.is_contiguous(): + raise ValueError("kv_indices must be contiguous") + if qo_offset is not None: + if not causal: + raise ValueError("qo_offset is only valid when causal=True") + if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,): + raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device") + if not qo_offset.is_contiguous(): + raise ValueError("qo_offset must be contiguous") + + m_extent = int(max_seqlen_q) + max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE) + n_aligned = max_k_tiles * _PAGE_SIZE + if max_k_tiles == 0: + return torch.full( + (heads_q, 0, total_q), + float("-inf"), + dtype=torch.float32, + device=q_fp4.device, + ) + + scores = torch.empty( + (heads_q, max_k_tiles, total_q), + dtype=torch.float32, + device=q_fp4.device, + ) + if qo_offset is None: + qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device) + has_qo_offset = 0 + else: + qo_offset_arg = qo_offset + has_qo_offset = 1 + if scale_layout == _PUBLIC_SCALE_LAYOUT: + q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute( + q_scale, + k_scale, + fp4_format=spec.name, + ) + else: + q_scale_arg = q_scale + k_scale_arg = k_scale + scale_assumed_align = 32 + if q_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + if k_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV + if use_decode_packed_q: + q_pack, q_scale_pack = _pack_decode_q_for_mma( + q_bytes, + q_scale_arg, + cu_seqlens_q, + fmt=spec, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + ) + _run_fp4_decode_packed_q_scores( + q_pack, + k_bytes, + q_scale_pack, + k_scale_arg, + scores, + kv_indices, + cu_seqlens_q, + cu_seqlens_k, + cu_page_offsets, + qo_offset_arg, + fmt=spec, + causal=causal, + has_qo_offset=has_qo_offset, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + max_k_tiles=max_k_tiles, + total_q=total_q, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + ) + return scores + prefill_compact_task_count = 0 + prefill_compact_schedule = False + if causal and has_qo_offset == 0: + k_tiles_per_cta = k_tiles_per_cta_for(causal) + q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0]) + k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta) + rectangular_task_count = q_tile_count * k_group_count + prefill_compact_task_count = min( + rectangular_task_count, + _causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta), + ) + prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19 + if prefill_compact_schedule: + scores.fill_(float("-inf")) + q_ptr = make_ptr( + cutlass.Uint8, + q_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + k_ptr = make_ptr( + cutlass.Uint8, + k_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + scores_ptr = make_ptr( + cutlass.Float32, + scores.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + kv_indices_ptr = make_ptr( + cutlass.Int32, + kv_indices.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_k_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_k.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_page_offsets_ptr = make_ptr( + cutlass.Int32, + cu_page_offsets.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + qo_offset_ptr = make_ptr( + cutlass.Int32, + qo_offset_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(m_extent), + Int32(n_aligned), + Int32(_HEAD_DIM), + Int32(batch * heads_q), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + Int32(prefill_compact_task_count), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream) + compiled = _compile_fp4_qk_kernel( + fmt=spec, + causal=causal, + preordered_q_scale_tma=use_preordered_q_scale_tma, + compact_schedule=prefill_compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_ptr=q_ptr, + k_ptr=k_ptr, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return scores + + +__all__ = [ + "fp4_indexer_block_scores", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/interface.py b/build/torch211-cxx11-cu130-x86_64-linux/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9e507961840b3322238646ffffe3e97cf5d13130 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/interface.py @@ -0,0 +1,2011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse attention interface. + +Current delivery scope: + - head dimension is supported only for D=128 + +Public API: + sparse_atten_func(...) + sparse_decode_atten_func(...) + SparseDecodePagedAttentionWrapper + +Internal forward core: + _sparse_atten_csr_varlen_forward(...) + +Preprocessing (external, done once): + q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr() + -> k2q_row_ptr [head_kv, total_rows + 1] int32 + -> k2q_q_indices [head_kv, total_q * topK] int32 +""" + +import math +import os +from typing import Optional + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 +from cutlass.cute.runtime import from_dlpack + +from .src.sm100.fwd.combine import combine +from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100 +from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 +from .src.sm100.prepare_scheduler import ( + SparseAttentionSchedule, + prepare_sparse_fwd_schedule_and_split, +) +from .src.sm100.decode_schedule import ( + DecodeAttentionSchedule, + prepare_decode_schedule, +) +from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter +from .src.common.tma_utils import ( + create_q_gather4_tma_desc, +) + +_compile_cache: dict = {} +_TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12 +_SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32) +_SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_DECODE_QHEAD_PER_KV = 16 + + +def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype: + supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn} + if partial_dtype not in supported: + raise TypeError( + "partial_dtype must be one of torch.float32 / torch.bfloat16 / " + "torch.float16 / torch.float8_e4m3fn, " + f"got {partial_dtype}" + ) + return partial_dtype + + +def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype: + dtype = fallback if dtype is None else dtype + if dtype not in _SUPPORTED_FWD_MMA_DTYPES: + raise TypeError( + f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}" + ) + return dtype + + +def _resolve_forward_mma_dtypes( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qk_dtype: Optional[torch.dtype], + pv_dtype: Optional[torch.dtype], +) -> tuple[torch.dtype, torch.dtype]: + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + if pv_dtype is None: + # Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V + # stages both K and V as BF16 compute operands. + if ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ): + pv_dtype = torch.bfloat16 + else: + pv_dtype = v.dtype + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype") + + if q.dtype != qk_dtype: + raise ValueError( + "qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported" + ) + if k.dtype != qk_dtype: + if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16): + raise ValueError( + "unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported" + ) + if v.dtype != pv_dtype: + if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16): + raise ValueError( + "unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported" + ) + return qk_dtype, pv_dtype + + +def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4): + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) + return tensor.mark_layout_dynamic(leading_dim=0) + + +def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype): + if dtype == torch.bfloat16: + return cutlass.BFloat16 + if dtype == torch.float16: + return cutlass.Float16 + if dtype == torch.float8_e4m3fn: + return cutlass.Float8E4M3FN + raise TypeError( + f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}" + ) + + +def _prepare_paged_kv_for_tma(k, v, blk_kv: int): + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + return k, v + + +def _validate_cu_seqlens( + cu_seqlens: torch.Tensor, + *, + name: str, + device: torch.device, +) -> None: + if cu_seqlens.device != device: + raise ValueError(f"{name} must be on the same device as q") + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must have shape [B + 1]") + if cu_seqlens.shape[0] < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int: + return int(k2q_row_ptr.shape[1] - 1) + + +def _validate_csr_varlen_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in _SUPPORTED_FWD_DTYPES: + raise TypeError( + "CSR sparse forward supports only torch.bfloat16 and " + f"torch.float8_e4m3fn Q/K/V, got {q.dtype}" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("q, k, v must be on the same device") + mixed_fp8_kv_bf16_q = ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ) + if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype): + raise ValueError( + "q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache" + ) + if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]: + raise ValueError("q, k, v must have the same head dimension") + dim = q.shape[-1] + if dim != 128: + raise NotImplementedError( + f"CSR sparse forward currently supports only D=128, got D={dim}" + ) + if page_table is None: + if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]: + raise ValueError("k and v must have the same [Hkv, D] tail dimensions") + head_kv = k.shape[-2] + else: + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must have the same Hkv and D" + ) + head_kv = k.shape[1] + if ( + q.device != k2q_row_ptr.device + or q.device != k2q_q_indices.device + ): + raise ValueError("CSR metadata must be on the same device as q") + if ( + k2q_row_ptr.dtype != torch.int32 + or k2q_q_indices.dtype != torch.int32 + ): + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + total_q = q.shape[0] + + head_q = q.shape[1] + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < total_q * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + total_k = k.shape[0] + if k.ndim != 3 or v.ndim != 3: + raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]") + if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]): + raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]") + else: + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2 or page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape != v.shape: + raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}") + if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must match " + "[num_pages, Hkv, page_size, D]" + ) + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError( + f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; " + "require page_size == blk_kv" + ) + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_csr_varlen_nvfp4_kv_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn): + raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}") + if q.shape[-1] != 128: + raise NotImplementedError( + f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}" + ) + if k.dtype != torch.uint8 or v.dtype != torch.uint8: + raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}") + if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8: + raise TypeError( + "KVFP4 block scales must be torch.uint8 E4M3 tensors, got " + f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}" + ) + if k_global_scale is not None and k_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None") + if v_global_scale is not None and v_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None") + tensors = ( + k, + v, + k_scale_128x4, + v_scale_128x4, + k2q_row_ptr, + k2q_q_indices, + cu_seqlens_q, + cu_seqlens_k, + ) + optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None) + if any(t.device != q.device for t in tensors + optional_tensors): + raise ValueError("KVFP4 inputs and metadata must be on the same device as q") + if k.shape != v.shape: + raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}") + packed_dim = q.shape[-1] // 2 + scale_cols = q.shape[-1] // 16 + if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2: + raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors") + if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols: + raise ValueError( + "KVFP4 block scales must have at least D/16 columns; " + f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}" + ) + if k_global_scale is not None and k_global_scale.numel() < 1: + raise ValueError("KVFP4 K global scale must contain at least one element") + if v_global_scale is not None and v_global_scale.numel() < 1: + raise ValueError("KVFP4 V global scale must contain at least one element") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + if k.ndim != 3: + raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]") + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + total_k = int(k.shape[0]) + head_kv = int(k.shape[1]) + required_scale_rows = total_k * head_kv + else: + if k.ndim != 4: + raise ValueError( + "KVFP4 Sparse Page Attention requires k/v shape " + "[num_pages, Hkv, page_size, D/2]" + ) + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError( + f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}" + ) + head_kv = int(k.shape[1]) + required_scale_rows = int(k.shape[0]) * head_kv * page_size + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + + padded_scale_rows = ((required_scale_rows + 127) // 128) * 128 + padded_scale_cols = ((scale_cols + 3) // 4) * 4 + for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)): + if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols: + raise ValueError( + f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, " + f"need at least {(padded_scale_rows, padded_scale_cols)}" + ) + + if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device: + raise ValueError("CSR metadata must be on the same device as q") + if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32: + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + if page_table is not None and page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if seqused_k is not None and seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < q.shape[0] * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_sparse_decode_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("decode attention requires q to have shape [total_q, Hq, D]") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("decode q, k, and v must be on the same device") + if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype: + raise TypeError( + "decode attention currently supports only torch.float8_e4m3fn Q/K/V" + ) + if k.shape != v.shape: + raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}") + if q.shape[-1] != 128 or k.shape[-1] != 128: + raise NotImplementedError( + f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}" + ) + if not bool(causal): + raise NotImplementedError("decode attention currently supports only causal=True") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + + head_kv = int(k.shape[1]) + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("decode q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV: + raise NotImplementedError( + "decode attention currently supports only " + f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}" + ) + + if page_table is None: + raise ValueError("decode attention requires page_table") + if page_table.device != q.device: + raise ValueError("decode page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("decode page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]") + batch = int(page_table.shape[0]) + if page_table.stride(-1) != 1: + raise ValueError("decode page_table must be contiguous in the last dimension") + + if seqused_k is None: + raise ValueError("decode attention requires seqused_k") + if seqused_k.device != q.device: + raise ValueError("decode seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("decode seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("decode seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("decode seqused_k must be contiguous") + + seqlen_q = int(seqlen_q) + max_seqlen_k = int(max_seqlen_k) + if seqlen_q <= 0 or max_seqlen_k <= 0: + raise ValueError("decode seqlen_q and max_seqlen_k must be positive") + if int(q.shape[0]) != batch * seqlen_q: + raise ValueError("decode q.shape[0] must equal batch * seqlen_q") + + if q2k_indices is not None: + if q2k_indices.device != q.device: + raise ValueError("decode q2k_indices must be on the same device as q") + if q2k_indices.dtype != torch.int32: + raise TypeError("decode q2k_indices must be torch.int32") + if q2k_indices.ndim != 3: + raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]") + if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]: + raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]") + if not q2k_indices.is_contiguous(): + raise ValueError("decode q2k_indices must be contiguous") + return batch, head_kv + + +def _validate_schedule_common( + schedule: SparseAttentionSchedule, + *, + device: torch.device, +) -> None: + if schedule.scheduler_metadata is None: + raise ValueError("schedule.scheduler_metadata is required") + if schedule.work_count is None: + raise ValueError("schedule.work_count is required") + metadata = schedule.scheduler_metadata + work_count = schedule.work_count + if metadata.device != device or work_count.device != device: + raise ValueError("schedule tensors must be on the same device as q") + if metadata.dtype != torch.int32 or work_count.dtype != torch.int32: + raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32") + if metadata.ndim != 2 or metadata.shape[1] != 6: + raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]") + if work_count.shape != (1,): + raise ValueError("schedule.work_count must have shape [1]") + if not metadata.is_contiguous() or not work_count.is_contiguous(): + raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous") + + +def _validate_fwd_schedule( + schedule: SparseAttentionSchedule, + *, + q: torch.Tensor, + k2q_q_indices: torch.Tensor, + head_kv: int, +) -> None: + _validate_schedule_common(schedule, device=q.device) + if schedule.qsplit_indices is None: + raise ValueError("schedule.qsplit_indices is required for forward") + if schedule.split_counts is None: + raise ValueError("schedule.split_counts is required for forward") + qsplit = schedule.qsplit_indices + split_counts = schedule.split_counts + if qsplit.device != q.device or split_counts.device != q.device: + raise ValueError("forward schedule tensors must be on the same device as q") + if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32: + raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32") + if qsplit.shape != k2q_q_indices.shape: + raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices") + total_q = q.shape[0] + if split_counts.shape != (total_q, head_kv): + raise ValueError( + "schedule.split_counts must have shape " + f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if not qsplit.is_contiguous() or not split_counts.is_contiguous(): + raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous") + + +def sparse_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, + usable_SM_count: int = -1, + qk_dtype: Optional[torch.dtype] = None, + pv_dtype: Optional[torch.dtype] = None, +): + """Run SM100 CSR block-sparse varlen attention. + + This is the public forward-only sparse attention API. It consumes + query-to-key block selections converted to CSR metadata by + ``build_k2q_csr`` and supports both dense KV layout and paged KV layout. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Dense layout ``[total_k, Hkv, 128]`` or paged layout + ``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K + may be FP8 E4M3 while QK compute uses BF16 staging. + v : torch.Tensor + Same layout and head count as ``k``. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + max_seqlen_q : int + Maximum Q sequence length in the batch. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + KV block size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return LSE computed with logits scaled by + ``softmax_scale / lse_temperature_scale``. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. Supported values are + FP32, BF16, FP16, and FP8 E4M3. + return_softmax_lse : bool, optional + If True, return ``(out, softmax_lse)`` or + ``(out, softmax_lse, temperature_lse)``. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Effective KV length per request + for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. If omitted, the schedule is built + during the call. + usable_SM_count : int, optional + Maximum number of SMs used by the scheduler. ``-1`` uses all SMs. + qk_dtype : torch.dtype, optional + Compile-time MMA operand dtype for QK. Defaults to Q storage dtype, + except supported FP8 K/V cache staging modes. + pv_dtype : torch.dtype, optional + Compile-time MMA operand dtype for PV. Defaults to V storage dtype, + except supported FP8 K/V cache staging modes. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + + Notes + ----- + ``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support + head dimension 128 only. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_inputs( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + max_seqlen_q = int(max_seqlen_q) + max_seqlen_k = int(max_seqlen_k) + + return _sparse_atten_csr_varlen_forward( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + int(topK), + int(blk_kv), + bool(causal), + float(softmax_scale), + lse_temperature_scale, + return_temperature_lse, + partial_dtype, + bool(return_softmax_lse), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + schedule, + int(usable_SM_count), + int(batch), + int(head_kv), + int(max_seqlen_q), + int(max_seqlen_k), + qk_dtype, + pv_dtype, + ) + + +def sparse_atten_nvfp4_kv_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Run SM100 CSR sparse attention with packed NVFP4 K/V. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged + layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8 + because each byte packs two FP4 values. + v : torch.Tensor + Packed NVFP4 V data with the same shape as ``k``. + k_scale_128x4 : torch.Tensor + K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8 + containing FP8 E4M3 scale values. + v_scale_128x4 : torch.Tensor + V block scales in the same 128x4 tiled storage. + k_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for K. May be ``None``. + v_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for V. May be ``None``. The V global + scale is applied in the combine stage. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q, cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV + lengths. + max_seqlen_q, max_seqlen_k : int + Maximum Q and KV sequence lengths in the batch. + blk_kv : int, optional + KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return temperature-scaled LSE. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. + return_softmax_lse : bool, optional + If True, return LSE together with the output. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Effective KV length per request for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + """ + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + total_q, head_q, dim = q.shape + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + + schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k_scale_128x4.contiguous(), + v_scale_128x4.contiguous(), + None if k_global_scale is None else k_global_scale.contiguous(), + None if v_global_scale is None else v_global_scale.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + k2q_qsplit_indices.contiguous(), + split_counts.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + O_partial, + LSE_partial, + LSE_temperature_partial, + float(softmax_scale), + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + int(blk_kv), + head_kv, + int(max_seqlen_q), + causal=bool(causal), + schedule=schedule, + ) + + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + output_scale=v_global_scale, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def sparse_decode_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor] = None, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = True, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + schedule: Optional[DecodeAttentionSchedule] = None, + O_partial: Optional[torch.Tensor] = None, + LSE_partial: Optional[torch.Tensor] = None, +): + """Run forward-only paged FP8 decode attention. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8 + E4M3 dtype. + v : torch.Tensor + Paged V cache with the same shape and dtype as ``k``. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype + int32. ``None`` selects the dense all-KV decode path. + page_table : torch.Tensor + Physical page table with shape ``[batch_size, max_num_pages_per_seq]`` + and dtype int32. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per request. + seqlen_q : int + Uniform query length per request. Ragged Q lengths should use prefill + or append paths instead. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + Page size. Must match ``k.shape[2]``. + causal : bool, optional + Whether to apply causal masking. Current decode kernel requires True. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + schedule : DecodeAttentionSchedule, optional + Prebuilt decode schedule. + O_partial, LSE_partial : torch.Tensor, optional + Optional split-KV partial workspaces. Normally owned by + ``SparseDecodePagedAttentionWrapper``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output with shape ``q.shape``. Optional LSE has shape + ``[batch_size * seqlen_q, Hq]`` and dtype float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + batch, head_kv = _validate_sparse_decode_inputs( + q, + k, + v, + q2k_indices, + page_table=page_table, + seqused_k=seqused_k, + seqlen_q=seqlen_q, + max_seqlen_k=max_seqlen_k, + blk_kv=blk_kv, + causal=causal, + ) + head_q = int(q.shape[1]) + head_dim = int(q.shape[2]) + if schedule is None: + schedule = prepare_decode_schedule( + seqused_k=seqused_k.contiguous(), + page_size=int(blk_kv), + seqlen_q=int(seqlen_q), + num_qo_heads=head_q, + num_kv_heads=head_kv, + head_dim=head_dim, + max_seqlen_k=int(max_seqlen_k), + ) + if schedule.split_kv: + if O_partial is None: + O_partial = torch.empty( + (schedule.partial_rows, head_q, head_dim), + dtype=torch.float32, + device=q.device, + ) + if LSE_partial is None: + LSE_partial = torch.empty( + (schedule.partial_rows, head_q), + dtype=torch.float32, + device=q.device, + ) + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + lse = torch.empty( + q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q), + dtype=torch.float32, + device=q.device, + ) + _call_sparse_decode_forward_sm100_paged_fp8( + q.contiguous(), + k.contiguous(), + v.contiguous(), + None if q2k_indices is None else q2k_indices.contiguous(), + page_table.contiguous(), + seqused_k.contiguous(), + out, + lse, + schedule, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + max_seqlen_k=int(max_seqlen_k), + blk_kv=int(blk_kv), + causal=bool(causal), + return_lse=bool(return_softmax_lse), + ) + if return_softmax_lse: + return out, lse + return out + + +class SparseDecodePagedAttentionWrapper: + """Plan/run helper for paged FP8 decode attention. + + Use this wrapper when the same page table shape and sequence metadata are + reused across multiple decode layers. ``plan`` validates metadata and + allocates persistent schedules/workspaces; ``run`` then launches the decode + kernel with lower per-call overhead than ``sparse_decode_atten_func``. + """ + + def __init__(self, *, blk_kv: int = 128, causal: bool = True): + self.blk_kv = int(blk_kv) + self.causal = bool(causal) + self.batch: Optional[int] = None + self.num_qo_heads: Optional[int] = None + self.num_kv_heads: Optional[int] = None + self.head_dim: Optional[int] = None + self.page_table: Optional[torch.Tensor] = None + self.seqused_k: Optional[torch.Tensor] = None + self.q2k_indices: Optional[torch.Tensor] = None + self.seqlen_q: Optional[int] = None + self.max_seqlen_k: Optional[int] = None + self.is_sparse: bool = False + self.decode_schedule: Optional[DecodeAttentionSchedule] = None + self.request_indices: Optional[torch.Tensor] = None + self.qo_tile_indices: Optional[torch.Tensor] = None + self.kv_tile_indices: Optional[torch.Tensor] = None + self.merge_indptr: Optional[torch.Tensor] = None + self.o_indptr: Optional[torch.Tensor] = None + self.block_valid_mask: Optional[torch.Tensor] = None + self.kv_pages: Optional[torch.Tensor] = None + self.split_counts: Optional[torch.Tensor] = None + self.split_kv: bool = False + self.cta_tile_q: int = 0 + self.num_q_tiles: int = 0 + self.kv_chunk_size_pages: int = 0 + self.kv_chunk_size_tokens: int = 0 + self.work_count: int = 0 + self.padded_work_count: int = 0 + self.O_partial: Optional[torch.Tensor] = None + self.LSE_partial: Optional[torch.Tensor] = None + # Cached dummy buffers used in non-split path to satisfy the kernel's + # positional arg signature without per-call torch.empty (saves ~5us + # on every run() for small kv). + self._O_partial_dummy: Optional[torch.Tensor] = None + self._LSE_partial_dummy: Optional[torch.Tensor] = None + # When the caller doesn't ask for LSE, the kernel still needs a valid + # tensor pointer to write to. Cache a small placeholder so run() can + # skip the per-call torch.empty for it as well. + self._lse_dummy: Optional[torch.Tensor] = None + + def plan( + self, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + q2k_indices: Optional[torch.Tensor] = None, + num_qo_heads: Optional[int] = None, + num_kv_heads: Optional[int] = None, + head_dim: Optional[int] = 128, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, + ) -> "SparseDecodePagedAttentionWrapper": + """Prepare decode scheduling metadata and reusable workspaces. + + Parameters + ---------- + page_table : torch.Tensor + Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps + logical pages to physical KV-cache pages. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per + request. + seqlen_q : int + Uniform query length per request. + max_seqlen_k : int + Maximum KV sequence length in the batch. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and + dtype int32. ``None`` selects the dense all-KV path. + num_qo_heads : int + Number of Q/O heads. + num_kv_heads : int + Number of KV heads. Current decode kernel requires + ``num_qo_heads / num_kv_heads == 16`` at run time. + head_dim : int, optional + Head dimension. Must be 128. + enable_cuda_graph : bool, optional + Build schedule metadata compatible with CUDA graph capture. + max_grid_size : int, optional + Override maximum CTA count used by the scheduler. + fixed_split_size : int, optional + Force a fixed split-KV chunk size in pages. + disable_split_kv : bool, optional + Disable split-KV even for long KV sequences. + + Returns + ------- + SparseDecodePagedAttentionWrapper + ``self``, planned and ready for ``run``. + """ + if page_table.ndim != 2: + raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]") + if page_table.dtype != torch.int32: + raise TypeError("decode plan requires page_table to be torch.int32") + if seqused_k.dtype != torch.int32: + raise TypeError("decode plan requires seqused_k to be torch.int32") + if not page_table.is_cuda or not seqused_k.is_cuda: + raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors") + if page_table.device != seqused_k.device: + raise ValueError("decode plan requires page_table and seqused_k on the same device") + if page_table.stride(-1) != 1: + raise ValueError("decode plan requires page_table contiguous in the last dimension") + if seqused_k.shape != (int(page_table.shape[0]),): + raise ValueError("decode plan requires seqused_k with shape [B]") + if q2k_indices is not None and q2k_indices.dtype != torch.int32: + raise TypeError("decode plan requires q2k_indices to be torch.int32") + if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0: + raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k") + if num_qo_heads is None or num_kv_heads is None or head_dim is None: + raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim") + if head_dim is not None and int(head_dim) != 128: + raise NotImplementedError("decode plan currently supports only head_dim=128") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads") + + self.batch = int(page_table.shape[0]) + self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads) + self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads) + self.head_dim = None if head_dim is None else int(head_dim) + self.page_table = page_table.contiguous() + self.seqused_k = seqused_k.contiguous() + self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous() + self.seqlen_q = int(seqlen_q) + self.max_seqlen_k = int(max_seqlen_k) + self.is_sparse = q2k_indices is not None + + # max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++ + # schedule launcher because the decode attn kernel always runs at + # 1 CTA/SM (its register/smem budget saturates the SM). Callers + # can still override via the explicit max_grid_size kwarg. + schedule = prepare_decode_schedule( + seqused_k=self.seqused_k, + page_size=self.blk_kv, + seqlen_q=self.seqlen_q, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seqlen_k=self.max_seqlen_k, + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=max_grid_size, + fixed_split_size=fixed_split_size, + disable_split_kv=bool(disable_split_kv), + ) + self.decode_schedule = schedule + self.request_indices = schedule.request_indices + self.qo_tile_indices = schedule.qo_tile_indices + self.kv_tile_indices = schedule.kv_tile_indices + self.merge_indptr = schedule.merge_indptr + self.o_indptr = schedule.o_indptr + self.block_valid_mask = schedule.block_valid_mask + self.kv_pages = schedule.kv_pages + self.split_counts = schedule.split_counts + self.split_kv = schedule.split_kv + self.cta_tile_q = schedule.cta_tile_q + self.num_q_tiles = schedule.num_q_tiles + self.kv_chunk_size_pages = schedule.kv_chunk_size_pages + self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens + self.work_count = schedule.work_count + self.padded_work_count = schedule.padded_work_count + if schedule.split_kv: + self.O_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self.LSE_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + self._O_partial_dummy = None + self._LSE_partial_dummy = None + else: + self.O_partial = None + self.LSE_partial = None + # decode_forward_paged_fp8 always wants non-None partial buffers + # for the kernel's positional arg layout (compile keeps the slot + # alive even when split_kv=False). Allocate once here and reuse. + self._O_partial_dummy = torch.empty( + (1, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self._LSE_partial_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + # LSE dummy is shape (1, head_q) — used when caller doesn't request + # LSE and the schedule isn't split-KV (split-KV always writes LSE). + self._lse_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + return self + + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + ): + """Launch decode using metadata cached by ``plan``. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``. + v : torch.Tensor + Paged V cache with the same shape as ``k``. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + out : torch.Tensor, optional + Preallocated BF16 output buffer with shape ``q.shape``. + lse : torch.Tensor, optional + Preallocated float32 LSE buffer with shape ``[total_q, Hq]``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output, optionally with float32 LSE. + """ + if self.decode_schedule is None: + raise RuntimeError("decode wrapper must be planned before run") + if self.is_sparse: + # Sparse path still goes through the validating wrapper for now; + # only the dense fast path is collapsed. + return sparse_decode_atten_func( + q, k, v, self.q2k_indices, + page_table=self.page_table, seqused_k=self.seqused_k, + seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, + blk_kv=self.blk_kv, causal=self.causal, + softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse, + schedule=self.decode_schedule, + O_partial=self.O_partial, LSE_partial=self.LSE_partial, + ) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + if out is None: + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + if lse is None: + if return_softmax_lse or self.split_kv: + # Real LSE needed — must allocate per-call (shape depends on q). + lse = torch.empty( + q.shape[:2], dtype=torch.float32, device=q.device, + ) + else: + # Kernel only needs a valid pointer; reuse cached dummy. + lse = self._lse_dummy + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + schedule = self.decode_schedule + decode_forward_paged_fp8( + q, k, v, + self.page_table, self.seqused_k, + out, lse, + schedule.request_indices, schedule.qo_tile_indices, + schedule.kv_tile_indices, schedule.block_valid_mask, + schedule.split_counts, schedule.o_indptr, schedule.merge_indptr, + self.O_partial, self.LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=self.seqlen_q, + page_size=self.blk_kv, + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=self.causal, + return_lse=bool(return_softmax_lse), + # cached dummies — avoid per-call torch.empty inside run_decode_attention + O_partial_dummy=self._O_partial_dummy, + LSE_partial_dummy=self._LSE_partial_dummy, + ) + if return_softmax_lse: + return out, lse + return out + + +def _sparse_atten_csr_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + causal: bool, + softmax_scale: float, + lse_temperature_scale: float, + return_temperature_lse: bool, + partial_dtype: torch.dtype, + return_softmax_lse: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + page_table: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + schedule: Optional[SparseAttentionSchedule], + usable_SM_count: int, + batch: int, + head_kv: int, + max_seqlen_q: int, + max_seqlen_k: int, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + total_q, head_q, dim = q.shape + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by head_kv") + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + schedule = _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count, + causal=causal, + schedule=schedule, + qk_dtype=qk_dtype, + pv_dtype=pv_dtype, + ) + # Sparse Attention and Sparse Page Attention both use the varlen-Q + # combine path; the kernel-written LSE_out is the final contract. + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def _call_sparse_decode_forward_sm100_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + schedule: DecodeAttentionSchedule, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, + return_lse: bool = True, +) -> None: + """Compile and launch the SM100 paged fp8 decode forward kernel. + + Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse + the same schedule wrapper but needs a separate q2k gather path. + """ + if q2k_indices is not None: + raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet") + if schedule.cta_tile_q != 128: + raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}") + if schedule.split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode forward requires O_partial and LSE_partial") + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + + decode_forward_paged_fp8( + q, + k, + v, + page_table, + seqused_k, + out, + lse, + schedule.request_indices, + schedule.qo_tile_indices, + schedule.kv_tile_indices, + schedule.block_valid_mask, + schedule.split_counts, + schedule.o_indptr, + schedule.merge_indptr, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(blk_kv), + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + ) + + +def _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count=-1, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + """Compile and launch the SM100 sparse forward K1 kernel on CSR metadata.""" + head_dim = q.shape[-1] + dtype = q.dtype + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype") + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size) + else: + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + k.dtype, + v.dtype, + qk_dtype, + pv_dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype), + pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"): + _compile_cache[key]( + k_kernel, + v_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule + + +def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V.""" + + head_dim = q.shape[-1] + dtype = q.dtype + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0" + k_global_scale_kernel = k_global_scale + # V global scale is linear in the final output. Keep K1 on block-scale-only V + # and apply the tensor scale once in K2 combine. + v_global_scale_kernel = None + has_k_global_scale = k_global_scale_kernel is not None + has_v_global_scale = v_global_scale_kernel is not None + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("KVFP4 sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + _prepare_paged_kv_for_tma(k, v, n_block_size) + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen_nvfp4_kv", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + bool(fp8_pair_dequant), + bool(has_k_global_scale), + bool(has_v_global_scale), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardNvfp4KvSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + fp8_pair_dequant=bool(fp8_pair_dequant), + has_k_global_scale=bool(has_k_global_scale), + has_v_global_scale=bool(has_v_global_scale), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k_scale_128x4), + to_cute_tensor_kvouter(v_scale_128x4), + None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel), + None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"): + _compile_cache[key]( + k_kernel, + v_kernel, + k_scale_128x4, + v_scale_128x4, + k_global_scale_kernel, + v_global_scale_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..912eb247b3c5840959dfe2b4dd340ddcfc2ea7d5 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,71 @@ +{ + "name": "msa", + "id": "_msa_cuda_09d7851", + "version": 0, + "license": "other", + "upstream": "https://github.com/MiniMax-AI/MSA", + "python-depends": [ + "tvm-ffi", + "nvidia-cutlass-dsl" + ], + "backend": { + "type": "cuda", + "archs": [ + "10.0" + ] + }, + "digest": { + "algorithm": "sha256", + "files": { + "__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=", + "_msa_cuda_09d7851.abi3.so": "V926Sb+rbIkUkZdd91Ziz71UtKo3mvb3XZHBG5C3CzE=", + "_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=", + "fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=", + "interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=", + "msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=", + "quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + "quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=", + "quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=", + "quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=", + "quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=", + "quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=", + "quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=", + "sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=", + "src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=", + "src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=", + "src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=", + "src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=", + "src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=", + "src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=", + "src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=", + "src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=", + "src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=", + "src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=", + "src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=", + "src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=", + "src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=", + "src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=", + "src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=", + "src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=", + "src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=", + "src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=", + "src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=", + "src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=", + "src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=", + "src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=", + "src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=", + "src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=", + "src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=", + "src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=", + "src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=", + "src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=", + "src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=", + "src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=", + "src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=", + "src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=", + "src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw=" + } + } +} \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json.sigstore b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json.sigstore new file mode 100644 index 0000000000000000000000000000000000000000..43ee11d211992759dfd7733018752cf3d0828d42 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json.sigstore @@ -0,0 +1 @@ +{"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHTDCCBtKgAwIBAgIUVVP4M4gmTl/ZRCYCkyL6hrHrfFowCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA4WhcNMjYwNjMwMTc1NDA4WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE1g3lt+NIVbPFRc82LW1gyIDDa5k1Ee8LidW1EfmmOhgabGHRtDQeNNYXVEmRZm7/A/7cDuwuAE60+3NEbtoNL6OCBfEwggXtMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQUDS+MoOyWzeRZKmVrckA5xixJ2PswHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBiwYKKwYBBAHWeQIEAgR9BHsAeQB3AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhmjwAAAQDAEgwRgIhAPQX2iwR1v4JZhAId6Hp1jokY/sGjA+YAAPAKGYyEn83AiEA3EVC3Qe0rTg3kZBS8GV1GE1WrEMYJZpFiKxgw0tsm7owCgYIKoZIzj0EAwMDaAAwZQIwMInW/0wVz0LCkC4XkpYv0D+rmFVSX7QhPmFcU75eqDKuM9zWjsjyA4J7VI2JQ9kCAjEA418fgIyoISriIi08+3oujNlNDxdJdr3xH6rjdchgsXuWWxkX+RHQB/yT7HmRAvU+"},"tlogEntries":[{"logIndex":"2024793428","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841449","inclusionPromise":{"signedEntryTimestamp":"MEQCIHwVjgjqMooWGrQZwBQWds3oaBax/EWLh7hv8GUkvYQDAiAEilhfTfnZnwOuRoPz1ni7/FhIeMP/t19OS935qAwthg=="},"inclusionProof":{"logIndex":"1902889166","rootHash":"PAS83ozl6aPIQR05MLeOs0DVUmDpUz8fv9kxpV0qmU4=","treeSize":"1902889170","hashes":["q4a6ZMHnYHWHnttum84EAY+XdabO++pGcTLha6Qj9sA=","hC+Bq91LQ9Xq6tYWzOLgH0iAgjRG9u0pQDUfp7gLeqY=","X/c6hmAA2zEKQ1y1eZTcTVaW3qGpio2fMZCsOZtkqkg=","c+RxtTibysKYzSurSzL0S5arsdYvjjw/nlAE7obI8J8=","lOlr8plq1OiOmAwap+xmNZYYZAQ8m/i5u9knC5Ej+Kk=","hMZERm0o1e7Zeq+6+/Uswm7K85VIy79GXIJNmHvHIkM=","QDXnrDViAE4d1fmZDEBGBWa0x8ebVY/zivZduxoii4g=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889170\nPAS83ozl6aPIQR05MLeOs0DVUmDpUz8fv9kxpV0qmU4=\n\n— rekor.sigstore.dev wNI9ajBFAiAH3GA0IRxXyVLBSOJqQuhNL8Lfq8W7xlL+gVGZPzNnKQIhAI05TRHG+jAQbsGKBGmjPuU5SoHZDluTPWrbLFPUJ8t0\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiJkNDczMzc1MDg0NDk0YWVjYWQ5ZjNhMzEyMTYxMWJmNjM3MzhmNWEwZmU2YjAyOGI0ZTNiYjUzYzQ0ZGViZDAwIn19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FUUNJSFhnUUgwOWh3UHpybkQ1UHVISnQ4UWJZK293N1duT1k2TWVDdk5XUnc1SEFpQlp1MDFsd0ZBU1ozb1NqeGUyQThSMWJZRjR6enFObHdwZHlHVEhNcFhuOWc9PSIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFVSRU5EUW5STFowRjNTVUpCWjBsVlZsWlFORTAwWjIxVWJDOWFVa05aUTJ0NVREWm9ja2h5WmtadmQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTkZkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFMFYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVV4WnpOc2RDdE9TVlppVUVaU1l6Z3lURmN4WjNsSlJFUmhOV3N4UldVNFRHbGtWekVLUldadGJVOW9aMkZpUjBoU2RFUlJaVTVPV1ZoV1JXMVNXbTAzTDBFdk4yTkVkWGQxUVVVMk1Dc3pUa1ZpZEc5T1REWlBRMEptUlhkbloxaDBUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlZFVXl0TkNtOVBlVmQ2WlZKYVMyMVdjbU5yUVRWNGFYaEtNbEJ6ZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwZDFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT1VKSWMwRUtaVkZDTTBGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2h0YW5kQlFVRlJSQXBCUldkM1VtZEphRUZRVVZneWFYZFNNWFkwU2xwb1FVbGtOa2h3TVdwdmExa3ZjMGRxUVN0WlFVRlFRVXRIV1hsRmJqZ3pRV2xGUVRORlZrTXpVV1V3Q25KVVp6TnJXa0pUT0VkV01VZEZNVmR5UlUxWlNscHdSbWxMZUdkM01IUnpiVGR2ZDBObldVbExiMXBKZW1vd1JVRjNUVVJoUVVGM1dsRkpkMDFKYmxjS0x6QjNWbm93VEVOclF6UllhM0JaZGpCRUszSnRSbFpUV0RkUmFGQnRSbU5WTnpWbGNVUkxkVTA1ZWxkcWMycDVRVFJLTjFaSk1rcFJPV3REUVdwRlFRbzBNVGhtWjBsNWIwbFRjbWxKYVRBNEt6TnZkV3BPYkU1RWVHUktaSEl6ZUVnMmNtcGtZMmhuYzFoMVYxZDRhMWdyVWtoUlFpOTVWRGRJYlZKQmRsVXJDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyTADAgEAMIICwAYJKoZIhvcNAQcCoIICsTCCAq0CAQMxDTALBglghkgBZQMEAgEwgbcGCyqGSIb3DQEJEAEEoIGnBIGkMIGhAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQglcg+ZS9nFSD28yuBKPMOo6PT19qhNuf1jiDZWDjuRDsCFA/ufLVwoeNKsUE/mGLnBCmsAetmGA8yMDI2MDYzMDE3NDQwOVowAwIBAaAypDAwLjEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MRUwEwYDVQQDEwxzaWdzdG9yZS10c2GgADGCAdswggHXAgEBMFEwOTEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MSAwHgYDVQQDExdzaWdzdG9yZS10c2Etc2VsZnNpZ25lZAIUOhNULwyQYe68wUMvy4qOiyojiwwwCwYJYIZIAWUDBAIBoIH8MBoGCSqGSIb3DQEJAzENBgsqhkiG9w0BCRABBDAcBgkqhkiG9w0BCQUxDxcNMjYwNjMwMTc0NDA5WjAvBgkqhkiG9w0BCQQxIgQg3xy5GhdNzcu3rdvs8ij1v0JfzL6XRtqCzrYGKJAGy5IwgY4GCyqGSIb3DQEJEAIvMX8wfTB7MHkEIIX5J7wHq2LKw7RDVsEO/IGyxog/2nq55thw2dE6zQW3MFUwPaQ7MDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAoGCCqGSM49BAMCBGcwZQIxAIUN/YeV+4MCr8TNSDfIGzTJ5mU0RbyUHkDCdFwX4KrCZXwa+MgFHwy8o/JYfQ3BhwIwMXOx+EZSAx4tMKpqFLugv8INtbNCIUwOA5mbhI8ag+5312K+dJgUFvh5Amu1Lx9V"}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"1HM3UIRJSuytnzoxIWEb9jc49aD+awKLTju1PETevQA="},"signature":"MEQCIHXgQH09hwPzrnD5PuHJt8QbY+ow7WnOY6MeCvNWRw5HAiBZu01lwFASZ3oSjxe2A8R1bYF4zzqNlwpdyGTHMpXn9g=="}} \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-x86_64-linux/msa/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/msa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/msa/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/activation.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cbeb29242b92b7cc336cd336604e58c36f4459 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quack/activation.py @@ -0,0 +1,532 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Tuple +from functools import partial + +import cutlass.cute as cute +from cutlass import Float32, Boolean, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm + + +F32_or_F32x2 = Float32 | Tuple[Float32, Float32] + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, +) + + +@dsl_user_op +def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) + return 0.5 + 0.5 * tanh(0.5 * x) + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + + +@dsl_user_op +def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + # return dout * out * (1.0 - out) + return dout * (out - out * out) + + +@dsl_user_op +def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) + else: + return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)) + + +@dsl_user_op +@cute.jit +def drelu( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0)) + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0)) + return dx, relu(x) + + +@dsl_user_op +def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * x + else: + relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) + return cute.arch.mul_packed_f32x2(relu_x, x) + + +@dsl_user_op +@cute.jit +def drelu_sq( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward + Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out + Returns: (dx, relu_sq_out) where: + - dx = dout * 2 * x if x > 0, else 0 + - relu_sq_out = max(x, 0) * x + """ + if const_expr(not isinstance(x, tuple)): + relu_x = relu(x) + relu_sq_out = relu_x * x + # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0 + dx = 2.0 * (dout * relu_x) + return dx, relu_sq_out + else: + relu_x = relu(x) + relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x) + dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x)) + return dx, relu_sq_out + + +@dsl_user_op +def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ + gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x))) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774 + if const_expr(not isinstance(x, tuple)): + return 0.5 * ( + x + # Currently cute.math.tanh(x, fastmath=True) generates very slow code + # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True)) + * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) + ) + else: + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x) + return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + + +@dsl_user_op +def dgelu_tanh_approx( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward + Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out + Returns: (dx, gelu_out) + + Derivative uses the chain rule: + d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2 + and sech^2(z) = 1 - tanh^2(z) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774 + sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322 + + if const_expr(not isinstance(x, tuple)): + # Compute z = x * (c1 + c2 * x^2) + x_sq = x * x + # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True) + tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq)) + half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z + gelu_out = x * half_tanh_z_plus_one + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = 1 - tanh_z * tanh_z + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx)) + + dx = dout * dgelu + return dx, gelu_out + else: + # Compute z = x * (c1 + c2 * x^2) + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one) + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + + dx = cute.arch.mul_packed_f32x2(dout, dgelu) + return dx, gelu_out + + +@dsl_user_op +@cute.jit +def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + use_linear = Boolean(x > 20.0) + return ( + cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True) + if not use_linear + else x + ) + else: + log2_e = math.log2(math.e) + x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e)) + x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) + x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0)) + log_x_exp_p1 = ( + cute.math.log2(x_exp_p1[0], fastmath=True), + cute.math.log2(x_exp_p1[1], fastmath=True), + ) + ln2 = math.log(2.0) + softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + use_linear_0 = Boolean(x[0] > 20.0) + use_linear_1 = Boolean(x[1] > 20.0) + return ( + softplus_x[0] if not use_linear_0 else x[0], + softplus_x[1] if not use_linear_1 else x[1], + ) + + +@dsl_user_op +@cute.jit +def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + use_linear = Boolean(out > 20.0) + # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout + dx = dout - dout * cute.math.exp(-out, fastmath=True) + return dx if not use_linear else dout + + +@dsl_user_op +def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2: + """ + silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA. + """ + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x if const_expr(not already_halved) else x + # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half + return x_half * tanh(x_half) + x_half + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half) + + +@dsl_user_op +def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return silu(x) * y + else: + return cute.arch.mul_packed_f32x2(silu(x), y) + + +@dsl_user_op +def dswiglu( + x: F32_or_F32x2, + y: F32_or_F32x2, + dout: F32_or_F32x2, + *, + already_halved: bool = False, + loc=None, + ip=None, +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out + Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x) + + d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + + This has been optimized to use fewer instructions (i.e. we expand things out + to use FFMA instead of FADD and FMUL). + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) + # FMUL, MUFU.TANH, then FFMA + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = x * sigmoid_x # FMUL + else: + tanh_x = tanh(x) # MUFU.TANH + sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA + silu_x = x * tanh_x + x # FFMA + silu_x_dout = silu_x * dout # FMUL + # d_silu(x) * dout + # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout + # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA + dx = d_silu_x_dout * y # FMUL + dy = silu_x_dout + swiglu_out = silu_x * y # FMUL + # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(x) and silu(x) + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x) + else: + tanh_x = (tanh(x[0]), tanh(x[1])) + sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2( + sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x + ) + d_silu_x_dout = cute.arch.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout + ) + dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y) + dy = silu_x_dout + swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y) + return dx, dy, swiglu_out + + +@dsl_user_op +def swiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> F32_or_F32x2: + """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y. + https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249 + x * sigmoid(alpha * x) * (y + 1) + Compile down to FMUL, FMUL, TANH, FFMA, FFMA + """ + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x + # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half + silu_x = x_half * tanh(alpha * x_half) + x_half + return silu_x * y + silu_x + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + + +@dsl_user_op +def dswiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + Swiglu OAI backward pass: computes gradients w.r.t. x and y + Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out + Returns: (dx, dy, swiglu_oai_out) + + Derivative of x * sigmoid(alpha * x) w.r.t. x: + d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x)) + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + alpha_x_half = (0.5 * alpha) * x # FMUL + # MUFU.TANH, then FFMA + # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True) + sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) + silu_x = x * sigmoid_alpha_x # FMUL + silu_x_dout = silu_x * dout # FMUL + # FFMA, FFMA, FMUL + d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1 + dy = silu_x_dout + swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1 + # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(alpha * x) + alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + silu_x_minus_product = cute.arch.fma_packed_f32x2( + silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x + ) + sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2( + (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x + ) + d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + dy = silu_x_dout + swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + return dx, dy, swiglu_out + + +@dsl_user_op +def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GLU: Gated Linear Unit + glu(x, y) = sigmoid(x) * y + Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + """ + if const_expr(not isinstance(x, tuple)): + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + return sigmoid_x * y # FMUL + else: + sigmoid_x = sigmoid(x) + return cute.arch.mul_packed_f32x2(sigmoid_x, y) + + +@dsl_user_op +def dglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out + Returns: (dx, dy, glu_out) where: + - dx = dout * y * sigmoid(x) * (1 - sigmoid(x)) + - dy = dout * sigmoid(x) + - glu_out = sigmoid(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + sigmoid_x_dout = sigmoid_x * dout # FMUL + glu_out = sigmoid_x * y # FMUL + # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout + # = y * (1 - sigmoid(x)) * sigmoid_x_dout + # = (y - y * sigmoid(x)) * sigmoid_x_dout + # = (y - glu_out) * sigmoid_x_dout + dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL + dy = sigmoid_x_dout + # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA + return dx, dy, glu_out + else: + sigmoid_x = sigmoid(x) + sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout) + glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y) + # dx = (y - glu_out) * sigmoid_x_dout + y_minus_glu_out = sub_packed_f32x2(y, glu_out) + dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + dy = sigmoid_x_dout + return dx, dy, glu_out + + +@dsl_user_op +def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ReGLU: ReLU Gated Linear Unit + reglu(x, y) = relu(x) * y = max(x, 0) * y + """ + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * y + else: + relu_x = relu(x) + return cute.arch.mul_packed_f32x2(relu_x, y) + + +@dsl_user_op +@cute.jit +def dreglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out + Returns: (dx, dy, reglu_out) where: + - dx = dout * y if x > 0, else 0 + - dy = dout * relu(x) + - reglu_out = relu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + relu_x = cute.arch.fmax(x, Float32(0.0)) + dx = (dout * y) if x_pos else Float32(0.0) + dy = dout * relu_x + reglu_out = relu_x * y + return dx, dy, reglu_out + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + relu_x = relu(x) + dout_y = cute.arch.mul_packed_f32x2(dout, y) + dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) + dy = cute.arch.mul_packed_f32x2(dout, relu_x) + reglu_out = cute.arch.mul_packed_f32x2(relu_x, y) + return dx, dy, reglu_out + + +@dsl_user_op +def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GeGLU: GELU Gated Linear Unit + geglu(x, y) = gelu(x) * y + Uses the tanh approximation of GELU + """ + if const_expr(not isinstance(x, tuple)): + return gelu_tanh_approx(x) * y + else: + return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y) + + +@dsl_user_op +def dgeglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out + Returns: (dx, dy, geglu_out) where: + - dx = dout * y * d_gelu(x) + - dy = dout * gelu(x) + - geglu_out = gelu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = dgelu_x_dout * y + dy = gelu_x * dout + geglu_out = gelu_x * y + return dx, dy, geglu_out + else: + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y) + dy = cute.arch.mul_packed_f32x2(gelu_x, dout) + geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y) + return dx, dy, geglu_out diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/compile_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4375594669c8f12d6a79d8878316271cb819568a --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quack/compile_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Optional + +import cutlass.cute as cute + + +def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/copy_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad989559766d6ee6e8ece9d322bf08980706dfa --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quack/copy_utils.py @@ -0,0 +1,890 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import re +from typing import Optional, Type, Tuple, Callable, Sequence +from functools import partial + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Int16, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline +from cutlass._mlir.dialects import llvm +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + +Sm100MmaPeerBitMask = 0xFEFFFFFF + + +@dsl_user_op +def cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + retile: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + if const_expr(retile): + src = tiled_copy.retile(src) + cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def load_s2r_retile( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst_shape: cute.Tensor | cute.Shape, + *, + loc=None, + ip=None, +) -> cute.Tensor: + # Will also accept dst_shape being a tensor, in which case we write into that tensor + if const_expr(not isinstance(dst_shape, cute.Tensor)): + dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip) + else: + dst = dst_shape + cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + num_copy_elems = src.shape[0][0] + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +# def tiled_copy_2d( +# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +# ) -> cute.TiledCopy: +# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width +# copy_elems = num_copy_bits // dtype.width +# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() +# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +# gmem_threads_per_row = major_mode_size // copy_elems +# assert num_threads % gmem_threads_per_row == 0 +# thr_layout = cute.make_ordered_layout( +# (num_threads // gmem_threads_per_row, gmem_threads_per_row), +# order=(1, 0), +# ) +# val_layout = cute.make_layout((1, copy_elems)) +# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return b, m, s + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: + bit_msk = (1 << b) - 1 + yyy_msk = bit_msk << (m + s) + return ptr_int ^ ((ptr_int & yyy_msk) >> s) + + +def swizzle_ptr(ptr: cute.Pointer): + b, m, s = parse_swizzle_from_pointer(ptr) + ptr_int = swizzle_int(ptr.toint(), b, m, s) + return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) + + +def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: + outer = tensor.layout + width = tensor.element_type.width + inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for + # for 16 bits and <3, 2, 3> for 32 bits) + new_layout = cute.recast_layout( + width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer)) + ) + # recast_ptr to remove the pointer swizzle + return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout) + + +def partition_D_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_D(tensor).iterator), + thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +def partition_S_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_S(tensor).iterator), + thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +@dsl_user_op +def sm90_get_smem_load_op( + layout_c: cutlass.utils.LayoutEnum, + elem_ty_c: Type[cutlass.Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem load atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_c : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_c : Type[Numeric] + The element type for output tensor D. + + Returns: + -------- + Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters. + """ + + if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta): + raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") + is_m_major = layout_c.is_m_major_c() + if elem_ty_c.width == 16: + return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip) + else: + return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_load_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_store_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + + def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs): + dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx] + cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sC + + +def get_smem_load_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sC = thr_copy.partition_S(sC) + else: + tSR_sC = partition_S_position_independent(thr_copy, sC) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape + + def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs): + src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx] + return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs) + + return copy_fn, thr_copy, tSR_sC + + +def epilog_smem_copy_atom( + tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False +) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + +def get_smem_store_epi( + tiled_mma: cute.TiledMma, + epi_tile: cute.Shape, + sC: Optional[cute.Tensor], + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]: + dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16 + tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile) + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom) + thr_copy = tiled_copy.get_slice(tidx) + tRS_sC = None + if const_expr(sC is not None): + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + sC_shape = sC.shape[:2] if sC is not None else epi_tile + # (R2S, R2S_M, R2S_N, PIPE_C) + tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape + tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs) + + return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC + + +def get_smem_store_A( + tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sA = thr_copy.partition_D(sA) + else: + tRS_sA = partition_D_position_independent(thr_copy, sA) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sA + + +def get_smem_load_A( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + tidx: Int32, + arch: int, + with_dst_tensor: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sA = thr_copy.partition_S(sA) + else: + tSR_sA = partition_S_position_independent(thr_copy, sA) + tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs): + return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs) + + return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + Get the address of the TMA descriptor embedded in a TMA Copy Atom. + + Extracts the constant memory address of the TMA descriptor for use with + custom PTX instructions. + + :param tma_atom: TMA Copy Atom from make_tiled_tma_atom + :return: Pointer to TMA descriptor in constant memory + + Example: + >>> desc_ptr = get_tma_descriptor_address(tma_atom) + """ + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + + +@dsl_user_op +def tma_gather4_load( + tma_desc_ptr: cute.Pointer, + dst_smem_ptr: cute.Pointer, + mbarrier_ptr: cute.Pointer, + col_idx: Int32, + row_indices: Sequence[Int32], + *, + num_cta: int = 1, + multicast_mask=None, + loc=None, + ip=None, +) -> None: + """ + Perform TMA gather4 load from global memory to shared memory. + + Issues PTX instruction: + cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar]; + + This loads 4 rows (specified by row_indices) from a 2D tensor at the given + column index into shared memory, using the TMA descriptor. + + :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned) + :type tma_desc_ptr: Pointer + :param dst_smem_ptr: Destination address in shared memory + :type dst_smem_ptr: Pointer + :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking + :type mbarrier_ptr: Pointer + :param col_idx: Column index + :type col_idx: Int32 + :param row_indices: Sequence of exactly 4 row indices + :type row_indices: Sequence[Int32] + :param num_cta: Number of CTAs participating (default: 1) + :type num_cta: int + :param multicast_mask: Optional multicast mask + :type multicast_mask: Int16 + + Requirements: + - row_indices must contain exactly 4 elements + - Compute capability >= SM_100 (Blackwell) + - TMA descriptor must be properly initialized for 2D tensor + + Example: + >>> from cutlass.cute.nvgpu import cpasync + >>> from cutlass.cute import core + >>> + >>> # Create TMA descriptor + >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...) + >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom) + >>> + >>> # Compute indices (typically from kernel logic) + >>> col_idx = core.get(...) or 5 # Int32 value + >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values + >>> + >>> # Gather 4 rows at computed column + >>> tma_gather4_load( + ... tma_desc_ptr=tma_desc_ptr, + ... dst_smem_ptr=smem_ptr, + ... mbarrier_ptr=barrier_ptr, + ... col_idx=col_idx, + ... row_indices=row_indices + ... ) + """ + if len(row_indices) != 4: + raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}") + col_val = Int32(col_idx).ir_value() + row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices] + # Convert pointers to integer addresses + desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip) + if num_cta > 1: + # Executed by both CTAs. Set peer bit to 0 so that the + # transaction bytes will update CTA0's barrier. + mbar_addr = mbar_addr & Sm100MmaPeerBitMask + mbar_addr = mbar_addr.ir_value() + # Handle multicast_mask - may already be ir.Value or Python int + multicast_mask_val = None + if multicast_mask is not None: + multicast_mask_val = Int16(multicast_mask).ir_value() + assert multicast_mask_val is None, "multicast is not supported yet" + # Emit inline PTX for TMA gather4 + # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar]; + ptx = ( + f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} " + "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];" + ) + + llvm.inline_asm( + None, + [ + dst_addr, + desc_addr, + col_val, + row_vals[0], + row_vals[1], + row_vals[2], + row_vals[3], + mbar_addr, + ], + ptx, + "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy( + atom, + src[None, src_idx], + dst[None, dst_idx], + mbar_ptr=tma_bar_ptr, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +@cute.jit +def gather_m_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_M), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + tAsA = thr_copy_A.partition_D(sA) + # k-major + assert tAsA.shape[2] == 1 + tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + m_idx = cute.make_rmem_tensor(rows_per_thread, Int32) + for m in cutlass.range(rows_per_thread, unroll_full=True): + row_idx = tAcA[0, m, 0][0] + if tApA_m[m]: + m_idx[m] = gsAIdx[row_idx] + else: + m_idx[m] = 0 # It's ok to load row 0 in the case of OOB + + mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + + def copy_fn(src_idx, dst_idx, pred: bool = False): + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + mA_cur = mA_k[None, (None, src_idx)] + for m in cutlass.range_constexpr(tAcA.shape[1]): + # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape + # ((elems_per_load), thread_per_row) + # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA + # So we append 1s to the last dimension and then do tiled_divide, then slice. + mA_row = cute.tiled_divide( + cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1) + )[None, None, 0] + if const_expr(is_even_m_smem) or tApA_m[m]: + # There's only 1 load per row + assert cute.size(tAcA.shape, mode=[2]) == 1 + ki = tAcA[0, 0, 0][1] // elems_per_load + cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k) + + return copy_fn + + +@cute.jit +def gather_k_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (tile_M, whatever) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + gAIdx, sAIdx = None, None + if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem): + gAIdx = gsAIdx + else: + assert gsAIdx.memspace == cute.AddressSpace.smem + sAIdx = gsAIdx + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + # (atom_v, CPY_M, 1, STAGE) + tAsA = thr_copy_A.partition_D(sA) + # m-major + tAsA = cute.group_modes(tAsA, 0, 3) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) + # This is very convoluted but idk a better way + # for tile_M=128, flat_divide gives (8, 16, K), + # then logical_divide gives ((8, 1), (8, 2), K). + tidx = thr_copy_A.thr_idx + tAmA = cute.logical_divide( + cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col) + )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K) + + def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]: + # Prefetch mAIdx early, even before smem is free + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + gAIdx_cur = gAIdx[None, src_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + if const_expr(not pred): + k_idx[k] = gAIdx_cur[col_idx] + else: + if tApA_k[k]: + k_idx[k] = gAIdx_cur[col_idx] + else: + k_idx[k] = -1 + return k_idx, tApA_k + + def prefetch_from_smem_fn( + a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False + ) -> Tuple[cute.Tensor, cute.Tensor]: + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_cur = sAIdx[None, dst_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + k_idx[k] = sAIdx_cur[col_idx] + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return k_idx, tApA_k + + def copy_fn( + src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False + ): + k_idx, tApA_k = k_idx_tApA_k + tApA_k_pred = None + if const_expr(pred): + tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread) + for k in cutlass.range_constexpr(tAcA.shape[2]): + # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2)) + for m in cutlass.range_constexpr(tAcA.shape[1]): + if tApA_m[m]: + cute.copy( + thr_copy_A, + tAmA[None, m, k_idx[k]], + tAsA[(None, m, k), dst_idx], + pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k], + ) + + return copy_fn, prefetch_from_gmem_fn if const_expr( + gAIdx is not None + ) else prefetch_from_smem_fn + + +@cute.jit +def gather_m_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # ((4, 32), (64, 1), STAGE) + sAIdx: cute.Tensor, # (tile_M), + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Callable: + tile_M = cute.size(sAIdx, mode=[0]) + tile_K = cute.size(sA[None, None, 0]) // tile_M + assert tile_M % 4 == 0 + # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2 + cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel + + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout + ) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) + # ((4, 1), 8, (64, 1), STAGE) + tSR_sA = warp_copy_AIdx_s2r.partition_S(sA) + tSR_rAIdx = load_s2r(tSR_sAIdx) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + col_idx = tile_K * src_idx + for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, m] for v in range(4)] + smem_ptr = tSR_sA[None, m, None, dst_idx].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) + + return copy_fn diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c92cf39ac08b92245316da46526494d7d8370e1 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from functools import lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float16, BFloat16, Float32 +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: Float16, + torch.bfloat16: BFloat16, + torch.float32: Float32, + torch.int32: Int32, + torch.int64: Int64, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quack/layout_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/quack/layout_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..099e0daf54cdac4b25b6d96f01b35451c810249b --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quack/layout_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, const_expr + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: + shape = (*a.shape[:dim], size, *a.shape[dim:]) + stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +@cute.jit +def permute_gated_Cregs_b16(t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower) + + +@cute.jit +def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 + a b | c d | e f | g h + to + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [2, 0, 3, 1] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b10 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a b | c d | e f | g h -> a b | c d | f e | h g + left0 = left if quad_idx < 2 else right + right0 = right if quad_idx < 2 else left + # a b | c d | f e | h g -> a b | f d | c e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a e | f b | c g | h d + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a e | f b | c g | h d -> a e | b f | c g | d h + t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0 + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + + +@cute.jit +def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + to + T0 | T1 | T2 | T3 + a b | c d | e f | g h + This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [1, 3, 0, 2] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b01 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + # This is just the inverse of permute_Cregs_b32_for_stsm + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a e | b f | c g | d h -> a e | f b | c g | h d + left0 = left if quad_idx % 2 == 0 else right + right0 = right if quad_idx % 2 == 0 else left + # a e | f b | c g | h d -> a b | f d | c e | h g + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a b | c d | f e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | c d | f e | h g -> a b | c d | e f | g h + t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0 + + +@cute.jit +def concat_layout(*layouts: cute.Layout) -> cute.Layout: + return cute.make_layout( + tuple(l.shape for l in layouts), + stride=tuple(l.stride for l in layouts), + ) + + +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def convert_layout_zero_stride( + input: cute.Tensor | cute.Layout, ref_layout: cute.Layout +) -> cute.Layout: + layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input + # Group the modes with non-zero stride in the ref_layout together, + # and the modes with zero stride together + layout_flat = cute.flatten(layout) + ref_layout_flat = cute.flatten(ref_layout) + nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0] + zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0] + # There's an edge case when all modes are zero stride + new_shape = ( + tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,), + tuple(layout_flat[i].shape for i in zero_modes), + ) + new_stride = ( + tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,), + tuple(layout_flat[i].stride for i in zero_modes), + ) + out_layout = cute.make_layout(new_shape, stride=new_stride) + if const_expr(isinstance(input, cute.Tensor)): + return cute.make_tensor(input.iterator, out_layout) + else: + return out_layout + + +def mma_partition_C_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def mma_partition_A_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quantize.py b/build/torch211-cxx11-cu130-x86_64-linux/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4719a4854bc9388b2a866598f9e21c1f14921181 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quantize.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Transformer Engine NVFP4 quantization helper. + +This file is intended as a customer-facing example for preparing KV tensors +for the KVFP4 attention kernel: + - BF16/FP16 K/V input + - packed E2M1 FP4 data from Transformer Engine + - E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout + - one FP32 tensor/global scale per tensor +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + + +NVFP4_BLOCK_SIZE = 16 +NVFP4_FP4_MAX = 6.0 +NVFP4_FP8_E4M3_MAX = 448.0 + + +@dataclass(frozen=True) +class Nvfp4QuantizedTensor: + """Packed NVFP4 tensor plus dequantization metadata. + + Attributes + ---------- + data : torch.Tensor + Packed E2M1 FP4 data from Transformer Engine. The last dimension is + half of the original logical last dimension because each byte stores + two FP4 values. + scale_128x4 : torch.Tensor + E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage. + global_scale : torch.Tensor + FP32 tensor/global dequant scale. + logical_scale_shape : tuple[int, int] + Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling. + original_shape : tuple[int, ...] + Original BF16/FP16 tensor shape before quantization. + """ + + data: torch.Tensor + scale_128x4: torch.Tensor + global_scale: torch.Tensor + logical_scale_shape: Tuple[int, int] + original_shape: Tuple[int, ...] + + +def _round_up(x: int, multiple: int) -> int: + return ((int(x) + multiple - 1) // multiple) * multiple + + +def nvfp4_scale_128x4_offset( + row: torch.Tensor, + col: torch.Tensor, + scale_cols: int, +) -> torch.Tensor: + """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage. + + Parameters + ---------- + row : torch.Tensor + Logical row indices. + col : torch.Tensor + Logical scale-column indices. + scale_cols : int + Logical number of scale columns before padding to a multiple of 4. + + Returns + ------- + torch.Tensor + Flat offsets into the padded 128x4 tiled storage. + """ + + tiles_n = _round_up(scale_cols, 4) // 4 + tile_m = row // 128 + tile_n = col // 4 + outer = row % 128 + inner = col % 4 + return ( + (tile_m * tiles_n + tile_n) * 512 + + (outer % 32) * 16 + + (outer // 32) * 4 + + inner + ) + + +def swizzle_nvfp4_scale_to_128x4( + scale: torch.Tensor, + *, + rows: int, + cols: int, +) -> torch.Tensor: + """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout. + + Parameters + ---------- + scale : torch.Tensor + Logical rowwise scale tensor with at least shape ``[rows, cols]``. + rows : int + Number of logical rows to convert. + cols : int + Number of logical scale columns to convert. + + Returns + ------- + torch.Tensor + Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)`` + and swizzled into 128x4 tiled storage. + """ + + if scale.ndim != 2: + raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}") + + rows = int(rows) + cols = int(cols) + padded_rows = _round_up(rows, 128) + padded_cols = _round_up(cols, 4) + if scale.shape[0] < rows or scale.shape[1] < cols: + raise ValueError( + "scale is smaller than the requested logical shape: " + f"got {tuple(scale.shape)}, need at least {(rows, cols)}" + ) + + logical = scale[:rows, :cols].contiguous() + if logical.shape != (padded_rows, padded_cols): + logical = torch.nn.functional.pad( + logical.to(torch.float32), + (0, padded_cols - cols, 0, padded_rows - rows), + ).to(scale.dtype) + swizzled = torch.empty_like(logical) + + row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None] + col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1) + swizzled.reshape(-1)[offset] = logical.reshape(-1) + return swizzled + + +def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor: + """Compute TE NVFP4 tensor/global dequant scale from rowwise amax. + + Parameters + ---------- + amax : torch.Tensor + Rowwise absolute maxima returned by Transformer Engine. + + Returns + ------- + torch.Tensor + FP32 global scale equal to ``amax / (448 * 6)``. + """ + + return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX) + + +def _import_te_nvfp4_quantizer(): + try: + from transformer_engine.pytorch.tensor import NVFP4Quantizer + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "Transformer Engine NVFP4 quantization is unavailable. Install a " + "Transformer Engine build with its PyTorch dependencies, including " + "FlashAttention v3 when required by that TE build." + ) from exc + return NVFP4Quantizer + + +def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor: + """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine. + + TE returns rowwise scales in logical padded layout. This helper returns + the scales in physical 128x4 tiled storage, so the attention kernel can + load them with ``nvfp4_scale_128x4_offset``. + + Parameters + ---------- + x : torch.Tensor + CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16, + and the flattened row dimension ``prod(x.shape[:-1])`` must also be + divisible by 16. + + Returns + ------- + Nvfp4QuantizedTensor + Packed FP4 data, 128x4-swizzled block scales, global scale, and shape + metadata needed by the KVFP4 attention kernel or by reference + dequantization. + """ + + if not x.is_cuda: + raise ValueError("NVFP4 quantization requires a CUDA tensor") + if x.dtype not in (torch.bfloat16, torch.float16): + raise TypeError(f"x must be bf16 or fp16, got {x.dtype}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}") + if x.shape[-1] % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}" + ) + + rows = 1 + for dim in x.shape[:-1]: + rows *= int(dim) + if rows % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + "flattened row dimension must be divisible by " + f"{NVFP4_BLOCK_SIZE}, got {rows}" + ) + + NVFP4Quantizer = _import_te_nvfp4_quantizer() + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False) + qx = quantizer.quantize(x.contiguous()) + meta = qx.get_metadata() + + data = meta["rowwise_data"] + if data.dtype is not torch.uint8: + data = data.view(torch.uint8) + logical_scale = meta["rowwise_scale_inv"] + amax = meta["amax_rowwise"] + scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE + scale_128x4 = swizzle_nvfp4_scale_to_128x4( + logical_scale, + rows=rows, + cols=scale_cols, + ) + global_scale = nvfp4_global_scale_from_amax(amax).contiguous() + + return Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale_128x4, + global_scale=global_scale, + logical_scale_shape=(rows, scale_cols), + original_shape=tuple(int(v) for v in x.shape), + ) + + +def quantize_kv_bf16_to_nvfp4_128x4( + k: torch.Tensor, + v: torch.Tensor, +) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]: + """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention. + + Parameters + ---------- + k : torch.Tensor + CUDA BF16 or FP16 K tensor. + v : torch.Tensor + CUDA BF16 or FP16 V tensor. + + Returns + ------- + tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor] + Quantized K and V tensors with independent scales. + """ + + return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v) + + +def dequantize_nvfp4_128x4_to_bf16( + qx: Nvfp4QuantizedTensor, + *, + include_global_scale: bool = True, +) -> torch.Tensor: + """Reference dequantization for validation. + + This mirrors the kernel contract: + x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale + + Parameters + ---------- + qx : Nvfp4QuantizedTensor + Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``. + include_global_scale : bool, optional + If True, multiply by ``qx.global_scale`` after applying per-block + scales. + + Returns + ------- + torch.Tensor + BF16 tensor with shape ``qx.original_shape``. + """ + + data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8) + if data.shape[-1] * 2 != qx.original_shape[-1]: + raise ValueError( + "packed data last dimension does not match original shape: " + f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical" + ) + + rows, scale_cols = qx.logical_scale_shape + logical_dim = int(qx.original_shape[-1]) + if scale_cols * NVFP4_BLOCK_SIZE != logical_dim: + raise ValueError( + "logical scale columns do not match original last dimension: " + f"{scale_cols} scale cols vs dim {logical_dim}" + ) + + fp4_lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=data.device, + ) + packed = data.reshape(rows, logical_dim // 2) + lo = packed & 0x0F + hi = packed >> 4 + values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device) + values[:, 0::2] = fp4_lut[lo.long()] + values[:, 1::2] = fp4_lut[hi.long()] + + row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None] + col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, scale_cols) + scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols) + scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32) + scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1) + out = values * scale + if include_global_scale: + global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32) + out = out * global_scale + return out.reshape(qx.original_shape).to(torch.bfloat16) + + +def _example() -> None: + device = torch.device("cuda") + k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v) + print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype) + print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype) + print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype) + print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype) + print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype) + print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + raise RuntimeError("quantize.py requires CUDA") + _example() diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sparse_index_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/sparse_index_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54c982c9230b189051e3a0bdf76d22b397dd62a --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sparse_index_utils.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Host-side q2k <-> k2q index conversion for sparse attention. + +These utilities prepare sparse metadata on the Python side for tests, +benchmarks, and other offline preprocessing flows. They are not kernel +runtime helpers, so they intentionally live outside `src/common`. + +Sparse attention pattern: + - Each Q token independently selects up to topK KV blocks (blk_kv tokens each). + - Under GQA, all Q heads in one group share the same sparsity pattern, + so indices are defined at the head_kv level. + +Shapes: + q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks), + trailing unused slots padded with -1 + k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1 + k2q_counts: [batch, head_kv, Nkv] int32 + +CSR reverse-index format: + q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices + k2q_row_ptr: [head_kv, total_rows + 1] int32 + k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx +""" + +from typing import Optional, Tuple + +import torch + +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + + +def q2k_to_k2q( + q2k_indices: torch.Tensor, + num_kv_blocks: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert q2k sparse indices to k2q representation. + + For each KV block, find which Q tokens attend to it. + + Args: + q2k_indices: [batch, head_kv, Sq, topK] int32. + For each Q token, the KV blocks it attends to. Unused slots must + be padded with -1. + num_kv_blocks: Total number of KV blocks (= Skv / blk_kv). + + Returns: + k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32. + For each KV block, the Q token indices that attend to it, + left-packed and padded with -1. Last dim fixed to Sq (upper bound). + k2q_counts: [batch, head_kv, num_kv_blocks] int32. + Actual number of Q tokens per KV block. + """ + B, H, Sq, topK = q2k_indices.shape + device = q2k_indices.device + N = Sq * topK + + kv_flat = q2k_indices.reshape(B, H, N).long() + valid_flat = kv_flat >= 0 + q_flat = ( + torch.arange(Sq, device=device) + .unsqueeze(-1) + .expand(Sq, topK) + .reshape(N) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, H, N) + ) + + k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + k2q_counts.scatter_add_( + 2, + safe_kv_flat, + valid_flat.to(torch.int32), + ) + + sort_keys = torch.where( + valid_flat, + kv_flat, + torch.full_like(kv_flat, num_kv_blocks), + ) + sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True) + sorted_q = q_flat.gather(-1, sort_idx) + sorted_valid = valid_flat.gather(-1, sort_idx) + + offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device) + offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long() + + global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N) + group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1)) + pos_in_group = global_pos - group_offset + + k2q_indices = torch.full( + (B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device + ) + flat_k2q = k2q_indices.reshape(B, H, -1) + flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group + for b in range(B): + for h in range(H): + valid = sorted_valid[b, h] + flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int() + + return k2q_indices, k2q_counts + + +def k2q_to_q2k( + k2q_indices: torch.Tensor, + k2q_counts: torch.Tensor, + Sq: int, + topK: int, +) -> torch.Tensor: + """Convert dense k2q indices back to q2k representation. + + Parameters + ---------- + k2q_indices : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values + are Q token indices padded with ``-1``. + k2q_counts : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of + valid Q indices per KV block. + Sq : int + Q sequence length per batch item in this dense reference format. + topK : int + Maximum number of KV blocks selected per Q token. + + Returns + ------- + torch.Tensor + Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted + by KV block index with ``-1`` padding at the tail. + """ + B, H, Nkv, _ = k2q_indices.shape + device = k2q_indices.device + + q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device) + counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device) + + for b in range(B): + for h in range(H): + for kv_blk in range(Nkv): + count = k2q_counts[b, h, kv_blk].item() + for j in range(count): + qt = k2q_indices[b, h, kv_blk, j].item() + if qt < 0: + continue + p = counters[b, h, qt].item() + if p < topK: + q2k[b, h, qt, p] = kv_blk + counters[b, h, qt] += 1 + + q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k) + _, sort_idx = q2k_sort_key.sort(dim=-1) + q2k = q2k.gather(-1, sort_idx) + return q2k + + +def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}") + if cu_seqlens.numel() < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor: + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + return (seqlens_k + kv_block_size - 1) // kv_block_size + + +def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]: + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + batch = len(rows_per_batch_cpu) + max_rows = max(rows_per_batch_cpu, default=0) + row_dtype = ( + torch.int32 + if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max + else torch.int64 + ) + row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype) + row_linear = 0 + for kv_block_idx in range(max_rows): + for batch_idx, row_count in enumerate(rows_per_batch_cpu): + if kv_block_idx < row_count: + row_map_cpu[batch_idx, kv_block_idx] = row_linear + row_linear += 1 + return row_map_cpu.to(rows_per_batch.device), row_linear + + +def build_k2q_csr_torch_reference( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, +) -> tuple: + """Torch reference for q2k -> k2q CSR conversion. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32. Values are + batch-local KV block indices padded with ``-1``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape + ``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape + ``[head_kv, total_q * topK]``. + """ + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError( + "q2k_indices must have shape [head_kv, total_q, topK], " + f"got {tuple(q2k_indices.shape)}" + ) + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + + head_kv, total_q, topk = q2k_indices.shape + if total_q != int(cu_seqlens_q[-1].item()): + raise ValueError( + f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] " + f"({int(cu_seqlens_q[-1].item())})" + ) + + rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size) + row_map, total_rows = _build_packed_row_map(rows_per_batch) + nnz_upper_bound = total_q * topk + + k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device) + k2q_q_indices = torch.full( + (head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device + ) + if total_rows == 0 or total_q == 0 or topk == 0: + return k2q_row_ptr, k2q_q_indices + + counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device) + total_entries = total_q * topk + row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64 + row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device) + q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device) + valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device) + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist() + entry_cursor = 0 + + for batch_idx, kv_rows in enumerate(rows_per_batch_cpu): + q_start = q_cu_cpu[batch_idx] + q_end = q_cu_cpu[batch_idx + 1] + q_len = q_end - q_start + if q_len == 0: + continue + num_entries = q_len * topk + q2k_batch = q2k_indices[:, q_start:q_end, :] + valid_batch = q2k_batch >= 0 + if valid_batch.any(): + max_valid_kv = int(q2k_batch[valid_batch].max().item()) + if max_valid_kv >= kv_rows: + raise ValueError( + f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, " + f"but that batch only has {kv_rows} logical kv blocks" + ) + kv_flat = q2k_batch.reshape(head_kv, num_entries).long() + valid_flat = valid_batch.reshape(head_kv, num_entries) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + row_map_batch = row_map[batch_idx] + row_flat = row_map_batch[safe_kv_flat] + q_flat = ( + torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32) + .view(1, q_len, 1) + .expand(head_kv, q_len, topk) + .reshape(head_kv, num_entries) + ) + row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat + q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat + valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat + counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32)) + entry_cursor += num_entries + + k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32) + + sort_stride = max(total_q, 1) + invalid_key = total_rows * sort_stride + max_sort_key = invalid_key + max(total_q - 1, 0) + if max_sort_key < torch.iinfo(torch.int32).max: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32) + sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all] + else: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64) + sort_keys[valid_all] = ( + row_all[valid_all].to(torch.int64) * sort_stride + + q_all[valid_all].to(torch.int64) + ) + _, sort_idx = sort_keys.sort(dim=1, stable=True) + sorted_q = q_all.gather(1, sort_idx) + + valid_counts = valid_all.sum(dim=1) + write_mask = ( + torch.arange(total_entries, device=q2k_indices.device) + .unsqueeze(0) + .expand(head_kv, -1) + < valid_counts.unsqueeze(1) + ) + k2q_q_indices[write_mask] = sorted_q[write_mask] + + return k2q_row_ptr, k2q_q_indices + + +_K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100() + + +def build_k2q_csr( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, + *, + total_k: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]: + """Build the public k2q CSR reverse index on GPU. + + Runtime construction does not read device-side ``cu_seqlens`` on the host, + so callers must provide size hints such as ``total_k`` from already-known + tensor shapes. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are + batch-local KV block indices with trailing ``-1`` padding. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + total_k : int + Total KV token count. Required; normally ``k.shape[0]`` for dense KV + or ``sum(kv_segment_lens)`` for paged KV. + max_seqlen_k : int, optional + Maximum KV sequence length. Passing this avoids recomputing a bound. + max_seqlen_q : int, optional + Maximum Q sequence length. + total_rows : int, optional + Total number of packed KV-block rows across the batch. If omitted, + the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``. + qhead_per_kv : int, optional + Number of Q heads per KV head under GQA. + return_schedule : bool, optional + If True, also return the sparse forward schedule object produced by the + SM100 builder. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object] + ``(k2q_row_ptr, k2q_q_indices)`` or + ``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on + the same CUDA device as ``q2k_indices``. + """ + if total_k is None: + raise ValueError("build_k2q_csr requires total_k from k.shape[0]") + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}") + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + return _K2Q_CSR_BUILDER( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + total_k=int(total_k), + blk_kv=int(kv_block_size), + max_seqlen_k=max_seqlen_k, + max_seqlen_q=max_seqlen_q, + total_rows=total_rows, + qhead_per_kv=qhead_per_kv, + return_schedule=return_schedule, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/aot_cache.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/aot_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd0b4da4ddb6fba21bcb18c924f5e9e8b583e6 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/aot_cache.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Persistent AOT cache for CuTe DSL compiled kernels. + +Saves compiled TVM FFI kernels as .o files on first compile, +loads them on subsequent runs to skip JIT compilation. + +Environment variables: + MM_SPARSE_ATTN_AOT_CACHE: Override cache directory + (default: ~/.cache/minfer/mm_sparse_attn) + MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely +""" + +import hashlib +import os +import time + +import cutlass.cute as cute + +_AOT_CACHE_DIR = os.environ.get( + "MM_SPARSE_ATTN_AOT_CACHE", + os.path.expanduser("~/.cache/minfer/mm_sparse_attn"), +) +_AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1" + +_loaded_modules: dict[str, object] = {} + + +def _key_to_path(key: tuple) -> str: + h = hashlib.sha256(repr(key).encode()).hexdigest()[:16] + name = str(key[0]).replace("/", "_") + return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}") + + +def try_load_aot(key: tuple): + if _AOT_DISABLE: + return None + obj_path = _key_to_path(key) + ".o" + if not os.path.isfile(obj_path): + return None + func_name = str(key[0]) + try: + if obj_path not in _loaded_modules: + _loaded_modules[obj_path] = cute.runtime.load_module( + obj_path, enable_tvm_ffi=True + ) + return getattr(_loaded_modules[obj_path], func_name) + except Exception as e: + print(f"[aot_cache] Failed to load {obj_path}: {e}") + return None + + +def save_aot(key: tuple, compiled) -> None: + if _AOT_DISABLE: + return + if not hasattr(compiled, "export_to_c"): + return + obj_path = _key_to_path(key) + ".o" + os.makedirs(_AOT_CACHE_DIR, exist_ok=True) + tmp_path = obj_path + f".tmp.{os.getpid()}" + func_name = str(key[0]) + try: + t0 = time.time() + compiled.export_to_c(tmp_path, function_name=func_name) + os.replace(tmp_path, obj_path) + dt = time.time() - t0 + print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)") + except Exception as e: + print(f"[aot_cache] Failed to save {func_name}: {e}") + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/barrier.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5753a8a175b529567e0be238f47fd4cc8401bf --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/barrier.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd22f7efa3cef9988b4036c2d00fc1d3b9c816e8 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py @@ -0,0 +1,1093 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +from . import mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, + num_unroll_groups: int = 1, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range( + cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups + ): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, + acc_tmem_addr, + rA, + rB, + sA_cur, + sB_cur, + zero_init=zero_init, + cta_group=cta_group, + **kwargs, + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + split_arrive: Optional[int] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + assert split_arrive is not None, ( + "split_arrive must be provided when mbar_ptr is not None" + ) + split_arrive_idx = split_arrive // op.shape_mnk[2] + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed( + acc_tmem_addr: Int32, + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_start_b: Int32, + idesc: int, + smem_desc_base_a: Optional[int], + smem_desc_base_b: int, + tCrA_layout: cute.Layout, + tCrB_layout: cute.Layout, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + else: + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] + + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + # smem_desc_start_a_lo = smem_desc_start_a + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(num_k_tile // 4 * 3, num_k_tile) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_smem_desc( + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_base_a: Optional[int], + tCrA_layout: cute.Layout, + var_name_prefix: str = "smem_desc", +) -> None: + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + smem_desc_base_a_lo, smem_desc_a_hi = None, None + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + if const_expr(not is_ts): + llvm.inline_asm( + None, + [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], + f".reg .b32 {var_name_prefix}_lo;\n\t" + f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" + f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + + "".join( + ( + f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" + f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" + ) + for k in range(1, num_k_tile) + ), + "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: + idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) + llvm.inline_asm( + None, + [], + f".reg .b32 {var_name};\n\t" # noqa + f"mov.b32 {var_name}, {hex(idesc)};\n\t", + constraints="", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed_varname( + acc_tmem_addr: Int32, + smem_desc_start_b: Int32, + # idesc: int, + smem_desc_base_b: int, + tCrB_layout: cute.Layout, + smem_var_name_prefix: str, + idesc_var_name: str, + smem_offset: int, + zero_init: bool | Boolean = False, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + is_ts = False + num_k_tile = cute.size(tCrB_layout.shape[2]) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + # ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + # ".reg .b64 smem_desc_b;\n\t" + f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + # f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $2;\n\t" + "mov.b32 smem_desc_b_lo_start, $0;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + + "".join( + ( + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + ) + for k in range(1, num_k_tile) + ) + + "setp.ne.b32 p, $1, 0;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + + "".join( + ( + # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/block_info.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/block_info.py new file mode 100644 index 0000000000000000000000000000000000000000..463290ab3b022a8883e7d40b84ff1ab31827e5dc --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/block_info.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...src.common.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n)) + n_block_min = 0 + if num_splits > 1: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) + return n_block_min, n_block_max + + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_block_max = cute.ceil_div( + seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m + ) + m_block_min = 0 + if const_expr(self.is_causal): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx *= self.qhead_per_kvhead_packgqa + m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m) + return m_block_min, m_block_max diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/copy_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98ba5f40b7b9543744e663a96bcdf637c7e2a146 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/copy_utils.py @@ -0,0 +1,1179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Copy, store, and layout execution helpers. + +`copy_utils.py` is the canonical owner for generic copy primitives, async +bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers. +""" + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +# Generic Copy Primitives + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +# Store/Layout Helpers + +@dsl_user_op +def atomic_add_i32(gmem_ptr, *, loc=None, ip=None): + """Simple atomicAdd. Intended for use under a single-thread guard.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "atom.global.add.u32 $0, [$1], 1;\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None): + """Lane-0 atomicAdd broadcast to the whole warp via shfl.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "{\n" + ".reg .pred p;\n" + ".reg .u32 lane, r;\n" + "mov.u32 lane, %laneid;\n" + "mov.u32 r, 0;\n" + "setp.eq.u32 p, lane, 0;\n" + "@p atom.global.add.u32 r, [$1], 1;\n" + "shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n" + "mov.u32 $0, r;\n" + "}\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def stg_128( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.bf16.f32 h0, $5;\n" + "cvt.rn.bf16.f32 h1, $6;\n" + "cvt.rn.bf16.f32 h2, $7;\n" + "cvt.rn.bf16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.f16.f32 h0, $5;\n" + "cvt.rn.f16.f32 h1, $6;\n" + "cvt.rn.f16.f32 h2, $7;\n" + "cvt.rn.f16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_32_fp8_e4m3( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n" + "mov.b32 p0, {h0, h1};\n" + "st.global.b32 [$4], p0;\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_bf16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two bf16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.bf16.f32 h0, $1;\n" + "cvt.rn.bf16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_f16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two fp16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .f16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.f16.f32 h0, $1;\n" + "cvt.rn.f16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_fp8_e4m3_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [ + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + ] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + Float32(v8).ir_value(loc=loc, ip=ip), + Float32(v9).ir_value(loc=loc, ip=ip), + Float32(v10).ir_value(loc=loc, ip=ip), + Float32(v11).ir_value(loc=loc, ip=ip), + Float32(v12).ir_value(loc=loc, ip=ip), + Float32(v13).ir_value(loc=loc, ip=ip), + Float32(v14).ir_value(loc=loc, ip=ip), + Float32(v15).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n" + "cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n" + "cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n" + "cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n" + "cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n" + "cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n" + "cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; " + "mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; " + "mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; " + "mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; " + "mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;", + ( + "=f,=f,=f,=f,=f,=f,=f,=f," + "=f,=f,=f,=f,=f,=f,=f,=f," + "l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f" + ), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + *acc_layout_col_major.shape[2:], + ), + stride=( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + *acc_layout_col_major.stride[2:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor: + layout = convert_layout_acc_mn( + convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout) + ) + return cute.make_tensor(tensor.iterator, layout) + + +def real_col_to_stg128_fake_col(col: Int32) -> Int32: + nt = col // Int32(16) + col16 = col - nt * Int32(16) + pair = col16 // Int32(2) + rank = pair % Int32(4) + kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2)) + return nt * Int32(16) + rank * Int32(4) + kv + + +def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(16) + fake16 = fake_col - nt * Int32(16) + rank = fake16 // Int32(4) + kv = fake16 % Int32(4) + return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2)) + + +def real_col_to_stg128_half_fake_col(col: Int32) -> Int32: + nt = col // Int32(32) + col32 = col - nt * Int32(32) + lane = (col32 % Int32(8)) // Int32(2) + group = col32 // Int32(8) + elem = col32 % Int32(2) + return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem + + +def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(32) + fake32 = fake_col - nt * Int32(32) + lane = fake32 // Int32(8) + lane_slot = fake32 - lane * Int32(8) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem + + +def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32: + nt = col // Int32(64) + col64 = col - nt * Int32(64) + lane = (col64 % Int32(8)) // Int32(2) + group = col64 // Int32(8) + elem = col64 % Int32(2) + return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem + + +def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(64) + fake64 = fake_col - nt * Int32(64) + lane = fake64 // Int32(16) + lane_slot = fake64 - lane * Int32(16) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem + + +# Cluster & Bulk Async Ops + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_s2cluster( + smem_src_ptr: cute.Pointer, + smem_dst_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + size: int | Int32, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +): + smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() + smem_dst_ptr_i32 = set_block_rank( + smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [ + smem_dst_ptr_i32, + smem_src_ptr_i32, + mbar_ptr_i32, + Int32(size).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +# TMA Copy Adapters + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +__all__ = [ + "atomic_add_broadcast_i32", + "atomic_add_fp32x4", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "copy", + "cpasync_bulk_g2s", + "cpasync_bulk_get_copy_fn", + "cpasync_bulk_s2cluster", + "cpasync_reduce_bulk_add_f32", + "cvt_copy", + "get_copy_atom", + "load_s2r", + "make_16x256b_tensor_mn_view", + "make_tmem_copy", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "set_block_rank", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "sts_32_bf16", + "sts_32_f16", + "store_shared_remote_fp32x4", + "tiled_copy_1d", + "tiled_copy_2d", + "tma_get_copy_fn", + "tma_producer_copy_fn", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3473fbbf77fa1261abfc8fd960102c70d3e64bd --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import logging +import os +import pathlib +import time +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +logger = logging.getLogger("minimax") + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta +from cutlass.cute.runtime import from_dlpack + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile. + + Behaviour: + - Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set. + - Logs JIT compile wall time at DEBUG level via the ``minimax`` logger, + tagged with the kernel's class name when available. Enable with + ``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env + ``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT + (~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md). + """ + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + kernel_obj = args[0] if args else kwargs.get("op") + kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "" + t0 = time.time() + output = cute_compile_og(*args, **kwargs) + dt = time.time() - t0 + logger.debug("[%s] compiled in %.1fs", kernel_name, dt) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1": + if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + logger.addHandler(_h) + logger.setLevel(logging.DEBUG) + + +# Monkey-patch cute.compile so every JIT compile across the repo gets timed +# without touching individual call sites. Idempotent: only patches once. +if cute.compile is not cute_compile_patched: + cute.compile = cute_compile_patched + + +def assume_strides_aligned(t): + """Assume all strides except the last are divisible by 128 bits. + + Python int strides (e.g., stride=0 from GQA expand) are kept as-is + since they're static and don't need alignment assumptions. + """ + divby = 128 // t.element_type.width + strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1]) + return (*strides, t.stride[-1]) + + +def assume_tensor_aligned(t): + """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.""" + if t is None: + return None + return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t))) + + +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/fast_math.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/fast_math.py new file mode 100644 index 0000000000000000000000000000000000000000..63a8b4a501ac499e372056a07d499832c830b474 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/fast_math.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/mask.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0da42c3be9bf1c3dcff81ccde579b54131bfa4c6 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/mask.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Callable, Optional, TypeAlias +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Uint32, const_expr + +from ...src.common import utils as utils +from ...src.common.seqlen_info import SeqlenInfoQK + +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + + +@cute.jit +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) + + +@cute.jit +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) + + +@cute.jit +def mask_r2p_lambda( + X: cute.Tensor, + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, +) -> None: + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)): + mask = mask_gen_fn(s) + for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * MASK_R2P_CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) + + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_info: SeqlenInfoQK + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + swap_AB: cutlass.Constexpr[bool] = False + + @property + def seqlen_q(self) -> Int32: + return self.seqlen_info.seqlen_q + + @property + def seqlen_k(self) -> Int32: + return self.seqlen_info.seqlen_k + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + row_idx: Optional[Int32] = None, + kv_valid_cols: Optional[Int32] = None, + kv_block_col_start: Optional[Int32] = None, + ) -> None: + if const_expr(not mask_seqlen and not mask_causal): + return + + col_limit = Int32(self.tile_n) + if const_expr(mask_seqlen): + if const_expr(kv_valid_cols is not None): + col_limit = kv_valid_cols + else: + col_limit = self.seqlen_k - n_block * Int32(self.tile_n) + + if const_expr(mask_causal): + if const_expr(row_idx is None): + row_axis = 0 if const_expr(not self.swap_AB) else 1 + row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa) + else: + row_idx_cur = row_idx + if const_expr(kv_block_col_start is not None): + block_col_start = kv_block_col_start + else: + block_col_start = n_block * Int32(self.tile_n) + causal_col_limit = ( + row_idx_cur + self.seqlen_k - self.seqlen_q + - block_col_start + Int32(1) + ) + col_limit = ( + cutlass.min(col_limit, causal_col_limit) + if const_expr(mask_seqlen) + else causal_col_limit + ) + + if col_limit < Int32(self.tile_n): + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit, s), + rank1=True, + ) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + is_full_block: bool = False, + check_m_boundary: bool = True, + valid_tok_count: Optional[Int32] = None, + q_idx_tile: Optional[cute.Tensor] = None, + masked_tok_count: Optional[Int32] = None, + ) -> None: + del is_full_block, check_m_boundary + del t0ScS_t2r + row_axis = 0 if const_expr(not self.swap_AB) else 1 + col_axis = 1 if const_expr(not self.swap_AB) else 0 + + if const_expr(valid_tok_count is not None): + kv_block_col_start = n_block * Int32(self.tile_n) + causal_q_offset = self.seqlen_k - self.seqlen_q + nfrag = const_expr(cute.size(acc_S.shape)) + for i in cutlass.range(nfrag, unroll_full=True): + row_idx = tScS_t2r[i][row_axis] + tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa) + acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i] + if const_expr(mask_seqlen): + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i] + if const_expr(mask_causal): + if const_expr(q_idx_tile is not None): + causal_tok_count = ( + masked_tok_count + if const_expr(masked_tok_count is not None) + else Int32(0) + ) + if tok_idx < causal_tok_count: + q_idx = q_idx_tile[tok_idx] + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = ( + -Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i] + ) + return + + thr_col_offset = tScS_t2r[0][col_axis] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + + if const_expr(not mask_causal): + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + return + + thr_row_offset = tScS_t2r[0][row_axis] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + row_limit_top = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + num_rep = cute.size(tScS_t2r, mode=[0]) + row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..53c58d17f5085d207f2a1d7b6b45d627ff3322e3 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT +# +# The bit-field encodings, enum values, and descriptor layout below mirror the +# SM100 tcgen05 MMA instruction descriptor as documented and +# implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay +# identical to the hardware/ISA encodings; see the "Third-party licenses" +# section of README.md at the repo root for attribution. + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix "layout" in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type -> encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 + if cutlass_type is cutlass.Float8E4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.Float8E5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for SM100 MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + # CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this + # descriptor aligned with generated/reference SM100 FP8 kernels. + desc |= (int(is_f8f6f4) & 0x1) << 23 + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 + + +def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: + sA_swizzle = sA.iterator.type.swizzle_type + return make_smem_desc_base( + cute.recast_layout(128, sA.element_type.width, sA.layout[0]), + sA_swizzle, + major, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/named_barrier.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/named_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..a7722a471ca011a94d5fd7774224906001979b78 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/named_barrier.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import enum + + +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + LoadWG = enum.auto() + StoreEpilogue = enum.auto() + KvLoad = enum.auto() + KvDequantK = enum.auto() + KvDequantV = enum.auto() diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dc25edd3f48fbe2c77ec94c8ab3f1ea417507 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""PackGQA primitives for GQA (grouped-query attention) tile layouts. + +Contains: +- ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead`` + into the seqlen dimension of a tensor layout (zero-copy view). +- ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / + ``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)`` + as a single packed row dimension. +- ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE`` + for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q + innermost (stride-1). +""" + +from dataclasses import dataclass +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ...quack import layout_utils + +from . import utils + + +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass +class PackGQA: + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + +@dataclass +class PackGQAComb(PackGQA): + """PackGQA subclass for the K2 combine kernel. + + Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from + ``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when + LSE_partial is laid out with H_q innermost. + + K2 combine treats each query head independently (no GQA grouping in combine + itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller — + all heads are folded into one "group" per Sq position. + """ + + @cute.jit + def load_LSE( + self, + mLSE_partial: cute.Tensor, + # Packed layout after caller-side reshape: + # shape ((qhead_per_kvhead, seqlen_q), num_splits) + # stride ((1, qhead_per_kvhead), ...) + # — H_q is the innermost (stride-1) element of the packed first dim. + sLSE: cute.Tensor, + # SMEM destination: ``(topk, m_block_size)`` fp32. + topk: cutlass.Constexpr[int], + # Explicit topk so the identity tensor shape is a plain int, + # avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape. + gmem_tiled_copy: cute.TiledCopy, + tidx: Int32, + block: Int32, + num_splits: Int32, + seqlen: Int32, + num_heads_divmod: FastDivmodDivisor, + mCounter: Optional[cute.Tensor] = None, + batch_idx: Optional[Int32] = None, + qhead_per_kvhead: Int32 = Int32(1), + # divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so + # caller controls whether the divisor is constexpr or a runtime value. + ): + """Coalesced GMEM→SMEM async load of LSE_partial for one tile. + + For each (split, row) slot this thread owns in the tile, compute the + GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32. + Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``) + are filled with ``-inf`` so they flow cleanly through downstream reductions. + + Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values + (head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to + adjacent GMEM addresses when H_q is stride-1 — one sector per warp. + """ + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cLSE = cute.make_identity_tensor((topk, self.m_block_size)) + tLSEcLSE = gmem_thr_copy.partition_S(cLSE) + tLSEsLSE = gmem_thr_copy.partition_D(sLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = block * self.m_block_size + mi + m_pos, h_pos = divmod(idx, num_heads_divmod) + + if m_pos < seqlen: + row_count = ( + mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead] + if const_expr(mCounter is not None) + else num_splits + ) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + # Build a 1-element GMEM tensor at ((h_pos, m_pos), si), + # matching PackGQA.store_LSE's ptr pattern so cute.copy + # receives a proper Tensor, not a scalar. + src_ptr_i64 = utils.elem_pointer( + mLSE_partial, ((h_pos, m_pos), si)).toint() + src_ptr = cute.make_ptr( + Float32, src_ptr_i64, + cute.AddressSpace.gmem, assumed_align=4, + ) + src_t = cute.make_tensor(src_ptr, (1,)) + cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/paged_kv.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/paged_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5f6923c42a826d4f3dd1f192ce2fdb38eefbf5 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/paged_kv.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + + +@dataclass(frozen=True) +class PagedKVManager: + mPageTable: cute.Tensor + page_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + + @staticmethod + def create( + mPageTable: cute.Tensor, + *, + page_size: int, + n_block_size: int, + ): + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + return PagedKVManager( + mPageTable, + page_size=page_size, + n_block_size=n_block_size, + ) + + @cute.jit + def logical_length( + self, + batch_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + if const_expr(mSeqUsedK is not None): + return mSeqUsedK[batch_idx] + return num_kv_blocks * Int32(self.n_block_size) + + @cute.jit + def valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def physical_block_index( + self, + batch_idx: Int32, + kv_block_idx: Int32, + ) -> Int32: + return self.mPageTable[batch_idx, kv_block_idx] + +__all__ = ["PagedKVManager"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/pipeline.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..27f711772f5c6fa16a86f4aa305f42a0ca9322eb --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/pipeline.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate, dsl_user_op +from cutlass.pipeline import PipelineState +from cutlass.pipeline import PipelineUserType +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg +import cutlass.pipeline as cutlass_pipeline + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """Compatibility wrapper for FA-style helpers now vendored into src.common.""" + return cutlass_pipeline.make_pipeline_state(type, stages) + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + @staticmethod + def create(*args, **kwargs): + obj = NamedBarrierOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", NamedBarrier) + return obj + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +@dataclass(frozen=True) +class PipelineAsync(PipelineAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineAsync + object.__setattr__(obj, "__class__", PipelineAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_try_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + *, + loc=None, + ip=None, + ): + return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineTmaAsyncOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineTmaAsync) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineTmaUmma + object.__setattr__(obj, "__class__", PipelineTmaUmma) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx( + state.index, tx_count, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineUmmaAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineUmmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineUmmaAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsyncUmmaOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineAsyncUmma) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py new file mode 100644 index 0000000000000000000000000000000000000000..873304f71c2cb47ffdd1453fe771c754783f51a4 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...quack import copy_utils + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + + +@dataclass(frozen=True) +class SeqlenInfo: + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + batch_idx: Int32, + seqlen_static: Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, + ): + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) + ) + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] + else: + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + return SeqlenInfoQK( + offset_q, + offset_k, + padded_offset_q, + padded_offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, + ) + + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + else: + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) + + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) + else: + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/softmax.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8f94c1c9e40aeb44c0a128165d90a502feb04afd --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/softmax.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Online softmax primitives. + +Contains: +- ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O. + The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed + ``fmul.f32x2`` (2× CUDA-core throughput) when available. +- ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``, + ``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel. +""" + +import math +import operator +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from ...quack import layout_utils +from ...quack.cute_dsl_utils import ParamsBase + +from . import utils + + +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is + rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL + compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU + confirms: FFMA2 count is 0 for the TensorSSA path). The packed + rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA + instruction count for the softmax scale/subtract stage. + """ + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + row_scale = cute.make_rmem_tensor_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + + for r in cutlass.range(cute.size(row_max), unroll_full=True): + acc_S_row_slice = acc_S_mn[r, None] + acc_S_row = acc_S_row_slice.load() + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + + row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4) + row_max_prev = row_max[r] + row_max[r] = row_max_cur + + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + + row_max_cur_scaled = row_max_cur * scale_log2 + minus_row_max_scaled = -row_max_cur_scaled + n = cute.size(acc_S_row_slice) + + if cutlass.const_expr(arch >= 100 and n % 2 == 0): + # SM100 packed f32x2 FMA path: scale + subtract in one pass. + for i in cutlass.range(0, n, 2, unroll_full=True): + acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row_slice[i], acc_S_row_slice[i + 1]), + (scale_log2, scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + for i in cutlass.range(n, unroll_full=True): + acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True) + acc_S_row_exp = acc_S_row_slice.load() + else: + acc_S_row_exp = cute.math.exp2( + acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True + ) + acc_S_row_slice.store(acc_S_row_exp) + + if cutlass.const_expr(is_first): + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) + row_scale[r] = 1.0 + else: + row_scale[r] = cute.math.exp2( + (row_max_prev - row_max_cur) * scale_log2, fastmath=True + ) + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_sum[r] = acc_S_row_sum + + return row_scale + + @cute.jit + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp. + + On SM100+ with an even ``num_rows`` and no sink_val, the loop is + unrolled in pairs so the key per-row arithmetic ― rcp*final_scale, + max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one + ``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more + ``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare). + """ + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_rmem_tensor_like(row_max, Float32) + + LN2 = math.log(2.0) + num_rows = cute.size(row_sum) + use_packed = cutlass.const_expr( + self.arch >= 100 and num_rows % 2 == 0 and sink_val is None + ) + + if use_packed: + for r in cutlass.range(0, num_rows, 2, unroll_full=True): + s0 = row_sum[r] + s1 = row_sum[r + 1] + m0 = row_max[r] + m1 = row_max[r + 1] + bad0 = s0 == 0.0 or s0 != s0 + bad1 = s1 == 0.0 or s1 != s1 + + # row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar + # (no packed rcp intrinsic); the trailing multiply packs. + rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0) + rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1) + row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2( + (rcp0, rcp1), (final_scale, final_scale) + ) + + # LSE = (row_max * scale_log2 + log2(row_sum)) * LN2 + # packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2. + log0 = cute.math.log2(s0, fastmath=True) + log1 = cute.math.log2(s1, fastmath=True) + lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2( + (m0, m1), (scale_log2, scale_log2), (log0, log1) + ) + lse_0, lse_1 = cute.arch.mul_packed_f32x2( + (lse_pre_0, lse_pre_1), (LN2, LN2) + ) + row_sum[r] = -Float32.inf if bad0 else lse_0 + row_sum[r + 1] = -Float32.inf if bad1 else lse_1 + else: + for r in cutlass.range(num_rows, unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) + + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + row_scale[r] = ( + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = row_sum[r] + row_sum[r] = ( + (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor.""" + acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + n = cute.size(acc_O_mn, mode=[1]) + if cutlass.const_expr(self.arch >= 100 and n % 2 == 0): + # SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput). + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + scale = row_scale[r] + for j in cutlass.range(0, n, 2, unroll_full=True): + acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2( + (acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale) + ) + else: + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + """SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths.""" + + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + @cute.jit + def update_row_max_deferred_exp2( + self, + acc_S_row: cute.TensorSSA, + is_first: int, + ) -> Tuple[Float32, Float32]: + """update_row_max variant that publishes the log2-delta (un-exp2'd) so + the consumer can do the exp2 only when an actual rescale fires. + + Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where: + - ``row_max_safe`` is the same row-max as ``update_row_max`` (with + ``rescale_threshold`` rollback applied). + - ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when + the threshold rollback fired (consumer treats as no rescale), else + the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2`` + (consumer computes ``cute.math.exp2`` and rescales). + + This keeps MUFU.EX2 off the sm_stats publication critical path that + gates the correction WG's consumer wait. + """ + publish = Float32(0.0) + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + # publish stays 0.0 (signal: no rescale needed) + else: + publish = acc_scale_ + else: + publish = acc_scale_ + self.row_max[0] = row_max_new + return row_max_safe, publish + + @cute.jit + def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0]) + self.row_max[0] = row_max_new + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + + @cute.jit + def compute_scaled_exp2_row_sum( + self, + acc_S_row: cute.Tensor, + scale: Float32, + ) -> Float32: + return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch) + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + else: + if cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True + ) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert_sum( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + init_sum: Float32, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ) -> Float32: + # When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res + # pairs in the inner loop use the FFMA2-based polynomial ex2 emulation + # (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's + # apply_exp2_convert. This removes the MUFU "wait" stall that dominates + # the second-largest stall bucket in decode (~22% of total). + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + acc_sum = (init_sum, Float32(0.0)) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2( + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + use_real = cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ) + if cutlass.const_expr(use_real): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + utils.ex2_emulation_2( + acc_S_row_frg[k, j], + acc_S_row_frg[k + 1, j], + ) + ) + acc_sum = cute.arch.add_packed_f32x2( + acc_sum, + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + return acc_sum[0] + acc_sum[1] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..985b4289e146288355dfecd7169383eb64df4f09 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py @@ -0,0 +1,967 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable +from dataclasses import dataclass + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams + +from ...quack.cute_dsl_utils import ParamsBase + +from ...src.common import utils as utils +from ...src.common.fast_math import clz + + +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `SparseAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + args.use_cluster_idx, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) + return ( + grid_x, + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_cluster_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks_cluster: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) + total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(num_block_cluster), + FastDivmodDivisor(args.num_head), + total_blocks_cluster, + cluster_shape_m=args.cluster_shape_mn[0], + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": + if const_expr(cute.size(params.cluster_shape_m) == 1): + tile_idx = cute.arch.block_idx()[0] + else: + tile_idx = cute.arch.cluster_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + usable_SM_count=0, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + cluster_shape_m = int(params.cluster_shape_m) + if usable_SM_count > 0: + sm_count = usable_SM_count + else: + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m + max_ctas = max(max_ctas, cluster_shape_m) + grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks_cluster + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.cluster_shape_m == 1): + self._tile_idx += cute.arch.grid_dim()[0] + else: + self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + num_head: Int32 + num_batch: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileLPTScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # Seems faster if swizzle is a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), + is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + if const_expr(params.lpt): + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + # Round down to nearest multiple of cluster since odd excess is always padding. + total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx // params.cluster_shape_m + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt or params.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + * params.cluster_shape_m + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + if cutlass.const_expr(params.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_m + bidx_in_cluster[0] + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/tma_utils.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/tma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdc19a08eacf9a060f2c0a7a4d50a4adb735094 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/tma_utils.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Raw TMA ops and descriptor builders. + +`tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA +descriptor construction. Non-TMA store/layout helpers are re-exported from +`copy_utils.py` for backward compatibility. +""" + +import ctypes + +from cutlass import Int32, Int64 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as cute_ir +import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir +from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen + + +# Raw TMA Ops + +TMA_CACHE_EVICT_FIRST = 0x12F0000000000000 +TMA_CACHE_EVICT_LAST = 0x14F0000000000000 + + +@dsl_user_op +def tma_tile_load( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $9;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None): + """Prefetch a raw TMA descriptor pointer into the descriptor cache.""" + ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip) + ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width) + ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip) + ptr_ty = cute_ir.PtrType.get( + cute_nvgpu_ir.TmaDescriptorTiledType.get(), + cute_ir.AddressSpace.gmem, + 128, + ) + desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip) + cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def tma_tile_prefetch( + tma_desc_ptr, + col_idx, + row_idx, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint " + "[$0, {$1, $2}], $3;\n", + "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_prefetch( + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint " + "[$0, {$1, $2, $3, $4, $5}], $6;\n", + "l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_load_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes.L2::cache_hint " + "[sa], [$3, {$4, $5}], [ma], $7;\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $0;\n" + "add.u32 sa, sa, $1;\n" + "cvt.u32.u64 ma, $8;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint " + "[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n" + "}\n", + "l,r,l,r,r,r,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_store( + tma_desc_ptr, + col_idx, + row_idx, + smem_ptr, + smem_byte_offset, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store.""" + llvm.inline_asm( + T.i32(), + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + "cvt.u32.u64 sa, $4;\n" + "add.u32 sa, sa, $5;\n" + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + " [$1, {$2, $3}], [sa];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,r,l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +# Descriptor Builders + +_TMA_DESC_BYTES = 128 + + +def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes: + import torch + import cuda.bindings.driver as cuda + + if tensor_2d.ndim != 2: + raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}") + rows, cols = tensor_2d.shape + if tensor_2d.stride(-1) != 1: + raise ValueError(f"{context} tensor must be contiguous in the last dimension") + dtype_map = { + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + } + if tensor_2d.dtype not in dtype_map: + raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}") + + sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)] + strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())] + box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)] + elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)] + err, tm = cuda.cuTensorMapEncodeTiled( + dtype_map[tensor_2d.dtype], + 2, + tensor_2d.data_ptr(), + sizes, + strides, + box, + elem_stride, + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}" + buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr()) + return bytes(buf) + + +def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device): + import torch + + desc_bytes = bytes(desc_bytes) + device = torch.device(device) + if device.type != "cuda": + raise ValueError(f"TMA descriptors require a CUDA device, got {device}") + + host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True) + host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8)) + device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device) + stream = torch.cuda.current_stream(device) + with torch.cuda.stream(stream): + device_desc.copy_(host_desc, non_blocking=True) + device_desc.record_stream(stream) + # Keep the staging buffer alive for the async copy without caching descriptors. + device_desc._tma_host_desc = host_desc + return device_desc + + +def create_flat_gather4_tma_desc(tensor_2d, box_x=64): + """Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor.""" + if tensor_2d.ndim != 2: + raise ValueError( + f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}" + ) + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=1, + context="gather4", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_q_gather4_tma_desc(q_flat, box_x=64): + return create_flat_gather4_tma_desc(q_flat, box_x=box_x) + + +def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y): + """Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride.""" + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=box_y, + context="strided 2D", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128): + """Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage.""" + import torch + + if kv_flat.ndim != 3: + raise ValueError( + f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}" + ) + total_k, head_kv, dim = kv_flat.shape + row_stride = head_kv * dim + desc_table = bytearray() + for h in range(head_kv): + head_view = torch.as_strided( + kv_flat, + size=(total_k, dim), + stride=(row_stride, 1), + storage_offset=h * dim, + ) + desc_table.extend( + _encode_tma_desc_2d_bytes( + head_view, + box_x=box_x, + box_y=box_y, + context="flat KV", + ) + ) + return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape( + head_kv, _TMA_DESC_BYTES + ) + + +# Compatibility Re-exports + +from .copy_utils import ( + atomic_add_broadcast_i32, + atomic_add_i32, + convert_layout_acc_mn, + convert_layout_from_tmem16x256b_to_acc_sm90, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, + stg_128, + stg_128_cs, + stg_128_bf16, + stg_128_bf16_cs, + stg_128_f16, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, + stg_32_fp8_e4m3, + stg_64_bf16, + stg_64_f16, +) + + +__all__ = [ + "TMA_CACHE_EVICT_FIRST", + "TMA_CACHE_EVICT_LAST", + "atomic_add_broadcast_i32", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "create_flat_gather4_tma_desc", + "create_flat_kv_tma_descs", + "create_q_gather4_tma_desc", + "create_strided_2d_tma_desc", + "make_16x256b_tensor_mn_view", + "prefetch_tma_desc_raw", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "tma_gather4", + "tma_gather4_cached", + "tma_gather4_prefetch", + "tma_tile_load", + "tma_tile_load_cached", + "tma_tile_prefetch", + "tma_tile_store", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/common/utils.py b/build/torch211-cxx11-cu130-x86_64-linux/src/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bd0ba76b532cb54c159eba5e82320266c80c63 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/common/utils.py @@ -0,0 +1,1088 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import math +import hashlib +import inspect +from typing import Type, Callable, Optional, Tuple, overload + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass.cute.runtime import from_dlpack + + +from ...quack import activation +_MIXER_ATTRS = ("__vec_size__",) + +# Obtained from sollya: +# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); +POLY_EX2 = { + 0: (1.0), + 1: ( + 1.0, + 0.922497093677520751953125, + ), + 2: ( + 1.0, + 0.6657850742340087890625, + 0.330107033252716064453125, + ), + 3: ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ), + 4: ( + 1.0, + 0.693042695522308349609375, + 0.2412912547588348388671875, + 5.2225358784198760986328125e-2, + 1.3434938155114650726318359375e-2, + ), + 5: ( + 1.0, + 0.693151414394378662109375, + 0.24016360938549041748046875, + 5.5802188813686370849609375e-2, + 9.01452265679836273193359375e-3, + 1.86810153536498546600341796875e-3, + ), +} + + +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash + + hasher = hashlib.sha256(base_hash.encode()) + + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) + + return hasher.hexdigest() + + +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale): + """Compute softmax_scale_log2 from softmax_scale. + + Returns (softmax_scale_log2, None). + """ + return softmax_scale * LOG2_E, None + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_rmem_tensor(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + else: + # New API: infers result type automatically + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@cute.jit +def fadd_exp2_scaled_reduce( + x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements" + if const_expr(arch < 100): + return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch) + elif const_expr(cute.size(x.shape) % 8 == 0): + local_sum = [ + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + ] + for i in cutlass.range_constexpr(0, cute.size(x.shape), 8): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i + 0], x[i + 1]), (scale, scale) + ) + acc2, acc3 = cute.arch.mul_packed_f32x2( + (x[i + 2], x[i + 3]), (scale, scale) + ) + acc4, acc5 = cute.arch.mul_packed_f32x2( + (x[i + 4], x[i + 5]), (scale, scale) + ) + acc6, acc7 = cute.arch.mul_packed_f32x2( + (x[i + 6], x[i + 7]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + acc2 = cute.math.exp2(acc2, fastmath=True) + acc3 = cute.math.exp2(acc3, fastmath=True) + acc4 = cute.math.exp2(acc4, fastmath=True) + acc5 = cute.math.exp2(acc5, fastmath=True) + acc6 = cute.math.exp2(acc6, fastmath=True) + acc7 = cute.math.exp2(acc7, fastmath=True) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1)) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3)) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5)) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7)) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + else: + row_sum = Float32(0.0) + for i in cutlass.range_constexpr(0, cute.size(x.shape), 2): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i], x[i + 1]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + row_sum += acc0 + acc1 + return row_sum + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_f32( + a: float | Float32, + b: float | Float32, + c: float | Float32, + d: float | Float32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n" + "mov.b32 $0, {h0, h1};\n" + "}\n", + "=r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_bf16x4( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Convert packed e4m3x4 bits into two packed bf16x2 registers.""" + out0 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "and.b32 out, q, 0x80008000;\n\t" + "and.b32 mant, q, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + out1 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, qs, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "shl.b32 qs, q, 8;\n\t" + "and.b32 out, qs, 0x80008000;\n\t" + "and.b32 mant, qs, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return out0, out1 + + +@dsl_user_op +def cvt_fp4x2_e2m1_f16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert one packed E2M1 byte into one packed f16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0;\n\t" + "mov.b32 {byte0, _, _, _}, $1;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_f16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed f16x2 registers.""" + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + +@dsl_user_op +def cvt_fp4x8_e2m1_bf16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed bf16x2 registers.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8( + src, loc=loc, ip=ip + ) + return ( + cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip), + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_scaled_e4m3x8( + src: cutlass.Int32, + scale_e4m3: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 tmp, ra;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "prmt.b32 tmp, $3, 0, 0;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "mov.b32 ra, {byte0, byte1, _, _};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t" + "mov.b32 ra, {_, _, byte2, byte3};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 sf_bytes, sf_f16x2;\n\t" + ".reg .b16 sf_pair, e0, e1, e2, e3;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + ".reg .b32 h0, h1, h2, h3;\n\t" + "prmt.b32 sf_bytes, $3, 0, 0;\n\t" + "mov.b32 {sf_pair, _}, sf_bytes;\n\t" + "cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t" + "mul.rn.f16x2 h0, h0, sf_f16x2;\n\t" + "mul.rn.f16x2 h1, h1, sf_f16x2;\n\t" + "mul.rn.f16x2 h2, h2, sf_f16x2;\n\t" + "mul.rn.f16x2 h3, h3, sf_f16x2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t" + "mov.b32 $0, {e0, e1};\n\t" + "mov.b32 $1, {e2, e3};\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def cvt_f16x2_to_bf16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert a packed f16x2 register into a packed bf16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b16 h0, h1;\n\t" + ".reg .f32 f0, f1;\n\t" + "mov.b32 {h0, h1}, $1;\n\t" + "cvt.f32.f16 f0, h0;\n\t" + "cvt.f32.f16 f1, h1;\n\t" + "cvt.rn.bf16x2.f32 $0, f1, f0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def mul_bf16x2( + a: cutlass.Int32, + b: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Multiply two packed bf16x2 registers.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Int32(a).ir_value(loc=loc, ip=ip), + cutlass.Int32(b).ir_value(loc=loc, ip=ip), + ], + "mul.rn.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32: + """Decode one E4M3 byte and replicate it into a packed bf16x2 register.""" + + src_u8 = src & cutlass.Int32(0xFF) + packed = src_u8 * cutlass.Int32(0x01010101) + out0, _ = cvt_fp8x4_e4m3_bf16x4(packed) + return out0 + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_rmem_tensor(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@cute.jit +def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None: + """Convert a Float32 rmem tensor to dst's element type. + + fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment + ``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the + DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8 + bytes within a 32-bit register cell in the order DSL chooses, which is + expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads. + """ + if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]): + cvt_f16(src, dst) + elif const_expr(dst.element_type is cutlass.Float8E4M3FN): + assert src.element_type is Float32, "src must be Float32" + assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements" + frg_tile = 4 + src_frg = cute.logical_divide(src, cute.make_layout(frg_tile)) + dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile)) + for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])): + dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type)) + else: + assert src.element_type is Float32, "src must be Float32" + dst_view = cute.make_tensor(dst.iterator, src.layout) + dst_view.store(src.load().to(dst.element_type)) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: + assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" + # We assume x <= 127.0 + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +@dsl_user_op +def ex2_emulation_2( + x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None +) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = activation.sub_packed_f32x2( + xy_rounded, (fp32_round_int, fp32_round_int) + ) + xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_rmem_tensor(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] + + +# ------------------------------------------------------------------ +# Host-side Python helpers (not @cute.jit — called from PyTorch host code) +# ------------------------------------------------------------------ + +def default_softmax_scale(dim: int) -> float: + """Default softmax scale: 1 / sqrt(dim).""" + return 1.0 / math.sqrt(dim) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f23267fe73800d35db382a1919bc28196da5aa8c --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention kernels.""" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf19c60a32d2f57595c9666323b47738b878115 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""q2k -> k2q CSR builder backed by the precompiled Torch ops. + +The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built +ahead of time by kernel-builder; it is reached through the ``_ops`` +namespace instead of being JIT-compiled at import time. + +The kernel pipeline is tuned and verified for SM100; other +architectures are not supported. +""" + +from __future__ import annotations + +import torch + +from ...._ops import ops + + +def run_build_k2q_csr( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, +) -> None: + """In-place fill of ``row_ptr`` and ``q_idx``. + + Args: + q2k: int32 [H, total_q, topK] contiguous (CUDA). + cu_seqlens_q: int32 [B+1] contiguous (CUDA). + cu_seqlens_k: int32 [B+1] contiguous (CUDA). + row_ptr: int32 [H, total_rows + 1] CUDA, written in place. + q_idx: int32 [H, total_q * topK] CUDA, written in place + (trailing slots set to -1). + topk: must be in {4, 8, 16, 32}. + blk_kv: must equal 128. + total_rows: sum over batches of ceil(seqlen_k / blk_kv). + max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound + used to size the row_map workspace and clamp valid kv ids. + """ + ops.run_build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + ) + + +def run_build_k2q_csr_with_schedule( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + qsplit_idx: torch.Tensor, + split_counts: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, + target_q_per_cta: int, + work_capacity: int, + max_seqlen_q: int, +) -> None: + """In-place fill of CSR plus fused sparse attention schedule metadata.""" + ops.run_build_k2q_csr_with_schedule( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + scheduler_metadata, + work_count, + qsplit_idx, + split_counts, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + int(target_q_per_cta), + int(work_capacity), + int(max_seqlen_q), + ) + + +def is_supported(topk: int, blk_kv: int) -> bool: + return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128 + + +__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..037791818feb030a5969ebf6ac3cc3943cdb7dce --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Split-KV schedule for paged fp8 decode attention. + +The public PageKV representation remains this repo's rectangular page table: +``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only +describes how query tiles and KV chunks are split into work items. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DecodeAttentionSchedule: + split_kv: bool + cta_tile_q: int + num_q_tiles: int + kv_chunk_size_pages: int + kv_chunk_size_tokens: int + work_count: int + padded_work_count: int + partial_rows: int + max_split_count: int + max_grid_size: int + active_blocks_per_sm: int + num_sms: int + base_cta: int + request_indices: torch.Tensor + qo_tile_indices: torch.Tensor + kv_tile_indices: torch.Tensor + merge_indptr: torch.Tensor + o_indptr: torch.Tensor + block_valid_mask: torch.Tensor + kv_pages: torch.Tensor + split_counts: torch.Tensor + + +def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def prepare_decode_schedule( + *, + seqused_k: torch.Tensor, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, +) -> DecodeAttentionSchedule: + """Build paged decode split-KV schedule on the GPU. + + A single CUDA kernel reads ``seqused_k`` on device and writes all + schedule index arrays. Only a small summary tensor is D2H-synced so + the wrapper can size O_partial / pick the kernel grid / choose the + split-vs-non-split compile path. + + ``max_seqlen_k`` is the host-side worst-case bound used to pad the + work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``. + """ + _require_i32_cuda_1d(seqused_k, name="seqused_k") + # Hard cap: current single-CTA schedule kernel stores per-batch state + # in shared memory. Larger batches require a multi-CTA cooperative + # scheduler (unimplemented). Fail fast at the Python boundary so the + # error doesn't surface from inside the CUDA extension. + if int(seqused_k.shape[0]) > 1024: + raise NotImplementedError( + "decode schedule currently supports batch <= 1024 " + f"(got batch={int(seqused_k.shape[0])}). Larger batches need " + "the multi-CTA scheduler — not yet implemented." + ) + # Two API-boundary checks tied to the kernel's packed-GQA layout + # (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8): + # + # (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as + # col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first + # q-token in the packed group) this is col_limit = seqlen_k - seqlen_q + # + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked + # row then enters a mask-codegen path with PTX-undefined shift counts + # and the kernel hangs. The condition is also semantically invalid + # in batched-decode: you can't emit seqlen_q new tokens with fewer + # than seqlen_q total context tokens (seqlen_k includes them). + # + # (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires + # when the LAST partial page has < q_tokens_per_group=8 valid + # columns, because then the *last MMA tile* hits the same all-masked + # row case for the trailing q-tokens. + # + # Both are tracked as a separate kernel-level TODO (un-pack the + # all-masked row → skip mask call, or saturate causal_col_limit at >= 1 + # in mask.py). Until then, fail fast at the Python boundary with a + # clear message rather than letting the kernel timeout. + seqlen_q_i = int(seqlen_q) + bad_q = seqused_k < seqlen_q_i + if bool(bad_q.any().item()): + bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) " + f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. " + f"This is also a batched-decode invariant: seqlen_k must include " + f"the seqlen_q new tokens being emitted." + ) + rem = seqused_k % int(page_size) + bad_rem = (rem > 0) & (rem < seqlen_q_i) + if bool(bad_rem.any().item()): + bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] % page_size ∈ " + f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. " + f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has " + f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). " + f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to " + f"a multiple of {page_size}." + ) + if int(page_size) <= 0: + raise ValueError("page_size must be positive") + if int(seqlen_q) <= 0: + raise ValueError("seqlen_q must be positive") + if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0: + raise ValueError("head counts must be positive") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(num_qo_heads) // int(num_kv_heads) != 16: + raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16") + if int(head_dim) != 128: + raise NotImplementedError("decode schedule currently supports only head_dim=128") + if int(max_seqlen_k) <= 0: + raise ValueError("max_seqlen_k must be positive") + + from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule + + raw = build_decode_schedule( + seqused_k, + page_size=int(page_size), + seqlen_q=int(seqlen_q), + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + head_dim=int(head_dim), + max_seqlen_k=int(max_seqlen_k), + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=0 if max_grid_size is None else int(max_grid_size), + fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size), + disable_split_kv=bool(disable_split_kv), + ) + return DecodeAttentionSchedule( + split_kv=bool(raw["split_kv"]), + cta_tile_q=int(raw["cta_tile_q"]), + num_q_tiles=int(raw["num_q_tiles"]), + kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]), + kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]), + work_count=int(raw["work_count"]), + padded_work_count=int(raw["padded_work_count"]), + partial_rows=int(raw["partial_rows"]), + max_split_count=int(raw["max_split_count"]), + max_grid_size=int(raw["max_grid_size"]), + active_blocks_per_sm=int(raw["active_blocks_per_sm"]), + num_sms=int(raw["num_sms"]), + base_cta=int(raw["base_cta"]), + request_indices=raw["request_indices"], + qo_tile_indices=raw["qo_tile_indices"], + kv_tile_indices=raw["kv_tile_indices"], + merge_indptr=raw["merge_indptr"], + o_indptr=raw["o_indptr"], + block_valid_mask=raw["block_valid_mask"], + kv_pages=raw["kv_pages"], + split_counts=raw["split_counts"], + ) + + +__all__ = [ + "DecodeAttentionSchedule", + "prepare_decode_schedule", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa83e39a5504ac6cf8d732255e495e48b35fa20a --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py @@ -0,0 +1,1956 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 FP4 sparse-attention indexer kernels.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from ...src.common import pipeline as common_pipeline + + +FP4_FORMAT = Literal["mxfp4", "nvfp4"] +_FP4_PACKED_D_BYTES = 64 +_HEAD_DIM = 128 +_BLOCK_K = 128 +_PAGE_SIZE = 128 +_MMA_TILER_MN = (128, 128) +_MMA_INST_SHAPE_K = 64 +_NON_CAUSAL_K_TILES_PER_CTA = 16 +_CAUSAL_K_TILES_PER_CTA = 16 +_DECODE_PACK_Q_LEN = 8 +_DECODE_QHEAD_PER_KV = 16 +_DECODE_K_TILES_PER_CTA = 16 +_AB_DTYPE = cutlass.Float4E2M1FN + + +@dataclass(frozen=True) +class Fp4FormatSpec: + name: FP4_FORMAT + sf_vec_size: int + scale_groups: int + torch_scale_dtype: torch.dtype + cutlass_scale_dtype: type + + +_FORMAT_SPECS: dict[str, Fp4FormatSpec] = { + "mxfp4": Fp4FormatSpec( + name="mxfp4", + sf_vec_size=32, + scale_groups=4, + torch_scale_dtype=torch.float8_e8m0fnu, + cutlass_scale_dtype=cutlass.Float8E8M0FNU, + ), + "nvfp4": Fp4FormatSpec( + name="nvfp4", + sf_vec_size=16, + scale_groups=8, + torch_scale_dtype=torch.float8_e4m3fn, + cutlass_scale_dtype=cutlass.Float8E4M3FN, + ), +} + + +def normalize_fp4_format(fmt: str) -> Fp4FormatSpec: + key = str(fmt).lower() + try: + return _FORMAT_SPECS[key] + except KeyError as exc: + raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc + + +def ceil_div(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def k_tiles_per_cta_for(causal: bool) -> int: + return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA + + +class Fp4IndexerScaleReorderSm100: + """Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, page_count, heads_k = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = cute.ceil_div(self.scale_groups, 4) + k_l = page_count * heads_k + + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (total_q, heads_q, self.scale_groups), + stride=(heads_q * self.scale_groups, self.scale_groups, 1), + ), + ) + k_scale = cute.make_tensor( + k_scale_ptr, + cute.make_layout( + (page_count, heads_k, _PAGE_SIZE, self.scale_groups), + stride=( + heads_k * _PAGE_SIZE * self.scale_groups, + _PAGE_SIZE * self.scale_groups, + self.scale_groups, + 1, + ), + ), + ) + + q_mma_layout = cute.make_ordered_layout( + (32, 4, rest_q_m, 4, rest_g, heads_q), + order=(2, 1, 4, 0, 3, 5), + ) + k_mma_layout = cute.make_ordered_layout( + (32, 4, 1, 4, rest_g, k_l), + order=(2, 1, 4, 0, 3, 5), + ) + q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout) + k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout) + q_scale_mma = cute.group_modes(q_scale_mma, 0, 3) + q_scale_mma = cute.group_modes(q_scale_mma, 1, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 0, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 1, 3) + + q_scale_count = total_q * heads_q * Int32(self.scale_groups) + k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups) + total_scale_count = q_scale_count + k_scale_count + grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta) + self.kernel( + q_scale, + k_scale, + q_scale_mma, + k_scale_mma, + heads_q, + heads_k, + q_scale_count, + total_scale_count, + ).launch( + grid=(grid_ctas, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + q_scale: cute.Tensor, + k_scale: cute.Tensor, + q_scale_mma: cute.Tensor, + k_scale_mma: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + q_scale_count: Int32, + total_scale_count: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + block_idx, _, _ = cute.arch.block_idx() + grid_dim, _, _ = cute.arch.grid_dim() + linear = block_idx * Int32(self.threads_per_cta) + tidx + stride = grid_dim * Int32(self.threads_per_cta) + + while linear < total_scale_count: + if linear < q_scale_count: + group = linear % Int32(self.scale_groups) + tmp = linear // Int32(self.scale_groups) + head = tmp % heads_q + row = tmp // heads_q + q_scale_mma[row, group, head] = q_scale[row, head, group] + else: + k_linear = linear - q_scale_count + group = k_linear % Int32(self.scale_groups) + tmp = k_linear // Int32(self.scale_groups) + row = tmp % Int32(_PAGE_SIZE) + tmp = tmp // Int32(_PAGE_SIZE) + head = tmp % heads_k + page = tmp // heads_k + scale_l = page * heads_k + head + k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group] + linear += stride + + +class Fp4IndexerStagedMmaSm100: + """Single-kernel FP4 indexer for preordered MMA scale storage.""" + + def __init__( + self, + *, + fmt: str, + causal: bool, + preordered_q_scale_tma: bool = False, + compact_schedule: bool = False, + use_tmem_load_red: bool = False, + ): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.preordered_q_scale_tma = bool(preordered_q_scale_tma) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal) + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + m, + _, + k, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + compact_task_count, + ) = problem_size + page_count = lk // heads_k + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (total_q, _HEAD_DIM, heads_q), + stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (total_q, _HEAD_DIM, heads_q), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor( + kv_indices_ptr, + cute.make_layout((page_count,), stride=(1,)), + ) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + if const_expr(self.preordered_q_scale_tma): + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + else: + tma_qs = tma_q + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0]) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid_x = compact_task_count + else: + grid_x = grid_q_tiles * grid_k_groups + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + q_scale_tensor, + k_scale_tensor, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + has_qo_offset, + max_k_tiles, + grid_k_groups, + ).launch( + grid=(grid_x, batch * heads_q, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset + return True + + @cute.jit + def _full_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.jit + def _partial_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + q_len: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mQS: cute.Tensor, + mKS: cute.Tensor, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + k_group_count: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + lane_idx = cute.arch.lane_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_idx, q_l, _ = cute.arch.block_idx() + batch_idx = q_l // heads_q + hq = q_l - batch_idx * heads_q + hk = hq // (heads_q // heads_k) + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + task_valid = True + q_tile_idx = Int32(0) + ktile_group = Int32(0) + if const_expr(self.compact_schedule): + remaining = task_idx + q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0]) + batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + q_scan = Int32(0) + task_valid = False + while q_scan < q_tile_count and not task_valid: + q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0]) + q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_scan_last >= q_len: + q_scan_last = q_len - Int32(1) + visible_limit = q_scan_last + causal_offset + visible_group_count = Int32(0) + if visible_limit >= Int32(0): + visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1) + if visible_group_count > batch_k_group_count: + visible_group_count = batch_k_group_count + task_valid = remaining < visible_group_count + if not task_valid: + remaining -= visible_group_count + q_scan += Int32(1) + if task_valid: + q_tile_idx = q_scan + ktile_group = remaining + else: + q_len = Int32(0) + k_len = Int32(0) + else: + q_tile_idx = task_idx // k_group_count + ktile_group = task_idx - q_tile_idx * k_group_count + q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0]) + q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_tile_last >= q_len: + q_tile_last = q_len - Int32(1) + q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len + q_tile_global_start = q_begin + q_tile_start + q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128) + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_tile_start, + q_tile_last, + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCsQ = thr_mma.partition_A(sQ_public) + tCsK = thr_mma.partition_B(sK_public) + mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma) + gQ_tma = cute.local_tile( + mQ_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + if const_expr(self.preordered_q_scale_tma): + mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma) + gQS_tma = cute.local_tile( + mQS_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + sQS = sQS_public + sKS = sKS_public + + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + if const_expr(self.preordered_q_scale_tma): + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_tma_copy_bytes, + defer_sync=True, + ).make_participants() + if const_expr(self.preordered_q_scale_tma): + qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.qs_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=qs_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + if warp_idx == self.load_warp_id: + if group_has_visible: + q_empty = q_producer.acquire_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_empty = qs_producer.acquire_and_advance() + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, q_tile_idx, 0, hq)], + tQsQS_tma[(None, qs_empty.index)], + tma_bar_ptr=qs_empty.barrier, + ) + qs_empty.commit() + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + cute.copy( + tma_q.atom, + tQgQ_tma[(None, q_tile_idx, 0, hq)], + tQsQ_tma[(None, q_empty.index)], + tma_bar_ptr=q_empty.barrier, + ) + q_empty.commit() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Move block scales into TMEM and issue one FP4 GEMM per visible K tile. + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_full = q_consumer.wait_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_full = qs_consumer.wait_and_advance() + qs_full.release() + q_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + ktile = Int32(0) + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx == self.load_warp_id: + if group_has_visible: + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Load accumulators from TMEM, reduce per-row max, and store scores. + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + q_local_store0 = q_tile_start + epi_tidx + q_global_store0 = q_begin + q_local_store0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta) + q_global_store1 = q_begin + q_local_store1 + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + tile_full = q_tile_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + if tile_mask_free: + if tile_full: + if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if coord_m == epi_tidx and q_local < q_len and k_local < k_len: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len: + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + if tile_full: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._full_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._full_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._partial_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._partial_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + if q_tile_full: + mScores[hq, ktile, q_global_store0] = row_max0 + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = row_max0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if q_tile_full: + mScores[hq, ktile, q_global_store1] = row_max1 + elif q_local_store1 < q_len: + mScores[hq, ktile, q_global_store1] = row_max1 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner) + if ktile < max_k_tiles: + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) + +class Fp4IndexerDecodeQPackSm100: + """Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, heads_k, batch = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = ceil_div(self.scale_groups, 4) + q = cute.make_tensor( + q_ptr, + cute.make_layout( + (total_q, heads_q, _FP4_PACKED_D_BYTES), + stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (heads_q, rest_q_m, rest_g, 32, 4, 4), + stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + q_pack_l = batch * heads_k + q_pack = cute.make_tensor( + q_pack_ptr, + cute.make_layout( + (q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale_pack = cute.make_tensor( + q_scale_pack_ptr, + cute.make_layout( + (q_pack_l, 1, rest_g, 32, 4, 4), + stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,))) + self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch( + grid=(q_pack_l, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQS: cute.Tensor, + mQPack: cute.Tensor, + mQSPack: cute.Tensor, + mCuQ: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + q_pack_l, _, _ = cute.arch.block_idx() + batch_idx = q_pack_l // heads_k + hk = q_pack_l - batch_idx * heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + q_len = q_end - q_begin + qhead_per_kv = heads_q // heads_k + + linear = tidx + while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES): + row = linear // Int32(_FP4_PACKED_D_BYTES) + byte = linear - row * Int32(_FP4_PACKED_D_BYTES) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + if q_local < q_len and h_in_group < qhead_per_kv: + mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte] + else: + mQPack[q_pack_l, row, byte] = cutlass.Uint8(0) + linear += Int32(self.threads_per_cta) + + scale_linear = tidx + while scale_linear < Int32(_PAGE_SIZE * self.scale_groups): + row = scale_linear // Int32(self.scale_groups) + group = scale_linear - row * Int32(self.scale_groups) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + q_abs = q_begin + q_local + if q_local >= q_len or h_in_group >= qhead_per_kv: + q_abs = q_begin + hq = hk * qhead_per_kv + src_rest_m = q_abs // Int32(128) + src_row = q_abs - src_rest_m * Int32(128) + src_row_atom = src_row % Int32(32) + src_row_major = src_row // Int32(32) + dst_row_atom = row % Int32(32) + dst_row_major = row // Int32(32) + rest_g = group // Int32(4) + group_in_rest = group - rest_g * Int32(4) + mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[ + hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest + ] + scale_linear += Int32(self.threads_per_cta) + + +class Fp4IndexerDecodePackedQSm100: + """Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``.""" + + def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + @cute.jit + def __call__( + self, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + _, + _, + _, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + ) = problem_size + page_count = lk // heads_k + q_pack_l = batch * heads_k + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_pack_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,))) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid = (compact_k_groups, heads_k, 1) + else: + grid = (grid_k_groups, batch * heads_k, 1) + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + batch, + has_qo_offset, + max_k_tiles, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset + return True + + @cute.jit + def _packed_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + h_in_group: Int32, + qhead_per_kv: Int32, + q_local: Int32, + q_len: Int32, + k_local: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + batch: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_x, task_y, _ = cute.arch.block_idx() + task_valid = True + batch_idx = Int32(0) + hk = Int32(0) + ktile_group = Int32(0) + q_l = Int32(0) + if const_expr(self.compact_schedule): + hk = task_y + group_base = Int32(0) + scan_batch = Int32(0) + task_valid = False + while scan_batch < batch and not task_valid: + batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch] + batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + task_valid = task_x < group_base + batch_groups + if not task_valid: + group_base += batch_groups + scan_batch += Int32(1) + if task_valid: + batch_idx = scan_batch + ktile_group = task_x - group_base + q_l = batch_idx * heads_k + hk + else: + ktile_group = task_x + q_l = task_y + batch_idx = q_l // heads_k + hk = q_l - batch_idx * heads_k + qhead_per_kv = heads_q // heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + if const_expr(self.compact_schedule): + if not task_valid: + q_len = Int32(0) + k_len = Int32(0) + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + gQ_tma = cute.local_tile( + mQ_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + gQS_tma = cute.local_tile( + mQS_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + if warp_idx == self.load_warp_id: + if group_has_visible: + q_pair_empty = q_producer.acquire_and_advance() + cute.copy( + tma_q.atom, + tQgQ_tma[(None, 0, 0, q_l)], + tQsQ_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, 0, 0, q_l)], + tQsQS_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + q_pair_empty.commit() + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS_public) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS_public) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_pair_full = q_consumer.wait_and_advance() + q_pair_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN) + q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN) + h_global_store = hk * qhead_per_kv + h_store + q_global_store = q_begin + q_local_store + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN) + tile_full = q_pack_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + if tile_mask_free and tile_full: + if const_expr(self.use_tmem_load_red): + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN) + q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN) + k_local = ktile * Int32(_BLOCK_K) + coord_n + valid = self._packed_coord_visible( + coord_m, + epi_tidx, + h_in_group, + qhead_per_kv, + q_local, + q_len, + k_local, + k_len, + causal_offset, + ) + if valid: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = row_max0 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b99aea3f8b4915c03fe8147127374d920970f3 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 forward kernels and combine paths.""" + +from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 + +__all__ = ["SparseAttentionForwardNvfp4KvSm100"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..531b27c9e6b4bd8c1bc74fb1f92ed98a192ca0b2 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py @@ -0,0 +1,3020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- Sparse Attention with flat varlen K/V +- Sparse Page Attention with paged K/V +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardSm100: + """SM100 sparse attention forward kernel.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + qk_dtype=None, + pv_dtype=None, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.qk_dtype_param = qk_dtype + self.pv_dtype_param = pv_dtype + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P dtype follows the PV operand policy and is packed into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mV: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_input_dtype = mK.element_type + self.v_input_dtype = mV.element_type + self.qk_dtype = ( + self.q_dtype if const_expr(self.qk_dtype_param is None) else self.qk_dtype_param + ) + if const_expr(self.pv_dtype_param is None): + legacy_fp8_kv_cache = ( + self.q_dtype == cutlass.BFloat16 + and self.k_input_dtype == cutlass.Float8E4M3FN + and self.v_input_dtype == cutlass.Float8E4M3FN + ) + self.pv_dtype = cutlass.BFloat16 if legacy_fp8_kv_cache else self.v_input_dtype + else: + self.pv_dtype = self.pv_dtype_param + self.k_dtype = self.qk_dtype + self.v_dtype = self.pv_dtype + self.p_dtype = self.pv_dtype + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported Q/K/V dtype: {self.q_dtype}") + if const_expr(self.qk_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported qk_dtype: {self.qk_dtype}") + if const_expr(self.pv_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported pv_dtype: {self.pv_dtype}") + if const_expr(self.q_dtype != self.qk_dtype): + raise TypeError("Q storage dtype must match qk_dtype") + if const_expr( + self.k_input_dtype != self.k_dtype + and not (self.k_input_dtype == cutlass.Float8E4M3FN and self.k_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 K -> BF16 QK staging is supported") + if const_expr( + self.v_input_dtype != self.v_dtype + and not (self.v_input_dtype == cutlass.Float8E4M3FN and self.v_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 V -> BF16 PV staging is supported") + self.k_fp8_to_bf16 = ( + self.k_input_dtype == cutlass.Float8E4M3FN + and self.k_dtype == cutlass.BFloat16 + ) + self.v_fp8_to_bf16 = ( + self.v_input_dtype == cutlass.Float8E4M3FN + and self.v_dtype == cutlass.BFloat16 + ) + self.kv_fp8_to_bf16 = self.k_fp8_to_bf16 or self.v_fp8_to_bf16 + self.qk_mma_kind = "f8f6f4" if const_expr(self.qk_dtype.width == 8) else "f16" + self.pv_mma_kind = "f8f6f4" if const_expr(self.pv_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.p_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV = [assume_tensor_aligned(t) for t in (mK, mV)] + + if const_expr(not self.paged_kv): + # Flat varlen K/V use CUTE-managed TMA descriptors, matching FA: + # K: [total_k, h, d] -> [total_k, d, h]. + # V: [total_k, h, d] -> [d, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Sparse Page Attention with page-sized blocks can use the blocked + # paged TMA layout directly. Host input is [page, head, token, dim]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d,h,b) -> (d,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp8_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim), + stride=(self.head_dim, 1), + ), + cute.make_layout((1,)), + ) + sV_fp8_layout = cute.append( + cute.make_layout( + (self.head_dim, self.n_block_size), + stride=(1, self.head_dim), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.p_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms + # ------------------------------------------------------------------ + k_tma_layout = ( + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2]) + ) + v_tma_layout = ( + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2]) + ) + kv_tma_bytes = ( + cute.size_in_bytes(self.k_input_dtype, k_tma_layout) + + cute.size_in_bytes(self.v_input_dtype, v_tma_layout)) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + if const_expr(self.k_fp8_to_bf16): + tma_atom_K, mK = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp8_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim), + ) + else: + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + if const_expr(self.v_fp8_to_bf16): + tma_atom_V, mV = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp8_layout, mode=[0, 1]), + (self.head_dim, self.n_block_size), + ) + else: + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for unified kernel signature. Small-GQA Q load + # uses raw gather4 and keeps mQ_2d as a plain row-major GMEM tensor. + tma_atom_Q = tma_atom_V + else: + tma_atom_Q, mQ_2d = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + if const_expr(self.k_fp8_to_bf16): + mbar_k_tma: cute.struct.MemRange[Int64, 2] + if const_expr(self.v_fp8_to_bf16): + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + if const_expr(self.k_fp8_to_bf16): + sKFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.k_input_dtype, cute.cosize(sK_fp8_layout) + ], + self.buffer_align_bytes] + if const_expr(self.v_fp8_to_bf16): + sVFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.v_input_dtype, cute.cosize(sV_fp8_layout) + ], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp8_layout, sV_fp8_layout, tP_layout, + tma_atom_K, tma_atom_V, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + kv_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + tma_K: cute.Tensor, + tma_V: cute.Tensor, + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp8_layout: cute.Layout, + sV_fp8_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atoms + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + kv_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + if const_expr(self.k_fp8_to_bf16): + sKFp8 = storage.sKFp8.get_tensor(sK_fp8_layout) + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + if const_expr(self.v_fp8_to_bf16): + sVFp8 = storage.sVFp8.get_tensor(sV_fp8_layout) + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_tma_bytes = cute.size_in_bytes( + self.k_input_dtype, + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2])) + v_tma_bytes = cute.size_in_bytes( + self.v_input_dtype, + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + if const_expr(self.k_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_k_ptr, k_tma_bytes) + if const_expr(self.v_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_v_ptr, v_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + if const_expr(self.kv_fp8_to_bf16): + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + if const_expr(self.k_fp8_to_bf16): + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if const_expr(self.v_fp8_to_bf16): + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if warp_idx == Int32(self.total_warps - 1): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + if const_expr(self.kv_fp8_to_bf16): + self._wg_load_kv_maybe_cast( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + sKFp8 if const_expr(self.k_fp8_to_bf16) else None, + sVFp8 if const_expr(self.v_fp8_to_bf16) else None, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + mbar_k_tma_ptr if const_expr(self.k_fp8_to_bf16) else None, + mbar_v_tma_ptr if const_expr(self.v_fp8_to_bf16) else None, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + else: + self._wg_load_kv( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.k_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sKFp8, + sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + False, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.v_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sVFp8, + sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + True, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _convert_fp8x16_to_bf16x16( + self, + src: cute.Tensor, + dst: cute.Tensor, + ): + src_i32 = cute.recast_tensor(src, cutlass.Int32) + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(4): + ( + dst_i32[word_idx * 2], + dst_i32[word_idx * 2 + 1], + ) = utils.cvt_fp8x4_e4m3_bf16x4(src_i32[word_idx]) + + @cute.jit + def _convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + elems_per_load: cutlass.Constexpr[int] = 16 + elems_per_store: cutlass.Constexpr[int] = 8 + chunks_per_row: cutlass.Constexpr[int] = self.head_dim // elems_per_load + r_fp8 = cute.make_rmem_tensor((elems_per_load,), cutlass.Float8E4M3FN) + r_bf16 = cute.make_rmem_tensor((elems_per_load,), cutlass.BFloat16) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * chunks_per_row + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(chunks_per_row) + chunk = task_idx - row * Int32(chunks_per_row) + col = chunk * Int32(elems_per_load) + smem_offset = row * Int32(self.head_dim) + col + s_fp8_ptr = cute.make_ptr( + cutlass.Float8E4M3FN, + sFp8.iterator.toint() + Int64(smem_offset), + mem_space=sFp8.iterator.memspace, + assumed_align=elems_per_load, + ) + s_fp8_vec = cute.make_tensor( + s_fp8_ptr, + cute.make_layout(elems_per_load), + ) + cute.autovec_copy(s_fp8_vec, r_fp8) + self._convert_fp8x16_to_bf16x16(r_fp8, r_bf16) + if const_expr(is_v): + sBf16_view = sBf16[(None, row % Int32(16)), 0, row // Int32(16), 0] + sBf16_vec = cute.local_tile(sBf16_view, (elems_per_load,), (chunk,)) + else: + sBf16_vec = sBf16[ + (row, None), + 0, + (chunk % Int32(4), chunk // Int32(4)), + 0, + ] + r_tiles = cute.logical_divide(r_bf16, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sBf16_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_load // elems_per_store): + cute.autovec_copy(r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv_maybe_cast( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sKFp8: Optional[cute.Tensor], + sVFp8: Optional[cute.Tensor], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + mbar_k_tma_ptr: Optional[cutlass.Pointer], + mbar_v_tma_ptr: Optional[cutlass.Pointer], + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.k_fp8_to_bf16): + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, + 0, + cute.make_layout(1), + gK, + sKFp8, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + else: + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.v_fp8_to_bf16): + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + gV, + sVFp8, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + else: + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + mbar_tma_ptr, + mbar_ready_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + if has_work: + cute.arch.mbarrier_wait(mbar_tma_ptr, 0) + self._convert_fp8_kv_to_bf16_smem( + sFp8, + sBf16, + lane, + warp_idx_in_wg, + num_dequant_warps, + is_v, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_ready_ptr) + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if producer_warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.p_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.p_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (p_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / p_dtype.width`` packed fp32 TMEM columns + # ``// (p_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.p_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1a8d6bf92b16d2943aa5e40fd91e26224ac40 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py @@ -0,0 +1,3305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel with NVFP4 K/V. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- BF16 Q +- packed NVFP4 K/V data +- E4M3 per-1x16 K/V scales in cuBLAS/cuDNN 128x4 tiled layout +- FP32 per-tensor K/V global scales +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardNvfp4KvSm100: + """SM100 sparse attention forward kernel with NVFP4 K/V.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + fp8_pair_dequant: bool = True, + has_k_global_scale: bool = True, + has_v_global_scale: bool = True, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardNvfp4KvSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardNvfp4KvSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.fp8_pair_dequant = fp8_pair_dequant + self.has_k_global_scale = has_k_global_scale + self.has_v_global_scale = has_v_global_scale + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardNvfp4KvSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P is bf16 and starts halfway into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mV: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mKScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened K rows and dim/16 cols + mVScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened V rows and dim/16 cols + mKGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mVGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_cache_dtype = mK.element_type + self.v_cache_dtype = mV.element_type + self.k_scale_dtype = mKScale.element_type + self.v_scale_dtype = mVScale.element_type + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"KVFP4 forward requires BF16 or FP8 E4M3 Q, got {self.q_dtype}") + self.k_dtype = self.q_dtype + self.v_dtype = self.q_dtype + if const_expr(self.k_cache_dtype is not cutlass.Uint8 or self.v_cache_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects packed uint8 K/V, got {self.k_cache_dtype}, {self.v_cache_dtype}" + ) + if const_expr(self.k_scale_dtype is not cutlass.Uint8 or self.v_scale_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects uint8 E4M3 scales, got {self.k_scale_dtype}, {self.v_scale_dtype}" + ) + if const_expr(self.has_k_global_scale and mKGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 K global scale") + if const_expr(self.has_v_global_scale and mVGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 V global scale") + self.mma_kind = "f8f6f4" if const_expr(self.q_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.q_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV, mKScale, mVScale = [ + assume_tensor_aligned(t) for t in (mK, mV, mKScale, mVScale) + ] + + if const_expr(not self.paged_kv): + # Flat varlen K/V: + # K: [total_k, h, d/2] -> [total_k, d/2, h]. + # V: [total_k, h, d/2] -> [d/2, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Host input is [page, head, token, dim/2]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d/2,h,b) -> (d/2,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp4_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim // 2), + stride=(self.head_dim // 2, 1), + ), + cute.make_layout((1,)), + ) + sV_fp4_layout = cute.append( + cute.make_layout( + (self.head_dim // 2, self.n_block_size), + stride=(1, self.head_dim // 2), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms. Packed FP4 K/V are staged by TMA, then dequantized into + # BF16 MMA SMEM layout by the KV load warps. + # ------------------------------------------------------------------ + k_fp4_tma_bytes = cute.size_in_bytes( + self.k_cache_dtype, cute.select(sK_fp4_layout, mode=[0, 1])) + v_fp4_tma_bytes = cute.size_in_bytes( + self.v_cache_dtype, cute.select(sV_fp4_layout, mode=[0, 1])) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_atom_K_fp4, mK_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp4_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim // 2), + ) + tma_atom_V_fp4, mV_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp4_layout, mode=[0, 1]), + (self.head_dim // 2, self.n_block_size), + ) + mK = mK_tma + mV = mV_tma + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for the unified kernel signature. Small-GQA Q + # loading uses raw gather4, so mQ_2d must stay as the plain GMEM + # tensor. The placeholder uses the natural SMEM top-level shape. + tma_atom_Q, _ = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (8, q_load_tile)) + else: + tma_atom_Q, mQ_2d_tma = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + mQ_2d = mQ_2d_tma + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + mbar_k_tma: cute.struct.MemRange[Int64, 2] + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + sKFp4: cute.struct.Align[ + cute.struct.MemRange[self.k_cache_dtype, cute.cosize(sK_fp4_layout)], + self.buffer_align_bytes] + sVFp4: cute.struct.Align[ + cute.struct.MemRange[self.v_cache_dtype, cute.cosize(sV_fp4_layout)], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mKScale, mVScale, mKGlobalScale, mVGlobalScale, + mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp4_layout, sV_fp4_layout, tP_layout, + tma_atom_K_fp4, tma_atom_V_fp4, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + k_fp4_tma_bytes, v_fp4_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp4_layout: cute.Layout, + sV_fp4_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atom + tma_atom_K_fp4: cute.CopyAtom, + tma_atom_V_fp4: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + k_fp4_tma_bytes: cutlass.Constexpr[int], + v_fp4_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sKFp4 = storage.sKFp4.get_tensor(sK_fp4_layout) + sVFp4 = storage.sVFp4.get_tensor(sV_fp4_layout) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_smem_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + v_smem_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_fp4_tma_bytes) + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_fp4_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if ( + warp_idx == Int32(self.total_warps - 1) + and warp_idx >= Int32(self.kv_load_warp_base + self.num_kv_load_warps) + ): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + q_group_start = Int32(0) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + self._wg_load_kv( + tma_atom_K_fp4, tma_atom_V_fp4, + mK, mV, + mKScale, mVScale, + mKGlobalScale, mVGlobalScale, + sPagedKvIdx, + sKFp4, sVFp4, sK, sV, + mbar_k_tma_ptr, mbar_v_tma_ptr, + mbar_k_ptr, mbar_v_ptr, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + num_heads_kv, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_k_from_tma_staging( + mKScale, + mKGlobalScale, + sPagedKvIdx, + sKFp4, sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_v_from_tma_staging( + mVScale, + mVGlobalScale, + sPagedKvIdx, + sVFp4, sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _scale_128x4_offset( + self, + row: Int32, + col: Int32, + scale_cols: cutlass.Constexpr[int], + ) -> Int32: + tiles_n: cutlass.Constexpr[int] = (scale_cols + 3) // 4 + tile_m = row // Int32(128) + tile_n = col // Int32(4) + outer = row % Int32(128) + inner = col % Int32(4) + return ( + (tile_m * Int32(tiles_n) + tile_n) * Int32(512) + + (outer % Int32(32)) * Int32(16) + + (outer // Int32(32)) * Int32(4) + + inner + ) + + @cute.jit + def _load_scale_bf16x2( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return utils.cvt_fp8_e4m3_to_bf16x2_replicated(cutlass.Int32(scale_byte)) + + @cute.jit + def _load_scale_e4m3_u8( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return cutlass.Int32(scale_byte) + + @cute.jit + def _dequant_fp4x16_to_bf16( + self, + src_words: cute.Tensor, + combined_scale_bf16x2: Int32, + dst: cute.Tensor, + ): + r_bf16 = cute.make_rmem_tensor((2,), cutlass.BFloat16) + r_bf16_i32 = cute.recast_tensor(r_bf16, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3 = utils.cvt_fp4x8_e2m1_bf16x8( + src_words[word_idx] + ) + bf16_pairs = (bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3) + for pair_idx in cutlass.range_constexpr(4): + r_bf16_i32[0] = utils.mul_bf16x2( + bf16_pairs[pair_idx], + combined_scale_bf16x2, + ) + dst[word_idx * 8 + 2 * pair_idx + 0] = r_bf16[0] + dst[word_idx * 8 + 2 * pair_idx + 1] = r_bf16[1] + + @cute.jit + def _dequant_fp4x16_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + + @cute.jit + def _dequant_fp4x32_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3_lo: Int32, + scale_e4m3_hi: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3_lo, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx + 2], + scale_e4m3_hi, + ) + dst_i32[word_idx * 2 + 4] = fp8_lo + dst_i32[word_idx * 2 + 5] = fp8_hi + + @cute.jit + def _flat_kv_scale_row( + self, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return token_idx * num_heads_kv + head_kv_idx + + @cute.jit + def _paged_kv_scale_row( + self, + page_idx: Int32, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return (page_idx * num_heads_kv + head_kv_idx) * Int32(self.page_size) + token_idx + + @cute.jit + def _load_k_fp4_to_smem( + self, + sKFp4: cute.Tensor, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mKScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sK_vec = sK[(row, None), 0, pair_col, 0] + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.k_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.k_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.k_dtype, + num_bits_per_copy=elems_per_store * self.k_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + else: + combined_bf16x2 = self._load_scale_bf16x2(mKScale, scale_row, scale_col) + if const_expr(self.has_k_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mKGlobalScale[0], + mKGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + sK_cols = sK[(row, None), 0, scale_col // Int32(2), 0] + sK_vec = cute.local_tile( + sK_cols, + (elems_per_block,), + (scale_col % Int32(2),), + ) + else: + sK_vec = sK[ + (row, None), + 0, + (scale_col % Int32(4), scale_col // Int32(4)), + 0, + ] + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _load_v_fp4_to_smem( + self, + sVFp4: cute.Tensor, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sV: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mVScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_pair,), (pair_col,)) + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.v_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.v_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.v_dtype, + num_bits_per_copy=elems_per_store * self.v_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + combined_bf16x2 = self._load_scale_bf16x2(mVScale, scale_row, scale_col) + if const_expr(self.has_v_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mVGlobalScale[0], + mVGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + else: + sV_cols = sV[(None, row % Int32(16)), 0, row // Int32(16), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_block,), (scale_col,)) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K_fp4, + tma_atom_V_fp4, + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sVFp4: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mbar_k_tma_ptr, + mbar_v_tma_ptr, + mbar_k_ptr, + mbar_v_ptr, + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.paged_kv): + mK_cur = mK[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + mK[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K_fp4, + 0, + cute.make_layout(1), + gK, + sKFp4, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.paged_kv): + mV_cur = mV[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + mV[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V_fp4, + 0, + cute.make_layout(1), + gV, + sVFp4, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_dequant_k_from_tma_staging( + self, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sK: cute.Tensor, + mbar_k_tma_ptr, + mbar_k_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_k_tma_ptr, 0) + self._load_k_fp4_to_smem( + sKFp4, + mKScale, + mKGlobalScale, + sPagedKvIdx, + sK, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + @cute.jit + def _wg_dequant_v_from_tma_staging( + self, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sVFp4: cute.Tensor, + sV: cute.Tensor, + mbar_v_tma_ptr, + mbar_v_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_v_tma_ptr, 0) + self._load_v_fp4_to_smem( + sVFp4, + mVScale, + mVGlobalScale, + sPagedKvIdx, + sV, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if const_expr(do_final_acquire) and producer_warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if const_expr(do_final_acquire) and warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_k_global_scale + ): + k_global = mKGlobalScale[0] + for i in cutlass.range_constexpr(0, cute.size(tSrS_t2r.shape), 2): + tSrS_t2r[i], tSrS_t2r[i + 1] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[i], tSrS_t2r[i + 1]), + (k_global, k_global), + ) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.q_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (q_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / q_dtype.width`` packed fp32 TMEM columns + # ``// (q_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.q_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + mVGlobalScale, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + mVGlobalScale: Optional[cute.Tensor], + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..a3894130432f6483291fe23c064efa7369f6d509 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py @@ -0,0 +1,1498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse forward combine kernel and public launcher. + +This keeps the local fake-layout -> real-layout epilogue needed by the lean +sparse forward path. +""" + +# Modified Step 7: O_out write with SMEM fake->real column permutation. +# O_partial dim is in STG.128 fake layout; O_out dim is real layout. +import math +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, Int64, Boolean, const_expr + +from ....src.common import utils +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + +from ....src.common.pack_gqa import PackGQAComb +from ....src.common.tma_utils import ( + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, +) + + +class SparseAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + tile_m: int = 8, + k_block_size: int = 64, + topk: int = 16, + num_threads: int = 256, + stages: int = 4, + use_pdl: bool = False, + min_blocks_per_mp: int = 0, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param tile_m: m block size + :param k_block_size: k block size + :param topk: exact number of split partials + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.topk = topk + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + self.use_pdl = use_pdl + self.min_blocks_per_mp = min_blocks_per_mp + self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16) + self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + tile_m, + k_block_size, + topk, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [ + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + Float32, + ]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if tile_m % 8 != 0: + return False + if topk > 256: + return False + if (tile_m * topk) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store). + # Keep this independent from O_partial: fp8 partial uses 16 elements + # per 128b transaction, while bf16/fp16 O stores must remain 8-wide. + output_copy_elems = universal_copy_bits // self.dtype.width + assert self.k_block_size % output_copy_elems == 0 + gmem_threads_per_row_o = k_block_gmem // output_copy_elems + assert self.num_threads % gmem_threads_per_row_o == 0 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o), + order=(1, 0), + ) + vO_layout = cute.make_layout((1, output_copy_elems)) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tO_layout, + vO_layout, + ) + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.topk, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1) + ) + + # O_partial staging layout. + if const_expr( + self.dtype_partial + in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN] + ): + smem_layout_atom_o = _get_cpasync_smem_layout_atom( + self.dtype_partial, self.k_block_size + ) + self.smem_layout_o = cute.tile_to_shape( + smem_layout_atom_o, + (self.tile_m, self.k_block_size, self.stages), + (0, 1, 2), + ) + else: + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + mLSE_temperature_partial: Optional[cute.Tensor] = None, + mLSE_temperature: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + mSplitCounts: Optional[cute.Tensor] = None, + mOutputScale: Optional[cute.Tensor] = None, + qhead_per_kvhead: Int32 = Int32(1), + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr( + mLSE_temperature_partial is not None + and mLSE_temperature_partial.element_type not in [Float32] + ): + raise TypeError("temperature LSE partial tensor must be Float32") + if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]): + raise TypeError("temperature LSE tensor must be Float32") + if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)): + raise ValueError( + "temperature LSE partial and output tensors must either both be provided or both be None" + ) + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]): + raise ValueError( + "temperature LSE partial tensor must have 3 or 4 dimensions: " + "(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]): + raise ValueError( + "temperature LSE tensor must have 2 or 3 dimensions: " + "(batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mSplitCounts is not None): + if const_expr(mSplitCounts.element_type not in [Int32]): + raise TypeError("split_counts tensor must be Int32") + if const_expr(cu_seqlens is not None): + if const_expr(len(mSplitCounts.shape) != 2): + raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)") + elif const_expr(len(mSplitCounts.shape) != 3): + raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)") + if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]): + raise TypeError("output_scale tensor must be Float32") + + mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b) + # Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads. + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + mLSE_temperature_partial = ( + cute.make_tensor( + mLSE_temperature_partial.iterator, + cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose), + ) + if mLSE_temperature_partial is not None + else None + ) + mLSE_temperature = ( + cute.make_tensor( + mLSE_temperature.iterator, + cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose), + ) + if mLSE_temperature is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + # Output-dtype permutation buffer for Step 7 (tile_m × k_block_size). + # Accumulation stays fp32; the final dtype conversion happens before + # the fake→real SMEM scatter to reduce half-output SMEM pressure. + if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]): + smem_layout_perm = cute.make_layout( + (self.tile_m, self.k_block_size), + stride=(self.k_block_size + 16, 1), + ) + else: + smem_layout_perm = cute.make_ordered_layout( + (self.tile_m, self.k_block_size), order=(1, 0) + ) + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sLSETemperature: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + sO_perm: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch) + # Head separated from seqlen → enables future TMA (contiguous Sq tiles) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + varlen_batch_idx, + semaphore_to_reset, + mSplitCounts, + mOutputScale, + qhead_per_kvhead, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + smem_layout_perm, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + self.use_pdl, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + min_blocks_per_mp=self.min_blocks_per_mp, + use_pdl=self.use_pdl, + ) + + @cute.jit + def decode_flat_row_idx( + self, + idx: Int32, + head_divmod: FastDivmodDivisor, + ): + """Decode flattened tile rows under the H_q-innermost contract.""" + q_idx_local, head_idx = divmod(idx, head_divmod) + return q_idx_local, head_idx + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSE_temperature_partial: Optional[cute.Tensor], + mLSE_temperature: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + mSplitCounts: Optional[cute.Tensor], + mOutputScale: Optional[cute.Tensor], + qhead_per_kvhead: Int32, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout | cute.ComposedLayout, + smem_layout_perm: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + use_pdl: cutlass.Constexpr[bool], + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm) + + # Handle semaphore reset — wait for dependent grids first + if const_expr(use_pdl and semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 + ): + cute.arch.griddepcontrol_wait() + semaphore_to_reset[0] = 0 + + if const_expr(num_splits_dynamic_ptr is not None): + raise ValueError("K2 combine requires compile-time exact topK") + num_splits = Int32(self.topk) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + output_scale = Float32(1.0) + if const_expr(mOutputScale is not None): + output_scale = mOutputScale[0] + + if const_expr(not varlen) or m_block * self.tile_m < max_idx: + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + if const_expr(use_pdl): + cute.arch.griddepcontrol_wait() + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + # `cLSE` (identity tensor for row/split coord tracking) is reused + # later in steps 4-5, so it must be defined on both branches. + cLSE = cute.make_identity_tensor((self.topk, self.tile_m)) + # Reshape mLSE_partial to PackGQA packed layout and delegate the + # tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq) + # into one compound dim with H_q innermost (stride 1), so thread + # rows that vary along h_pos produce one-sector coalesced reads. + # Non-varlen path only — varlen keeps the original inline loop. + if const_expr(not varlen): + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + # mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose + # [3,0,2,1] on [topK,B,Sq,H_q] and dropping B. + # Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim: + mLSE_partial_reord = cute.make_tensor( + mLSE_partial_cur.iterator, + cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2) + # shape ((H_q, Sq), topK) with H_q innermost. + packgqa = PackGQAComb( + m_block_size=self.tile_m, + head_dim_padded=0, # unused for LSE load + check_hdim_oob=False, # unused for LSE load + qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly + ) + packgqa.load_LSE( + mLSE_partial_packed, + sLSE, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_reord = cute.make_tensor( + mLSE_temperature_partial_cur.iterator, + cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_temperature_partial_packed = cute.group_modes( + mLSE_temperature_partial_reord, 0, 2) + packgqa.load_LSE( + mLSE_temperature_partial_packed, + sLSE_temperature, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + else: + # Varlen path keeps the same H_q-innermost flat-row contract: + # after transpose [1, 0, 2], mLSE_partial_cur is + # (q_local, split, head). + # mSplitCounts is the authoritative valid-split count per + # packed (q_abs, kv_head); masked splits stay at -inf and + # therefore drop out of the final kernel LSE_out reduction. + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_copy = cute.tiled_divide( + mLSE_temperature_partial_cur, (1,)) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + row_count = ( + mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead] + if const_expr(mSplitCounts is not None) + else num_splits + ) + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur_copy = ( + mLSE_temperature_partial_copy[None, m_idx, None, head_idx]) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + if const_expr(mLSE_temperature_partial is not None): + cute.copy( + gmem_thr_copy_LSE, + mLSE_temperature_partial_cur_copy[None, si], + tLSEsLSE_temperature[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) + + # Precompute per-row values for flattened (q_local, head) tiles. + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate in tile + idx = m_block * self.tile_m + mi + if idx >= max_idx: + tOhidx[m] = -1 + tOmidx[m] = 0 + tOSplitCount[m] = 0 + tOrOptr[m] = cutlass.Int64(0) + else: + tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod) + if const_expr(mSplitCounts is None): + tOSplitCount[m] = num_splits + elif const_expr(cu_seqlens is None): + tOSplitCount[m] = mSplitCounts[ + batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + else: + tOSplitCount[m] = mSplitCounts[ + offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + tOrOptr[m] = utils.elem_pointer( + mO_partial_cur, + (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]), + ).toint() + + tOpO = None + if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOSplitCount, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + if const_expr(mLSE_temperature_partial is not None): + ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature) + ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature) + cute.copy( + s2r_tiled_copy_LSE, + ts2rsLSE_temperature, + ts2rrLSE_temperature, + ) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row. Invalid splits + # have already been filled with -inf, so Step 5 can write the + # kernel-native LSE_out directly. + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + # Normalize scales + inv_sum = 0.0 + if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur: + final_lse[m] = -Float32.inf + else: + final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + if const_expr(mLSE_temperature_partial is not None): + final_lse_temperature = cute.make_rmem_tensor( + cute.size(ts2rrLSE_temperature, mode=[2]), Float32) + for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_temperature_max = cute.arch.warp_reduction_max( + ts2rrLSE_temperature[None, None, m] + .load() + .reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ), + threads_in_group=threads_per_col, + ) + lse_temperature_max_cur = ( + 0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max + ) + LOG2_E = math.log2(math.e) + lse_temperature_sum_cur = 0.0 + for s in cutlass.range( + cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE_temperature[0, s, m] * LOG2_E + - (lse_temperature_max_cur * LOG2_E), + fastmath=True, + ) + lse_temperature_sum_cur += scale + lse_temperature_sum_cur = cute.arch.warp_reduction_sum( + lse_temperature_sum_cur, threads_in_group=threads_per_col + ) + if ( + max_valid_split[m] < 0 + or lse_temperature_sum_cur == 0.0 + or lse_temperature_sum_cur != lse_temperature_sum_cur + ): + final_lse_temperature[m] = -Float32.inf + else: + final_lse_temperature[m] = ( + cute.math.log(lse_temperature_sum_cur, fastmath=True) + + lse_temperature_max + ) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.tile_m: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # This writeback is the authoritative LSE_out returned by the + # public Sparse Attention / Sparse Page Attention interface. + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + mLSE_cur = mLSE[None, None, batch_idx] + else: + mLSE_cur = cute.domain_offset((offset, 0), mLSE) + if const_expr(mLSE_temperature is not None): + if const_expr(cu_seqlens is None): + mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx] + else: + mLSE_temperature_cur = cute.domain_offset( + (offset, 0), mLSE_temperature) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mLSE_cur[m_idx, head_idx] = final_lse[m] + if const_expr(mLSE_temperature is not None): + mLSE_temperature_cur[m_idx, head_idx] = ( + final_lse_temperature[m]) + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # Flush any outstanding async-copy groups before the local Step-7 + # permutation buffer is read on the tail of the kernel. + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # =============================== + # Step 7: Write final O to gmem (fake→real via SMEM) + # =============================== + + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) + if const_expr(cu_seqlens is None): + mO_cur = mO[None, None, None, batch_idx] + else: + mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + num_vals = const_expr(cute.size(tOcO, mode=[0])) + if const_expr(not use_pdl): + # Direct / standalone calls don't participate in the K1->K2 + # dependency chain. Use a simple per-element real-column store + # path here to keep mixed-shape launches stable. + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO[k]: + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype) + else: + # 7a: fp32 accumulator -> output dtype SMEM with fake→real + # permutation. The dedicated permutation buffer stays separate + # from the O_partial pipeline staging buffer. + sO_perm = sO_perm_buf + + if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]): + # O_partial uses a dtype-specific STG.128 fake layout, but + # sO_perm is in the final O dtype. For all supported fake + # layouts, adjacent fake pairs map to adjacent real columns, + # so write the final BF16/F16 O pair as one 32-bit SMEM store. + assert num_vals % 2 == 0 + r2s_o_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=32, + ) + rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32) + sO_perm_i32_base = cute.make_ptr( + dtype=cutlass.Int32, + value=sO_perm.iterator.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v_pair in cutlass.range(num_vals // 2, unroll_full=True): + v = v_pair * 2 + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o0 = tOrO[v, m, k] + o1 = tOrO[v + 1, m, k] + if const_expr(mOutputScale is not None): + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), + (output_scale, output_scale), + ) + rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype) + smem_pair_ptr = cute.make_ptr( + dtype=cutlass.Int32, + value=( + sO_perm_i32_base.toint() + + Int64( + row_local * sO_perm_i32_row_stride + + real_col // Int32(2) + ) + * Int64(4) + ), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_pair = cute.make_tensor( + smem_pair_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair) + else: + # 7a: iterate over ALL val elements in mode[0]. + # tOcO[v, m, k][1] gives different fake_col for each v. + r2s_o_scalar_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, + ) + rO_scalar = cute.make_rmem_tensor((1,), self.dtype) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + rO_scalar[0] = o_val.to(self.dtype) + smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col)) + smem_scalar_ptr = cute.make_ptr( + dtype=self.dtype, + value=smem_ptr.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=self.dtype.width // 8, + ) + sO_scalar = cute.make_tensor( + smem_scalar_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar) + + cute.arch.sync_threads() + + # 7b: SMEM (real order, output dtype) → GMEM + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOcO_store = gmem_thr_copy_O.partition_D(cO) + tOsO_store = gmem_thr_copy_O.partition_D(sO_perm) + rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + num_store_rows = const_expr(cute.size(tOcO_store, mode=[1])) + num_store_vals = const_expr(cute.size(tOcO_store, mode=[0])) + tOpO_store = None + if const_expr(not self.is_even_k): + tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO_store), unroll_full=True): + tOpO_store[k] = ( + tOcO_store[0, 0, k][1] + < mO_partial.shape[1] - k_block * self.k_block_size + ) + + # Read output dtype from SMEM (now in real column order). + for m in cutlass.range(num_store_rows, unroll_full=True): + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k]) + + # Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7) + for m in cutlass.range(num_store_rows, unroll_full=True): + row_local = tOcO_store[0, m, 0][0] + idx = m_block * self.tile_m + row_local + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mO_cur_copy = cute.tiled_divide( + mO_cur[m_idx, None, head_idx], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + k_idx = tOcO_store[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOSplitCount: cute.Tensor, + tOpO: Optional[cute.Tensor], + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]): + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_cur_copy[None, k_idx, split], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, k].fill(0) + + +def _get_cutlass_dtype(torch_dtype: torch.dtype): + if torch_dtype not in torch2cute_dtype_map: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + return torch2cute_dtype_map[torch_dtype] + + +_combine_compile_cache = {} + + +def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = const_expr(dtype.width // 8) + bytes_per_row = const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout( + (8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), + order=(1, 0), + ), + ) + + +def combine( + o_partial_fake, + lse_partial, + o_out, + lse_out, + *, + lse_temperature_partial=None, + lse_temperature_out=None, + cu_seqlens=None, + seqused=None, + split_counts=None, + output_scale=None, + use_pdl=False, +): + """K2: merge sparse forward split partials into the final output. + + STG.128 fake-layout handling remains an internal implementation detail. + When lse_out is provided, the kernel writes the final authoritative + log-sum-exp for each query row/head directly into that tensor. + + Args: + o_partial_fake: + Batched: [num_splits, batch, Sq, head_q, dim] + Varlen: [num_splits, total_q, head_q, dim] + lse_partial: + Batched: [num_splits, batch, Sq, head_q] + Varlen: [num_splits, total_q, head_q] + o_out: + Batched: [batch, Sq, head_q, dim] + Varlen: [total_q, head_q, dim] + lse_out: + Batched: [batch, Sq, head_q] + Varlen: [total_q, head_q] + lse_temperature_partial: + Optional temperature-scaled LSE partial with the same shape as + lse_partial. + lse_temperature_out: + Optional temperature-scaled final LSE with the same shape as + lse_out. + cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine. + seqused: Optional [batch] int32 effective lengths for combine. + split_counts: Optional int32 rowwise valid split counts prepared from + q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen: + [total_q, head_kv]. + output_scale: Optional fp32 tensor with at least one element. When + provided, the final O accumulator is multiplied once before store. + use_pdl: When True, wait on PDL dependencies from the producer K1 + kernel. When False, launch without PDL waits. + """ + D = o_partial_fake.shape[-1] + num_splits = o_partial_fake.shape[0] + return_temperature_lse = ( + lse_temperature_partial is not None or lse_temperature_out is not None + ) + if (lse_temperature_partial is None) != (lse_temperature_out is None): + raise ValueError( + "lse_temperature_partial and lse_temperature_out must either both be provided or both be None" + ) + if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape: + raise ValueError( + "lse_temperature_partial must have the same shape as lse_partial, " + f"got {lse_temperature_partial.shape} vs {lse_partial.shape}" + ) + if lse_temperature_out is not None: + if lse_out is None: + raise ValueError("lse_temperature_out requires lse_out") + if lse_temperature_out.shape != lse_out.shape: + raise ValueError( + "lse_temperature_out must have the same shape as lse_out, " + f"got {lse_temperature_out.shape} vs {lse_out.shape}" + ) + if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32: + raise TypeError("temperature LSE tensors must be torch.float32") + + partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype) + out_dtype = _get_cutlass_dtype(o_out.dtype) + if output_scale is not None: + if output_scale.dtype != torch.float32: + raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}") + if output_scale.numel() < 1: + raise ValueError("output_scale must contain at least one element") + if output_scale.device != o_out.device: + raise ValueError("output_scale must be on the same device as o_out") + output_scale = output_scale.contiguous() + if split_counts is not None: + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_out.ndim == 4: + if split_counts.ndim != 3: + raise ValueError( + f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[:2] != o_out.shape[:2]: + raise ValueError( + f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}" + ) + else: + if cu_seqlens is None: + raise ValueError("split_counts with varlen output requires cu_seqlens") + if split_counts.ndim != 2: + raise ValueError( + f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[0] != o_out.shape[0]: + raise ValueError( + f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q " + f"({o_out.shape[0]})" + ) + if o_out.shape[-2] % split_counts.shape[-1] != 0: + raise ValueError( + f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})" + ) + qheadperkv = o_out.shape[-2] // split_counts.shape[-1] + else: + qheadperkv = 1 + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous") + if seqused is not None: + if seqused.dtype != torch.int32: + raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}") + if seqused.ndim != 1: + raise ValueError(f"seqused must be rank-1, got {seqused.shape}") + if not seqused.is_contiguous(): + raise ValueError("seqused must be contiguous") + + k_block_size = 128 if D > 64 else 64 + tile_m = 64 + has_cu_seqlens = cu_seqlens is not None + has_seqused = seqused is not None + has_lse = lse_out is not None + has_split_counts = split_counts is not None + has_output_scale = output_scale is not None + min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0 + + key = ( + "combine", + D, + k_block_size, + tile_m, + num_splits, + partial_dtype, + out_dtype, + has_cu_seqlens, + has_seqused, + has_lse, + bool(return_temperature_lse), + has_split_counts, + has_output_scale, + use_pdl, + min_blocks_per_mp, + ) + if key not in _combine_compile_cache: + from ....src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _combine_compile_cache[key] = loaded + else: + from ....quack.compile_utils import make_fake_tensor + + kernel = SparseAttentionForwardCombine( + dtype=out_dtype, + dtype_partial=partial_dtype, + head_dim=D, + tile_m=tile_m, + k_block_size=k_block_size, + topk=num_splits, + use_pdl=use_pdl, + min_blocks_per_mp=min_blocks_per_mp, + # stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM, + # theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35% + # -> 88.64%. Runtime latency within noise (kernel already at HBM + # bandwidth ceiling in practice) but the cleaner SOL profile + # matters for downstream NCU comparison. + stages=2, + ) + div = 128 // partial_dtype.width + if has_cu_seqlens: + total_q, nheads = (cute.sym_int64() for _ in range(2)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, total_q, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + mO = make_fake_tensor( + out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if return_temperature_lse + else None + ) + else: + batch, sq, nheads = (cute.sym_int64() for _ in range(3)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + mO = make_fake_tensor( + out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if return_temperature_lse + else None + ) + if not has_split_counts: + mSplitCounts = None + elif has_cu_seqlens: + total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2)) + mSplitCounts = make_fake_tensor( + Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1 + ) + else: + nheads_kv = cute.sym_int64() + mSplitCounts = make_fake_tensor( + Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2 + ) + mOutputScale = ( + make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0) + if has_output_scale + else None + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + _combine_compile_cache[key] = cute.compile( + kernel, + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + None + if cu_seqlens is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None + if seqused is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None, + None, + None, + mSplitCounts, + mOutputScale, + Int32(qheadperkv), + stream, + options="--enable-tvm-ffi", + ) + save_aot(key, _combine_compile_cache[key]) + + with torch.cuda.nvtx.range("K2_Combine"): + _combine_compile_cache[key]( + o_partial_fake, + lse_partial, + o_out, + lse_out, + lse_temperature_partial, + lse_temperature_out, + cu_seqlens, + seqused, + None, + None, + None, + split_counts, + output_scale, + qheadperkv, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64a0616bd5bb9c987e43b87bcbf9e89001fbb36 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""CUTE DSL launchers for paged fp8 decode forward.""" + +from __future__ import annotations + +import torch + +from .atten_fwd import run_decode_attention +from .combine import run_decode_combine + + +def decode_forward_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + merge_indptr: torch.Tensor, + O_partial: torch.Tensor | None, + LSE_partial: torch.Tensor | None, + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + max_split_count: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + O_partial_dummy: torch.Tensor | None = None, + LSE_partial_dummy: torch.Tensor | None = None, +) -> None: + """Launch dense paged fp8 decode forward and optional compressed combine. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated + placeholder buffers for the non-split path. When supplied, ``run_decode_attention`` + skips the per-call ``torch.empty`` it would otherwise need to satisfy the + kernel's positional arg signature, saving ~5us on small-kv calls. + """ + + run_decode_attention( + q, + k, + v, + page_table, + seqused_k, + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + o_indptr, + out, + lse, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(page_size), + kv_chunk_size_pages=int(kv_chunk_size_pages), + split_kv=bool(split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + O_partial_dummy=O_partial_dummy, + LSE_partial_dummy=LSE_partial_dummy, + ) + if split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode requires O_partial and LSE_partial") + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + run_decode_combine( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q=int(seqlen_q), + q_tokens_per_group=q_tokens_per_group, + max_split_count=int(max_split_count), + ) + + +__all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..9a56bb20363deffd4c850533484427bc128b3c84 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py @@ -0,0 +1,2691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Dense paged fp8 decode forward path. + +This file owns the CUTE DSL entry point for decode attention via +``SparseDecodeAttentionForwardSm100`` — SM100 UTCMMA + persistent +scheduling, paged fp8 Q/K/V, BSA blk128-style intra-warp overlap pipeline. +Forward only. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable, Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cutlass_dsl import BaseDSL +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from ....quack import copy_utils, layout_utils + +from ....src.common import pipeline +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +from ....src.common.pack_gqa import pack_gqa_layout +from ....src.common.tile_scheduler import SchedulingMode +from ....src.sm100.fwd_decode.tile_scheduler import ( + DecodeTileScheduler, + DecodeTileSchedulerArguments, +) + + +class SparseDecodeAttentionForwardSm100: + """SM100 dense paged fp8 decode forward attention (UTCMMA + CLC). + + Scope (Phase 1): + - Dense decode, ``split_kv=False``, single q-tile per work item + (``packed_q = seqlen_q * qhead_per_kv <= tile_m=128``). + - Causal only. KV reverse page loop; first reverse block applies + causal/seqlen mask, the rest is unmasked. + - fp8 Q/K/V, bf16 O, fp32 LSE. P is quantized to fp8_e4m3fn before PV + via ``SoftmaxSm100.apply_exp2_convert`` (mirror of prefill fp8 PV). + - per-batch ``mSeqUsedK[b]`` heterogeneous; no uniform-length assumptions. + + Production scope reached at Phase 4+: + - Multi q-tile (Phase 2), split-KV partial writeback (Phase 3), + CLC persistent scheduling (Phase 4), TC SOL >= 90% (Phase 7). + """ + + # UTCMMA K-tile width (matches prefill SparseAttentionForwardSm100). + k_tile = 64 + + def __init__( + self, + head_dim: int = 128, + qhead_per_kv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + page_size: int = 128, + split_kv: bool = False, + causal: bool = True, + write_lse: bool = True, + disable_softmax_exp2: bool = False, + ): + # --- structural constraints (Phase 1 scope) ------------------------- + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeAttentionForwardSm100 currently supports only D=128, " + f"got D={head_dim}" + ) + if m_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires tile_m=128, got {m_block_size}" + ) + if n_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires n_block_size=128, got {n_block_size}" + ) + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal n_block_size ({n_block_size})" + ) + if qhead_per_kv not in (16, 8, 4, 2, 1): + raise ValueError( + f"qhead_per_kv must be in {{1, 2, 4, 8, 16}}, got {qhead_per_kv}" + ) + if not causal: + raise NotImplementedError( + "decode UMMA forward currently supports only causal=True" + ) + + self.head_dim = int(head_dim) + self.qhead_per_kv = int(qhead_per_kv) + self.m_block_size = int(m_block_size) + self.n_block_size = int(n_block_size) + self.page_size = int(page_size) + self.tile_m = int(m_block_size) + self.split_kv = bool(split_kv) + self.causal = bool(causal) + self.write_lse = bool(write_lse) + self.disable_softmax_exp2 = bool(disable_softmax_exp2) + # FA fp8 SM100 fwd uses a threshold of 4.0 to avoid rescaling O for + # small row-max movements; correction receives acc_scale directly. + self.rescale_threshold = 4.0 + + # q tokens packed per (m_block_size) row group along M. + self.q_tokens_per_group = self.m_block_size // self.qhead_per_kv + + self.mma_tiler_qk = (self.m_block_size, self.n_block_size, self.head_dim) + self.mma_tiler_pv = (self.m_block_size, self.head_dim, self.n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # --- pipeline ring stages (BSA blk128 q_stage=1, s_stage=2) --- + self.q_stage = 1 + self.s_stage = 2 + self.o_stage = 2 + # Keep the fp8 decode KV ring deep enough to cover the K0/Q/K1/V0... + # order. This matches sage's fp8 setting and removes the underfed + # two-stage KV pipeline seen in the q8/16K non-split case. + self.kv_stage = 4 + self.k_stages = 2 + # Match prefill: PV is split at 3/4 of n_block_size for fp8. The + # producer (P store) must publish exactly 3N/4 fp8 columns at the + # signal point; that requires the TMEM-store atom Repetition to be + # ``8`` (one PV ``f8f6f4`` K=32 segment = 8 fp32 packed cols), so + # ``shape[2]=4`` chunks and ``split_idx=3`` lands on the 3N/4 + # boundary exactly. The previous N/2 cap was a workaround for + # ``Repetition(16)`` whose coarser chunk boundary could not + # represent 3N/4. + self.split_P_arrive = self.n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # --- warp layout (16 warps / 512 threads) — BSA-aligned (Phase 1.10.6b) + # 0-3 softmax WG 0 + # 4-7 softmax WG 1 + # 8-11 correction WG (acc_O rescale across pages + final epilogue + # write-back; participates in TmemPtr barrier) + # 12 MMA issue warp + # 13 spare / future CLC scheduler + # 14 load warp (serial Q + K + V TMA loads) + # 15 empty / register-budget reserve + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.correction_warp_base = ( + self.softmax1_warp_base + self.warps_per_group) + self.mma_warp_id = self.correction_warp_base + self.warps_per_group + self.spare_warp_id = self.mma_warp_id + 1 + self.load_warp_id = self.spare_warp_id + 1 + self.empty_warp_id = self.load_warp_id + 1 + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps + + # --- TMEM layout (fp8 P width-pack: 4 fp8 lanes per fp32 column) --- + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for head_dim_v=128 + # P (fp8) overlays the second half of each S tile via recast_ptr. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = self.n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * self.n_block_size + # fp8 P occupies n_block_size * fp8_width / fp32_width = n/4 fp32 cols. + # P offset is set in __call__ once q_dtype is known (defer to Phase 1.3). + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # --- register budget per role (BSA hdim>=96 default) --- + self.num_regs_softmax = 184 + self.num_regs_correction = 88 + self.num_regs_other = 56 + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_epilogue = self.num_regs_other + self.num_regs_empty = self.num_regs_other + + # exp2 emulation for causal: matches prefill ex2_emu_freq=16. + # disable_softmax_exp2 (Phase 7 SOL gate) bypasses both emulation and + # native exp2 — the convert pass becomes a pure fp32 -> fp8 cast. + self.ex2_emu_freq = 16 if (self.causal and not self.disable_softmax_exp2) else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # --- SM100 cluster config (single-CTA for decode, no 2-CTA pair) - + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + self.use_clc_scheduler = True + self.scheduling_mode = SchedulingMode.CLC + self.sched_stages = 2 + self.clc_scheduler_warp_id = self.empty_warp_id + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # Phase 1.2+ fills in the body. Phase 1.1 keeps signatures stable so + # the rest of the codepath (run_decode_attention dispatch in 1.10) + # can wire to this class without further churn. + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # [B, Sq, Hq, D] fp8 + mK: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mV: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mPageTable: cute.Tensor, # [B, max_pages] int32 + mSeqUsedK: cute.Tensor, # [B] int32 + mRequestIndices: cute.Tensor, # [work_capacity] int32 + mQoTileIndices: cute.Tensor, # [work_capacity] int32 + mKvTileIndices: cute.Tensor, # [work_capacity] int32 + mBlockValidMask: cute.Tensor, # [work_capacity] int32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] bf16 + mLSE: cute.Tensor, # [total_q, Hq] fp32 + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + softmax_scale: Float32, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + stream: cuda.CUstream = None, + ): + # --- dtype contract ------------------------------------------------ + if const_expr(mQ.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA Q must be Float8E4M3FN") + if const_expr(mK.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA K must be Float8E4M3FN") + if const_expr(mV.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA V must be Float8E4M3FN") + if const_expr(mO.element_type is not cutlass.BFloat16): + raise TypeError("decode UMMA output O must be BFloat16") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode UMMA output LSE must be Float32") + if const_expr(self.split_kv): + if const_expr(mO_partial is None or mO_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 O_partial") + if const_expr(mLSE_partial is None or mLSE_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 LSE_partial") + + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = ( + mO_partial.element_type if const_expr(self.split_kv) + else mO.element_type + ) + # f8f6f4 MMA descriptor kind for fp8 Q/K/V. + self.mma_kind = "f8f6f4" + # fp8 P width-pack ratio: each fp32 TMEM column holds 4 fp8 P lanes. + # Computed here so __init__ stays dtype-agnostic and the TMEM offsets + # can later be derived from this ratio in Phase 1.3. + elem_bytes = const_expr(self.q_dtype.width // 8) + p_cols_as_fp32 = const_expr( + self.n_block_size * self.q_dtype.width // Float32.width + ) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + + mQ, mK, mV, mO, mLSE = [ + assume_tensor_aligned(t) for t in (mQ, mK, mV, mO, mLSE) + ] + if const_expr(mO_partial is not None): + mO_partial = assume_tensor_aligned(mO_partial) + if const_expr(mLSE_partial is not None): + mLSE_partial = assume_tensor_aligned(mLSE_partial) + mO_epilogue = mO_partial if const_expr(self.split_kv) else mO + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO_epilogue) + self.epi_tile = (self.m_block_size, self.head_dim) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T + PV. PV uses MN-major V operand (V already + # transposed in the layout below) and a TMEM operand source for P. + # Phase 1.4 builds tiled_mma_qk; Phase 1.5 adds tiled_mma_pv so sV + # layout can derive the MN-major swizzle. + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # Paged K/V tensor view permutation. + # Input layout [num_pages, Hkv, page_size, D] (nhsd) is permuted to + # [page_size, D, Hkv, num_pages] for the paged TMA descriptor (K). + # V gets an additional (s,d) swap to become MN-major: + # [D, page_size, Hkv, num_pages]. + # ------------------------------------------------------------------ + mK_paged = cute.make_tensor( + mK.iterator, cute.select(mK.layout, mode=[2, 3, 1, 0]) + ) + mV_kv = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=[2, 3, 1, 0]) + ) + mV_paged = cute.make_tensor( + mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3]) + ) + + # ------------------------------------------------------------------ + # Q SMEM layout + BSA/FA PackGQA full-tile TMA atom. + # + # Runtime Q is [B, Sq, Hq, D]. We transpose to [Sq, D, Hq, B], then + # fold qhead_per_kv into the M dimension: + # ((qhead_per_kv, Sq), D, Hkv, B) + # This lets one Q TMA load cover the whole packed (tile_m, D) tile + # instead of issuing one TMA per q token. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + mQ = cute.make_tensor( + mQ.iterator, cute.select(mQ.layout, mode=[1, 3, 2, 0])) + nheads_kv = mK.shape[1] + mQ = pack_gqa_layout(mQ, self.qhead_per_kv, nheads_kv, head_idx=2) + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + + # ------------------------------------------------------------------ + # K / V SMEM layouts + TMA atoms (paged). + # sK uses the QK MMA operand B swizzle; sV uses the PV MMA operand B + # swizzle (MN-major). tP_layout is the TMEM-side P descriptor — no + # SMEM is actually allocated for P, it overlays the S region in TMEM + # via cute.recast_ptr in Phase 1.7. + # ------------------------------------------------------------------ + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + tma_atom_K, mK_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK_paged, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + tma_atom_V, mV_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV_paged, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # ------------------------------------------------------------------ + # Phase 1.10.6b-B-2: TMA-store atom for the epilogue write-back. + # Non-split writes bf16 final O; split-KV writes fp32 O_partial. + # sO follows FA/BSA epilogue layout: one full m_block x D tile in + # SMEM. Both paths expose global O as a packed-GQA tensor view so the + # final store is a full BSA-style m_block x D TMA tile. + # ------------------------------------------------------------------ + sO_layout = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.q_stage, + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_heads_kv_tma = mK.shape[1] + total_o_rows_tma = ( + mO_epilogue.shape[0] + // (num_heads_kv_tma * self.qhead_per_kv) + ) + head_stride_tma = self.head_dim + o_row_stride_tma = ( + num_heads_kv_tma * self.qhead_per_kv * self.head_dim) + kv_head_stride_tma = self.qhead_per_kv * self.head_dim + mO_epilogue_tma = cute.make_tensor( + mO_epilogue.iterator, + cute.make_layout( + ((self.qhead_per_kv, total_o_rows_tma), self.head_dim, num_heads_kv_tma), + stride=((head_stride_tma, o_row_stride_tma), 1, kv_head_stride_tma), + ), + ) + tma_atom_O, mO_tma = cpasync.make_tiled_tma_atom( + tma_store_op, + mO_epilogue_tma, + cute.select(sO_layout, mode=[0, 1]), + self.epi_tile, + ) + + # Pre-multiply softmax scale by log2(e) so the inner exp2 path can + # operate without re-scaling at every iteration. Mirrors prefill. + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + + work_capacity = mRequestIndices.shape[0] + num_heads_kv = mK.shape[1] + tile_sched_args = DecodeTileSchedulerArguments( + Int32(work_capacity), + Int32(num_heads_kv), + cluster_shape_mn=self.cluster_shape_mn, + ) + tile_sched_params = DecodeTileScheduler.to_underlying_arguments( + tile_sched_args, + scheduling_mode=self.scheduling_mode, + ) + self.tile_scheduler_cls = DecodeTileScheduler + grid = DecodeTileScheduler.get_grid_shape(tile_sched_params) + + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + # ------------------------------------------------------------------ + # SharedStorage mirrors BSA blk128's pipeline mesh for dense paged + # decode: Q, shared K/V, S/P/O, P-lastsplit, O-acc, O-epilogue and + # softmax stats mbarriers, plus the TMEM allocator state and SMEM + # staging tensors. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] + mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_P_full_lastsplit: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_O_full: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_softmax_stats0: cute.struct.MemRange[Int64, 2] + mbar_softmax_stats1: cute.struct.MemRange[Int64, 2] + mbar_O_epi: cute.struct.MemRange[Int64, self.s_stage * 2] + # Phase 1.10.6b-B-2: bf16 sO SMEM staging buffer for the TMA + # store epilogue. Sized for one full m_block_size × head_dim + # tile (single stage; overlap with sQ left for later perf tune). + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # ------------------------------------------------------------------ + # Launch — decode tasks are consumed from the + # (work_idx, head_kv_idx) scheduler space. In CLC mode grid is the + # BSA-style hardware problem shape; in static mode it is capped to the + # SM count and each CTA walks the flattened task stream. + # ------------------------------------------------------------------ + # q_tma_bytes (and Phase 1.5+: kv_tma_bytes / q_subtile_bytes) are + # recomputed inside the kernel from the constexpr SMEM layouts. + # Passing them as Constexpr[int] kernel args ended up marshalling + # to dynamic Int32 here, which then tripped MbarrierArray's + # `if tx_count < 0` check inside PipelineTmaUmma.create. + self.kernel( + mQ, mK_paged, mV_paged, + mPageTable, mSeqUsedK, + mRequestIndices, mQoTileIndices, mKvTileIndices, mBlockValidMask, + mSplitCounts, mOIndptr, + mO, mO_tma, mLSE, + mO_partial, mLSE_partial, + softmax_scale_log2, + sQ_layout, sK_layout, sV_layout, tP_layout, sO_layout, + tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, + tiled_mma_qk, tiled_mma_pv, + tile_sched_params, + seqlen_q, page_size, kv_chunk_size_pages, + Int32(num_heads_kv), + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=( + self.cluster_shape_mnk + if cute.size(self.cluster_shape_mnk) > 1 else None + ), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + # --- runtime tensors ------------------------------------------------- + mQ: cute.Tensor, # [((qhead_per_kv, Sq), D, Hkv, B)] + mK_paged: cute.Tensor, # [page_size, D, Hkv, num_pages] fp8 + mV_paged: cute.Tensor, # [D, page_size, Hkv, num_pages] fp8 + mPageTable: cute.Tensor, + mSeqUsedK: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mBlockValidMask: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mO_tma: cute.Tensor, + mLSE: cute.Tensor, + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + # --- scalars --------------------------------------------------------- + softmax_scale_log2: Float32, + # --- SMEM layouts ---------------------------------------------------- + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + # --- TMA atoms ------------------------------------------------------- + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + # --- TiledMma -------------------------------------------------------- + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: DecodeTileScheduler.Params, + # --- Int32 iteration bounds ------------------------------------------ + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, work-item dispatch. + # ------------------------------------------------------------------ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + if warp_idx == Int32(0): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + # ------------------------------------------------------------------ + # SMEM allocation — same SharedStorage type was registered on the + # class in __call__ (Phase 1.3). Every warp materialises the same + # storage view; later phases populate sQ/sK/sV/mbar contents. + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + # sQ is the MMA-operand layout and now also the Q TMA load target: + # PackGQA makes the global Q view match the full BSA (tile_m, D) tile. + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + # ------------------------------------------------------------------ + # TMEM allocator — MMA warp performs the allocation, all softmax / + # store / MMA warps participate in the TmemPtr named barrier that + # broadcasts the allocator pointer. Spare warp and KV-load warps + # do not touch TMEM directly. + # ------------------------------------------------------------------ + # TmemPtr participants: 2 softmax WGs (8 warps) + correction WG + # (4 warps) + MMA warp = 13 warps × WARP_SIZE. Load / spare / + # empty warps don't touch TMEM and don't arrive on this barrier. + tmem_alloc_warps: cutlass.Constexpr[int] = ( + self.warps_per_group * 3 + 1) + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + ) + tmem_cols = self.tmem_total + + # ------------------------------------------------------------------ + # Cluster layout + warp-specialized pipelines. + # Mirrors prefill (src/sm100/fwd/atten_fwd.py:617-683): cta_layout_vmnk + # is rebuilt in-kernel from tiled_mma_qk.thr_id.shape so its size is + # constexpr (the `cute.size(cta_layout_vmnk) == 1` check inside + # PipelineTmaUmma.create folds at compile time). pipeline_q is + # joined by the BSA S/P/O and shared K/V pipelines below. + # ------------------------------------------------------------------ + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + # One softmax WG participates per S/P/O stage; correction and the + # epilogue warp handle O rescale and TMA write-back. + softmax_warps = ThreadCooperativeGroup(self.warps_per_group) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + + # Recompute TMA byte counts inside the kernel from the constexpr SMEM + # layouts — see note in __call__ above the self.kernel(...) call for + # why these can't be plumbed through as Constexpr[int] kernel args. + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + k_tma_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # Decode KV follows BSA's single K/V ring: K0 is primed before Q, + # then K1, V0, K2, V1, ... share one PipelineTmaUmma state while + # landing in separate sK/sV SMEM tensors. For fp8 decode K/V TMA + # tiles have the same byte count, so the shared barrier uses K's count. + pipeline_kv = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=k_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + # ------------------------------------------------------------------ + # BSA pipeline mesh. + # pipeline_s_p_o — MMA→{softmax,correction} (8-warp cluster + # consumer). MMA producer_commit signals + # "S ready"; consumer_release signals "P stored + # and acc_O rescaled — MMA can issue next QK". + # pipeline_o_acc — MMA→correction (acc_O updated by PV). + # pipeline_sm_stats0/1 — softmax→correction stage-local stats. + # This avoids the per-warp NamedBarrier used by + # the BSA reference while preserving the same + # first/rescale/final signal sequence. + # pipeline_o_epi — correction→epilogue warp 13 (final O ready). + # ------------------------------------------------------------------ + softmax_correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE + * (self.warps_per_group + self.warps_per_group) # = 256 + ) + correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group # = 128 + ) + epilogue_warp_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE # warp 13 = 32 threads + ) + + pipeline_s_p_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_warps, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o_acc = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_O_full.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats0 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats0.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_sm_stats1 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats1.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_o_epi = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_O_epi.data_ptr(), + num_stages=self.s_stage, + producer_group=correction_threads, + consumer_group=epilogue_warp_threads, + defer_sync=True, + ) + + # Fence mbar init across all regular pipelines. CLC pipeline setup + # follows the BSA ordering: arrive after mbar init, create scheduler + # state, then wait before TMEM allocation and role dispatch. + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps = ( + self.threads_per_cta // cute.arch.WARP_SIZE + ) * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_clc_consumer_warps, + ) + clc_pipeline = cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ) + tile_scheduler = self.tile_scheduler_cls.create( + tile_sched_params, clc_response_ptr=clc_response_ptr + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + tile_scheduler.set_clc_pipeline( + clc_pipeline, clc_consumer_state) + else: + clc_pipeline = None + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # Single load warp issues Q + K + V TMA serially; no inter-warp + # broadcast / Q-load WG barrier needed (the BSA-aligned layout + # collapses the previous 4-warp Q-load fan-out into one warp). + + # ------------------------------------------------------------------ + # Phase 1.10.3: pre-dispatch TMEM partitions for softmax read/write. + # Mirrors prefill softmax body setup + # (src/sm100/fwd/atten_fwd.py:807-829, 1891-1921). Built once across + # all warps so each softmax WG can take its stage slice. + # ------------------------------------------------------------------ + thr_mma_qk_pre = tiled_mma_qk.get_slice(0) + qk_acc_shape_pre = thr_mma_qk_pre.partition_shape_C( + self.mma_tiler_qk[:2]) + tStS_base_pre = thr_mma_qk_pre.make_fragment_C(qk_acc_shape_pre) + tStS_pre = cute.make_tensor( + tStS_base_pre.iterator, + cute.append( + tStS_base_pre.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tScS_pre = thr_mma_qk_pre.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS_pre = tScS_pre[(None, None), 0, 0] + # fp8 P occupies n_block_size * fp8_width / fp32_width fp32 cols. + tilePlikeFP32 = const_expr( + self.mma_tiler_qk[1] * self.q_dtype.width // Float32.width) + tmem_load_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + # Repetition(8) gives ``tStP_r2t.shape[2] = tilePlikeFP32 / 8 = 4`` + # chunks for fp8 (tilePlikeFP32=32), with each chunk publishing + # 8 fp32 cols = 32 fp8 cols = exactly one PV ``f8f6f4`` K=32 + # segment. ``split_idx = 4 * 3N/4 / N = 3`` aligns the early + # publish edge to the producer/consumer K boundary. Larger + # Repetition (e.g. 16) would coarsen shape[2] to 2 and force + # split_idx to floor to 1, publishing only N/2 of P before MMA's + # first three K=32 segments need cols 0..3N/4 — that mismatch is + # the NaN source the workaround used to dodge with split=N/2. + tmem_store_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), + Float32, + ) + tmem_store_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tmem_load_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + # ------------------------------------------------------------------ + # Warp role dispatch. Bodies are filled in Phase 1.3-1.9: + # softmax WG 0/1 (warps 0-3, 4-7) — softmax + P fp32->fp8 convert + # store / Q-load WG (warps 8-11) — Q TMA gather + epilogue store + # MMA warp (warp 12) — UTCMMA QK + PV issue + # correction WG (warps 8-11) — per-page acc_O rescale + epilogue + # MMA warp (warp 12) — UTCMMA QK + PV issue + # spare warp (warp 13) — empty / future CLC scheduler + # load warp (warp 14) — serial Q + K + V TMA loads + # empty warp (warp 15) — register-budget reserve + # ------------------------------------------------------------------ + is_softmax0_warp = ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + ) + is_softmax1_warp = ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.correction_warp_base) + ) + is_correction_warp = ( + warp_idx >= Int32(self.correction_warp_base) + and warp_idx < Int32(self.mma_warp_id) + ) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + is_spare_warp = warp_idx == Int32(self.spare_warp_id) + is_load_warp = warp_idx == Int32(self.load_warp_id) + is_empty_warp = warp_idx == Int32(self.empty_warp_id) + + if const_expr(self.use_clc_scheduler): + if warp_idx == Int32(self.clc_scheduler_warp_id): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + self.clc_scheduler_warp(clc_pipeline, tile_scheduler) + is_empty_warp = False + + if is_softmax0_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg0 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg0 + self.softmax_loop( + 0, + self.softmax0_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats0, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_softmax1_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg1 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg1 + self.softmax_loop( + 1, + self.softmax1_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats1, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_correction_warp: + cute.arch.setmaxregister_decrease(self.num_regs_correction) + # Participate in TmemPtr handshake so the MMA warp can free. + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_corr + + self.correction_loop( + tiled_mma_pv, + tStS_pre, + tScS_pre, + tmem_load_vec_atom_pre, + pipeline_s_p_o, + pipeline_sm_stats0, + pipeline_sm_stats1, + pipeline_o_acc, + pipeline_o_epi, + sO, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mSplitCounts, + mOIndptr, + mLSE, + mLSE_partial, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + num_heads_kv, + softmax_scale_log2, + ) + tmem_alloc_barrier.arrive() + + if is_spare_warp: + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + self.epilogue_s2g( + mO_tma, + sO, + tma_atom_O, + pipeline_o_epi, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mOIndptr, + mBlockValidMask, + tile_scheduler, + seqlen_q, + ) + + if is_load_warp: + self.load( + tiled_mma_qk, + tiled_mma_pv, + mQ, + mK_paged, + mV_paged, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_q, + pipeline_kv, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + if is_empty_warp: + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + if is_mma_warp: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ---------------------------------------------------------------- + # MMA warp — Phase 1.6: QK fp8×fp8→fp32 UMMA. Phase 1.10.1 now + # wraps the body in the real TMEM allocator lifecycle: + # tmem.allocate(cols) -> wait_for_alloc -> retrieve_ptr + # -> ... QK work ... + # -> relinquish_alloc_permit -> tmem_alloc_barrier.arrive_and_wait + # -> free(ptr, cols) + # Softmax WG 0/1 participate via wait_for_alloc + retrieve_ptr + + # tmem_alloc_barrier.arrive (4+4+1 = 9 warps). + # ---------------------------------------------------------------- + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr # consumed by gemm_pv via raw TMEM offsets + + self.mma( + sQ, + sK, + sV, + tP_layout, + tiled_mma_qk, + tiled_mma_pv, + pipeline_q, + pipeline_kv, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_o_acc, + mRequestIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + # Phase 1.10.1: TMEM allocator teardown. + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + + @cute.jit + def clc_scheduler_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync, + tile_scheduler: DecodeTileScheduler, + ) -> None: + clc_producer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, + self.sched_stages, + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.producer_acquire(clc_producer_state) + mbarrier_addr = clc_pipeline.producer_get_barrier( + clc_producer_state) + tile_scheduler.advance_to_next_work( + mbarrier_addr=mbarrier_addr, + response_stage=clc_producer_state.index, + ) + clc_producer_state.advance() + + clc_pipeline.consumer_wait(clc_consumer_state) + work_tile = tile_scheduler.get_current_work( + response_stage=clc_consumer_state.index) + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() + clc_pipeline.producer_tail(clc_producer_state) + + @cute.jit + def correction_loop( + self, + tiled_mma_pv: cute.TiledMma, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tmem_load_vec_atom_pre: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_sm_stats0: pipeline.PipelineAsync, + pipeline_sm_stats1: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + pipeline_o_epi: pipeline.PipelineAsync, + sO: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mLSE: cute.Tensor, + mLSE_partial: Optional[cute.Tensor], + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + softmax_scale_log2: Float32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg_corr = warp_idx - Int32(self.correction_warp_base) + group_tidx_corr = ( + warp_idx_in_wg_corr * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + + # First iter: no correction is required. Notify MMA that the + # initial O slots are available, matching BSA's correction_loop. + for stage_init in cutlass.range_constexpr(self.s_stage): + pipeline_s_p_o.consumer_release_w_index(Int32(stage_init)) + + o_corr_consumer_phase = Int32(0) + sm_stats0_consumer_phase = Int32(0) + sm_stats1_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + thr0_rs = tiled_mma_pv.get_slice(0) + pv_acc_shape_rs_c = thr0_rs.partition_shape_C( + self.mma_tiler_pv[:2]) + tOtO_base_rs_c = thr0_rs.make_fragment_C(pv_acc_shape_rs_c) + tOtO_rs_c = cute.make_tensor( + tOtO_base_rs_c.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base_rs_c.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tScS_vec_layout_corr = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec_corr = cute.make_tensor( + tScS_pre.iterator, tScS_vec_layout_corr) + tSAcc_corr0 = tStS_pre[(None, None), 0, 0, 0] + tSAcc_corr1 = tStS_pre[(None, None), 0, 0, 1] + tStS_vec0_layout_corr = cute.composition( + tSAcc_corr0.layout, cute.make_layout((self.m_block_size, 2))) + tStS_vec1_layout_corr = cute.composition( + tSAcc_corr1.layout, cute.make_layout((self.m_block_size, 2))) + tStStats0_t2r_src = cute.make_tensor( + tSAcc_corr0.iterator, tStS_vec0_layout_corr) + tStStats1_t2r_src = cute.make_tensor( + tSAcc_corr1.iterator, tStS_vec1_layout_corr) + thr_tmem_load_vec = tcgen05.make_tmem_copy( + tmem_load_vec_atom_pre, + tStStats0_t2r_src, + ).get_slice(group_tidx_corr) + tStStats0_t2r = thr_tmem_load_vec.partition_S(tStStats0_t2r_src) + tStStats1_t2r = thr_tmem_load_vec.partition_S(tStStats1_t2r_src) + tScStats_t2r = thr_tmem_load_vec.partition_D(tScS_vec_corr) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_corr = mRequestIndices[work_idx] + qo_tile_corr = mQoTileIndices[work_idx] + seqused_k_corr = mSeqUsedK[batch_idx_corr] + split_idx_corr = mKvTileIndices[work_idx] + kv_pages_corr = ( + seqused_k_corr + page_size - Int32(1)) // page_size + kv_page_begin_corr = split_idx_corr * kv_chunk_size_pages + kv_page_end_corr = cutlass.min( + kv_pages_corr, + kv_page_begin_corr + kv_chunk_size_pages, + ) + page_count_corr = kv_page_end_corr - kv_page_begin_corr + block_iter_count_corr = ( + page_count_corr + Int32(1)) & ~Int32(1) + stage0_count_corr = block_iter_count_corr // Int32(2) + stage1_count_corr = block_iter_count_corr // Int32(2) + + if stage0_count_corr > Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + if stage1_count_corr > Int32(0): + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + for page_rel_corr in cutlass.range( + Int32(self.s_stage), block_iter_count_corr, unroll=1 + ): + # sm_stats[0] now holds the deferred-exp2 log2-delta: + # 0.0 means "no rescale needed", a negative value is the + # raw delta that needs exp2 to become a true scale factor. + if (page_rel_corr & Int32(1)) == Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 0], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 1], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(1)) + + for stage_wait in cutlass.range_constexpr(self.s_stage): + stage_count_wait = ( + stage0_count_corr + if const_expr(stage_wait == 0) + else stage1_count_corr + ) + if stage_count_wait > Int32(0): + pipeline_o_acc.consumer_wait_w_index_phase( + Int32(stage_wait), o_corr_consumer_phase) + + row_sum0 = Float32(0.0) + row_sum1 = Float32(0.0) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + for stage_final in cutlass.range_constexpr(self.s_stage): + if const_expr(stage_final == 0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum0 = tSrStats[0] + row_max0 = tSrStats[1] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum1 = tSrStats[0] + row_max1 = tSrStats[1] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + zero0 = row_sum0 == Float32(0.0) or row_sum0 != row_sum0 + zero1 = row_sum1 == Float32(0.0) or row_sum1 != row_sum1 + rm0 = -Float32.inf if zero0 else row_max0 + rm1 = -Float32.inf if zero1 else row_max1 + row_max_comb = cutlass.max(rm0, rm1) + row_max_safe = ( + Float32(0.0) + if row_max_comb == -Float32.inf + else row_max_comb + ) + scale0 = ( + Float32(0.0) + if zero0 + else cute.math.exp2( + (rm0 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + scale1 = ( + Float32(0.0) + if zero1 + else cute.math.exp2( + (rm1 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + row_sum_comb = row_sum0 * scale0 + row_sum1 * scale1 + combined_zero_or_nan = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + inv_sum = cute.arch.rcp_approx( + Float32(1.0) + if combined_zero_or_nan else row_sum_comb) + final_scale0 = scale0 * inv_sum + final_scale1 = scale1 * inv_sum + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(0), corr_epi_producer_phase) + self.correction_epilogue_combine( + tiled_mma_pv, + sO[None, None, 0], + group_tidx_corr, + final_scale0, + final_scale1, + ) + + if const_expr(self.write_lse or self.split_kv): + if group_tidx_corr < Int32(self.m_block_size): + is_bad_lse = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + LN2 = Float32(math.log(2.0)) + lse_val = ( + -Float32.inf if is_bad_lse + else ( + row_max_safe * softmax_scale_log2 + + cute.math.log2(row_sum_comb, fastmath=True) + ) * LN2 + ) + tok_lse = group_tidx_corr // Int32(self.qhead_per_kv) + if tok_lse < seqlen_q: + h_in_kv_lse = ( + group_tidx_corr + - tok_lse * Int32(self.qhead_per_kv)) + q_idx_lse = ( + qo_tile_corr * Int32(self.q_tokens_per_group) + + tok_lse + ) + h_abs_lse = ( + head_kv_idx * Int32(self.qhead_per_kv) + + h_in_kv_lse + ) + if const_expr(self.split_kv): + q_tokens_per_group = Int32( + self.q_tokens_per_group) + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row_lse = ( + mOIndptr[batch_idx_corr] + + split_idx_corr * q_stride_partial + + q_idx_lse + ) + mLSE_partial[ + partial_row_lse, h_abs_lse] = lse_val + else: + q_abs_lse = ( + batch_idx_corr * seqlen_q + q_idx_lse) + mLSE[q_abs_lse, h_abs_lse] = lse_val + + for stage_release in cutlass.range_constexpr(self.s_stage): + stage_count_release = ( + stage0_count_corr + if const_expr(stage_release == 0) + else stage1_count_corr + ) + if stage_count_release > Int32(0): + pipeline_s_p_o.consumer_release_w_index( + Int32(stage_release)) + pipeline_o_acc.consumer_release_w_index( + Int32(stage_release)) + if block_iter_count_corr > Int32(0): + o_corr_consumer_phase = ( + o_corr_consumer_phase ^ Int32(1)) + + pipeline_o_epi.producer_commit_w_index(Int32(0)) + corr_epi_producer_phase = ( + corr_epi_producer_phase ^ Int32(1)) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), corr_epi_producer_phase) + + @cute.jit + def epilogue_s2g( + self, + mO_tma: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + pipeline_o_epi: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mOIndptr: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + ) -> None: + epi_consumer_phase = Int32(0) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + split_idx = mKvTileIndices[work_idx] + + pipeline_o_epi.consumer_wait_w_index_phase( + Int32(0), epi_consumer_phase) + q_tokens_per_group = Int32(self.q_tokens_per_group) + gO = cute.local_tile( + mO_tma[None, None, head_kv_idx], + self.epi_tile, + (None, 0), + ) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO) + if const_expr(not self.split_kv): + q_abs = ( + batch_idx * seqlen_q + + qo_tile * q_tokens_per_group + ) + dst_idx = q_abs // q_tokens_per_group + else: + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row = ( + mOIndptr[batch_idx] + + split_idx * q_stride_partial + + qo_tile * q_tokens_per_group + ) + dst_idx = partial_row // q_tokens_per_group + store_O(src_idx=Int32(0), dst_idx=dst_idx) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0) + pipeline_o_epi.consumer_release_w_index(Int32(0)) + epi_consumer_phase = epi_consumer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def correction_epilogue_combine( + self, + tiled_mma_pv: cute.TiledMma, + sO: cute.Tensor, + tidx: Int32, + scale0: Float32, + scale1: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr_mma.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr_mma.make_fragment_C(pv_acc_shape) + tOtO = cute.make_tensor( + tOtO_base.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tOsO = thr_mma.get_slice(0).partition_C(sO) + tOcO_full = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = ( + 8 * 32 // self.o_dtype.width + ) + tOsO_i = cute.logical_divide( + tOsO, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOcO_i = cute.logical_divide( + tOcO_full, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO0_i = cute.logical_divide( + tOtO[None, None, None, 0], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO1_i = cute.logical_divide( + tOtO[None, None, None, 1], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_load_atom = sm100_utils.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=self.use_2cta_instrs, + ) + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO0_i[(None, None), 0]) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load) + tiled_smem_store = cute.make_tiled_copy_D( + smem_copy_atom, tiled_tmem_load) + tOtO0_t2r = thr_tmem_load.partition_S( + tOtO0_i[(None, None), None]) + tOtO1_t2r = thr_tmem_load.partition_S( + tOtO1_i[(None, None), None]) + tOsO_s2r = copy_utils.partition_D_position_independent( + thr_tmem_load, tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D( + tOcO_i[(None, None), None]) + + for col_pass_idx in cutlass.range( + self.head_dim // corr_tile_size, unroll_full=True): + tOtO0_t2r_i = tOtO0_t2r[None, 0, 0, col_pass_idx] + tOtO1_t2r_i = tOtO1_t2r[None, 0, 0, col_pass_idx] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, col_pass_idx] + frg_shape = tOcO_t2r[None, 0, 0, col_pass_idx].shape + tOrO0_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + tOrO1_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + is_zero_output = ( + scale0 == Float32(0.0) and scale1 == Float32(0.0) + ) + if not is_zero_output: + cute.copy(tiled_tmem_load, tOtO0_t2r_i, tOrO0_frg) + cute.copy(tiled_tmem_load, tOtO1_t2r_i, tOrO1_frg) + for j in cutlass.range( + 0, cute.size(tOrO0_frg), 2, unroll_full=True + ): + o0_a, o0_b = cute.arch.mul_packed_f32x2( + (tOrO0_frg[j], tOrO0_frg[j + 1]), + (scale0, scale0), + ) + o1_a, o1_b = cute.arch.mul_packed_f32x2( + (tOrO1_frg[j], tOrO1_frg[j + 1]), + (scale1, scale1), + ) + tOrO0_frg[j], tOrO0_frg[j + 1] = ( + cute.arch.add_packed_f32x2( + (o0_a, o0_b), (o1_a, o1_b)) + ) + else: + tOrO0_frg.fill(Float32(0.0)) + copy_utils.cvt_copy(tiled_smem_store, tOrO0_frg, tOsO_r2s_i) + cute.arch.fence_view_async_shared() + + @cute.jit + def correction_rescale( + self, + tiled_mma_pv: cute.TiledMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + tOcO = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = 16 + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition( + tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count: cutlass.Constexpr[int] = self.head_dim // corr_tile_size + for fi in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment( + tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor( + tOtO_t2r.iterator + fi * corr_tile_size, + tOtO_t2r.layout, + ) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range( + 0, cute.size(tOrO_frg), 2, unroll_full=True + ): + tOrO_frg[j], tOrO_frg[j + 1] = ( + cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + ) + tOtO_r2t_i = cute.make_tensor( + tOtO_r2t.iterator + fi * corr_tile_size, + tOtO_r2t.layout, + ) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tP_layout: cute.ComposedLayout, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + thr_mma_qk = tiled_mma_qk.get_slice(0) + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0_layout = tSrQ[None, None, None, 0].layout + tSrK0_layout = tSrK[None, None, None, 0].layout + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, 0].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, q_smem_base, tSrQ0_layout, + var_name_prefix="decode_q_smem_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="decode_qk_idesc") + gemm_qk = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0_layout, + smem_var_name_prefix="decode_q_smem_desc", + idesc_var_name="decode_qk_idesc", + smem_offset=0, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP_base = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = const_expr(Float32.width // self.v_dtype.width) + tP_stage_stride = const_expr( + self.tmem_stage_stride * tP_width_ratio) + tOrP = cute.make_tensor( + tOrP_base.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP_base.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + tOrV = tiled_mma_pv.make_fragment_B(sV) + pv_mma_op = tiled_mma_pv.op + sm100_helpers.declare_ptx_idesc( + pv_mma_op, var_name="decode_pv_idesc") + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage) + phase_s0 = Int32(0) + phase_s1 = Int32(0) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_mma = mRequestIndices[work_idx] + split_idx_mma = mKvTileIndices[work_idx] + seqused_k_mma = mSeqUsedK[batch_idx_mma] + kv_pages_mma = ( + seqused_k_mma + page_size - Int32(1)) // page_size + kv_page_begin_mma = split_idx_mma * kv_chunk_size_pages + kv_page_end_mma = cutlass.min( + kv_pages_mma, + kv_page_begin_mma + kv_chunk_size_pages, + ) + page_count_mma = kv_page_end_mma - kv_page_begin_mma + block_iter_count_mma = ( + page_count_mma + Int32(1)) & ~Int32(1) + + pipeline_q.consumer_wait_w_index_phase( + Int32(0), mma_q_consumer_phase) + mma_q_consumer_phase = mma_q_consumer_phase ^ Int32(1) + if block_iter_count_mma > Int32(0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(0)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(1): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(1)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(self.s_stage): + for page_rel_pv in cutlass.range( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + unroll=1, + ): + pv_slot = page_rel_pv & Int32(1) + pv_stage_iter = page_rel_pv // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + page_rel_qk = page_rel_pv + Int32(self.s_stage) + qk_slot = page_rel_qk & Int32(1) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + qk_slot * Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(qk_slot) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + pipeline_q.consumer_release_w_index(Int32(0)) + + if block_iter_count_mma > Int32(0): + page_rel_epi_begin = cutlass.max( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin, block_iter_count_mma, unroll=1 + ): + pv_slot = page_rel_epi & Int32(1) + pv_stage_iter = page_rel_epi // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + pipeline_o_acc.producer_commit_w_index(pv_slot) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def softmax_loop( + self, + stage: cutlass.Constexpr[int], + warp_base: cutlass.Constexpr[int], + softmax_scale_log2: Float32, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tilePlikeFP32: cutlass.Constexpr[int], + tmem_load_atom_pre: cute.CopyAtom, + tmem_store_atom_pre: cute.CopyAtom, + tmem_store_vec_atom_pre: cute.CopyAtom, + thr_mma_qk_pre: cute.core.ThrMma, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg = warp_idx - Int32(warp_base) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stage_i32 = Int32(stage) + + tSAcc = tStS_pre[(None, None), 0, 0, stage] + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom_pre, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS_pre) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32)), + ) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, + tStP_layout, + ) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom_pre, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + tScS_vec_layout = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec = cute.make_tensor(tScS_pre.iterator, tScS_vec_layout) + tStS_vec_layout = cute.composition( + tSAcc.layout, cute.make_layout((self.m_block_size, 2))) + tStStats_r2t_dst = cute.make_tensor( + tSAcc.iterator, tStS_vec_layout) + thr_tmem_store_vec = tcgen05.make_tmem_copy( + tmem_store_vec_atom_pre, + tStStats_r2t_dst, + ).get_slice(group_tidx) + tStStats_r2t = thr_tmem_store_vec.partition_D(tStStats_r2t_dst) + tScStats_r2t = thr_tmem_store_vec.partition_S(tScS_vec) + tScP_shape = ( + self.mma_tiler_qk[0] // thr_mma_qk_pre.thr_id.shape, + tilePlikeFP32, + ) + + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32, + ) + s_consumer_phase = Int32(0) + sm_stats_producer_phase = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=self.rescale_threshold, + ) + softmax.reset() + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + seqused_k = mSeqUsedK[batch_idx] + split_idx = mKvTileIndices[work_idx] + kv_pages = ( + seqused_k + page_size - Int32(1)) // page_size + kv_page_begin = split_idx * kv_chunk_size_pages + kv_page_end = cutlass.min( + kv_pages, kv_page_begin + kv_chunk_size_pages + ) + page_count = kv_page_end - kv_page_begin + block_iter_count = (page_count + Int32(1)) & ~Int32(1) + if const_expr(stage == 0): + stage_page_count = block_iter_count // Int32(2) + else: + stage_page_count = block_iter_count // Int32(2) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seqlen_q, + seqused_k, + False, + False, + False, + True, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + qhead_per_kvhead_packgqa=self.qhead_per_kv, + ) + wg_count = stage_page_count + if wg_count > Int32(0): + page_rel0 = stage_i32 + page_rel0_clamped = cutlass.min( + page_rel0, page_count - Int32(1)) + page_idx_global = kv_page_end - Int32(1) - page_rel0_clamped + kv_valid_cols = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global * page_size, + ) + if page_rel0 >= page_count: + kv_valid_cols = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, + mask, + stage_i32, + s_consumer_phase, + page_idx_global, + qo_tile, + kv_valid_cols, + tStS_t2r, + tScS_t2r, + tStP_r2t, + tSrP_r2t_f32, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, + warp_idx_in_wg, + tStStats_r2t, + tScStats_r2t, + sm_stats_producer_phase, + is_first=True, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + for stage_iter in cutlass.range( + Int32(1), wg_count, unroll=1 + ): + page_rel = ( + stage_iter * Int32(self.s_stage) + stage_i32) + page_rel_clamped = cutlass.min( + page_rel, page_count - Int32(1)) + page_idx_global_n = ( + kv_page_end - Int32(1) - page_rel_clamped) + kv_valid_cols_n = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global_n * page_size, + ) + # Dummy-iter analysis: with s_stage=2, the WG that + # handles stage_i32=0 only ever sees page_rel ≤ + # block_iter_count - 2 < page_count → NEVER dummy. + # The WG with stage_i32=1 sees page_rel = + # block_iter_count - 1 at its last iter, which + # equals page_count iff page_count is odd → only + # WG1 may need the runtime mask_dummy_only guard. + # Pass None for WG0 so the const_expr branch in + # softmax_step eliminates the runtime check + # entirely (compile-time disappears). + if const_expr(stage == 0): + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + # mask_dummy_only=None → no runtime check + ) + else: + is_dummy = page_rel >= page_count + if is_dummy: + kv_valid_cols_n = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + mask_dummy_only=is_dummy, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = softmax.row_sum[0] + tSrStats[1] = softmax.row_max[0] + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + else: + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = Float32(0.0) + tSrStats[1] = -Float32.inf + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + mask: AttentionMask, + stage: Int32, + s_phase: Int32, + page_idx: Int32, + qo_tile: Int32, + kv_valid_cols: Int32, + tStS_t2r: cute.Tensor, + tScS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tSrP_r2t_f32: cute.Tensor, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_vec: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tStStats_r2t: cute.Tensor, + tScStats_r2t: cute.Tensor, + sm_stats_producer_phase: Int32, + is_first: cutlass.Constexpr[bool], + apply_mask: cutlass.Constexpr[bool] = True, + mask_dummy_only: Optional[cutlass.Boolean] = None, + ) -> Int32: + # apply_mask=False is the inner-page fast path: skip both the seqlen + # bounds check and the causal-diagonal check, which together cost ~15 + # cyc per iter on the producer pre-publication critical path that + # gates correction WG's consumer_wait (top long_scoreboard PC in NCU). + # Callers must only set apply_mask=False when they can prove the tile + # is fully unmasked (no partial-page seqlen tail, no causal diagonal + # cut). + # + # mask_dummy_only (runtime bool, used only when apply_mask=False): + # when True the iter is a "dummy" rounded-up iter that needs the + # mask to zero out garbage S — runs the mask at runtime cost. For + # non-dummy iters it stays the fast no-mask path. + pipeline_s_p_o.consumer_wait_w_index_phase(stage, s_phase) + sm_stats_try_acquire = ( + pipeline_sm_stats.producer_try_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + ) + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(apply_mask): + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + elif const_expr(mask_dummy_only is not None): + if mask_dummy_only: + # Dummy iter — zero everything via mask (kv_valid_cols=0 + # makes mask_r2p_lambda set all positions to -inf). + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + # Publish acc_scale in log2-domain (un-exp2'd); correction WG does + # the exp2 only when an actual rescale fires. Removes MUFU.EX2 from + # the sm_stats publication critical path that gates correction's + # consumer_wait (the dominant long_scoreboard hot PC in NCU). + row_max, acc_scale_log2 = softmax.update_row_max_deferred_exp2( + tSrS_t2r.load(), is_first) + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase, sm_stats_try_acquire) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = acc_scale_log2 + tSrStats[1] = row_max + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # exp2 for the internal row_sum carry happens AFTER producer_commit, so + # it no longer extends correction's consumer-wait window. + # acc_scale_log2 == 0.0 in the threshold/first-iter paths makes + # exp2(0)=1.0, which is the no-rescale identity for the row_sum carry — + # semantically equivalent to the original ``acc_scale=1.0`` branch. + if const_expr(is_first): + row_sum_init = Float32(0.0) + else: + acc_scale_mult = cute.math.exp2(acc_scale_log2, fastmath=True) + row_sum_init = softmax.row_sum[0] * acc_scale_mult + # Bulk EX2 emulation parameters. + # + # ex2_emu_freq=16 emulate exp2 with FFMA2 polynomial on + # 15 of every 16 (j, k) positions; the + # remaining 1/16 still issues MUFU.EX2. + # This cuts the MUFU.EX2 throughput bottleneck + # in the softmax inner loop (≈22k cyc + # saved per stage at baseline). + # ex2_emu_res=3 degree-3 polynomial; res=4 broke + # kv=1024 close-tolerance even with + # poly_degree=5 — 3 is the most aggressive + # setting that still passes cos_sim ≥ 0.99 + # against the reference for the fp8 PV path. + # ex2_emu_start_frg=1 skip the emulation for fragment index 0 + # (preserves accuracy on the first iter + # where row_max is least settled). + # + # If you tune these, re-run the variable-kv self-consistency check + # (split vs non-split must stay at cos_min ≥ 0.99). + softmax.row_sum[0] = softmax.scale_apply_exp2_convert_sum( + tSrS_t2r, + row_max, + tSrP_r2t, + row_sum_init, + ex2_emu_freq=16, + ex2_emu_res=3, + ex2_emu_start_frg=1, + ) + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k], + ) + if const_expr(self.split_P_arrive > 0): + split_P_arrive_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive + // self.n_block_size + ) + if const_expr(k + 1 == split_P_arrive_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_s_p_o.consumer_release_w_index(stage) + cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_p_lastsplit.producer_commit_w_index(stage) + else: + pipeline_s_p_o.consumer_release_w_index(stage) + return sm_stats_producer_phase + + @cute.jit + def load( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mQ: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + cute.arch.setmaxregister_decrease(self.num_regs_load) + thr_mma_qk_ld = tiled_mma_qk.get_slice(0) + thr_mma_pv_ld = tiled_mma_pv.get_slice(0) + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_ld = mRequestIndices[work_idx] + qo_tile_ld = mQoTileIndices[work_idx] + split_idx_ld = mKvTileIndices[work_idx] + seqused_k_ld = mSeqUsedK[batch_idx_ld] + kv_pages_ld = ( + seqused_k_ld + page_size - Int32(1)) // page_size + kv_page_begin_ld = split_idx_ld * kv_chunk_size_pages + kv_page_end_ld = cutlass.min( + kv_pages_ld, kv_page_begin_ld + kv_chunk_size_pages + ) + page_count_ld = kv_page_end_ld - kv_page_begin_ld + block_iter_count_ld = ( + page_count_ld + Int32(1)) & ~Int32(1) + physical_page_v0 = Int32(0) + physical_page_v1 = Int32(0) + + mQ_cur_ld = mQ[None, None, None, batch_idx_ld][ + None, None, head_kv_idx + ] + tiler_gQ_ld = ( + (self.mma_tiler_qk[0] * self.q_stage), + self.head_dim, + ) + gQ_ld = cute.local_tile( + mQ_cur_ld, tiler_gQ_ld, (qo_tile_ld, 0)) + gQ_ld = layout_utils.select( + cute.flat_divide(gQ_ld, (self.mma_tiler_qk[0],)), + mode=[0, 2, 1], + ) + tSgQ_ld = thr_mma_qk_ld.partition_A(gQ_ld) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ_ld, sQ + ) + mK_cur_ld = mK_paged[None, None, head_kv_idx, None] + gK_ld = cute.local_tile( + mK_cur_ld, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + tSgK_ld = thr_mma_qk_ld.partition_B(gK_ld) + tKsK_ld, tKgK_ld = cpasync.tma_partition( + tma_atom_K, 0, cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_ld, 0, 3), + ) + mV_cur_ld = mV_paged[None, None, head_kv_idx, None] + gV_ld = cute.local_tile( + mV_cur_ld, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + tOgV_ld = thr_mma_pv_ld.partition_B(gV_ld) + tVsV_ld, tVgV_ld = cpasync.tma_partition( + tma_atom_V, 0, cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV_ld, 0, 3), + ) + + if block_iter_count_ld > Int32(0): + # Prime K0 before Q; then follow BSA order + # K1, V0, K2, V1, ... + page_idx_ld0 = kv_page_end_ld - Int32(1) + physical_page_v0 = mPageTable[batch_idx_ld, page_idx_ld0] + physical_page_v1 = physical_page_v0 + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v0, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + self.load_Q( + load_Q_fn_full, + pipeline_q, + Int32(0), + q_producer_phase, + ) + q_producer_phase = q_producer_phase ^ Int32(1) + + if block_iter_count_ld > Int32(0): + if block_iter_count_ld > Int32(1): + page_rel_k1 = cutlass.min( + Int32(1), page_count_ld - Int32(1)) + page_idx_ld1 = kv_page_end_ld - Int32(1) - page_rel_k1 + physical_page_v1 = mPageTable[ + batch_idx_ld, page_idx_ld1] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v1, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + if block_iter_count_ld > Int32(2): + for page_rel in cutlass.range( + Int32(0), + block_iter_count_ld - Int32(2), + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + page_rel_k_ld = cutlass.min( + page_rel + Int32(2), + page_count_ld - Int32(1), + ) + page_idx_k_ld = ( + kv_page_end_ld - Int32(1) - page_rel_k_ld) + physical_page_k_ld = mPageTable[ + batch_idx_ld, page_idx_k_ld] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_k_ld, + pipeline_kv, + kv_producer_state, + ) + if (page_rel & Int32(1)) == Int32(0): + physical_page_v0 = physical_page_k_ld + else: + physical_page_v1 = physical_page_k_ld + kv_producer_state.advance() + + page_rel_epi_begin_ld = cutlass.max( + Int32(0), + block_iter_count_ld - Int32(2), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin_ld, + block_iter_count_ld, + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel_epi, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel_epi & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.consumer_advance() + + pipeline_kv.producer_tail(kv_producer_state) + pipeline_q.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), q_producer_phase) + + @cute.jit + def load_Q( + self, + load_Q_fn: Callable, + pipeline_q: pipeline.PipelineAsync, + stage: Int32, + phase: Int32, + ) -> None: + pipeline_q.producer_acquire_w_index_phase(stage, phase) + load_Q_fn( + src_idx=Int32(0), + dst_idx=stage, + tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage), + ) + + @cute.jit + def load_KV_physical( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + physical_page: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + ) -> None: + pipeline_kv.producer_acquire(producer_state) + cute.copy( + tma_atom, + tXgX[(None, 0, physical_page)], + tXsX[(None, producer_state.index)], + tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state), + ) + +_atten_compile_cache: dict[tuple[object, ...], object] = {} + + +def run_decode_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + disable_softmax_exp2: bool = False, + O_partial_dummy: Optional[torch.Tensor] = None, + LSE_partial_dummy: Optional[torch.Tensor] = None, +) -> None: + """Launch the SM100 UMMA paged decode attention CUTE DSL kernel. + + qhead_per_kv is derived from input shapes (q.shape[1] // k.shape[1]). + disable_softmax_exp2 toggles the sage-style host flag (decision §1.7); + default False keeps full ex2 emulation. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` let callers pre-allocate the + placeholder buffers for the non-split path, avoiding ~5us of per-call + ``torch.empty`` overhead in tight decoding loops. + """ + + q_dtype = torch2cute_dtype_map[q.dtype] + o_dtype = torch2cute_dtype_map[out.dtype] + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + write_lse = bool(return_lse) or bool(split_kv) + if int(seqlen_q) != q_tokens_per_group: + raise NotImplementedError( + "decode fp8 currently assumes one full packed-q tile: " + f"seqlen_q must equal {q_tokens_per_group}, got {seqlen_q}" + ) + key = ( + "decode_attention", + q.shape[-1], + q_dtype, + o_dtype, + bool(split_kv), + bool(causal), + int(qhead_per_kv), + int(seqlen_q), + bool(write_lse), + bool(disable_softmax_exp2), + ) + if key not in _atten_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + head_q = cute.sym_int64() + num_pages = cute.sym_int64() + head_kv = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + max_pages = cute.sym_int64() + work_capacity = cute.sym_int64() + partial_rows = cute.sym_int64() + partial_rows_flat = cute.sym_int64() + head_dim = int(q.shape[-1]) + kernel = SparseDecodeAttentionForwardSm100( + head_dim=head_dim, + qhead_per_kv=int(qhead_per_kv), + page_size=int(page_size), + split_kv=bool(split_kv), + causal=bool(causal), + write_lse=bool(write_lse), + disable_softmax_exp2=bool(disable_softmax_exp2), + ) + # Always pass non-None fake tensors so the @cute.kernel positional + # arg marshalling stays stable; the kernel only reads these when + # split_kv=True (decision #10 epilogue branch). + fake_O_partial = make_fake_tensor( + Float32, (partial_rows_flat, head_dim), divisibility=4) + fake_LSE_partial = make_fake_tensor( + Float32, (partial_rows, head_q), divisibility=1, leading_dim=1) + # Q is passed as a [B, Sq, Hq, D] view so the kernel can build the same + # PackGQA TMA view used by FA/BSA and issue one full-tile Q TMA. + # O still uses the compact 2D view for the packed-GQA TMA epilogue. + total_q_flat = cute.sym_int64() + _atten_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor( + q_dtype, (batch, int(seqlen_q), head_q, head_dim), + divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(Int32, (batch, max_pages), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(o_dtype, (total_q_flat, head_dim), divisibility=128 // o_dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + fake_O_partial, + fake_LSE_partial, + Float32(float(softmax_scale)), + Int32(int(seqlen_q)), + Int32(int(page_size)), + Int32(int(kv_chunk_size_pages)), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + q_4d = q.view( + q.shape[0] // int(seqlen_q), int(seqlen_q), q.shape[1], q.shape[2]) + out_2d = out.view(out.shape[0] * out.shape[1], out.shape[2]) + # Compile keeps non-None fake partial buffers for positional stability + # (see fake_O_partial / fake_LSE_partial above). Runtime callers that + # don't need them (split_kv=False) pass None; allocate small uninitialized + # dummy buffers so the kernel signature still matches without launching + # torch fill kernels. + if O_partial is None: + # Reuse caller-cached dummy when available (e.g. the + # SparseDecodePagedAttentionWrapper plan() pre-allocation), else + # allocate a small placeholder on the fly. + O_partial_kernel = ( + O_partial_dummy + if O_partial_dummy is not None + else torch.empty( + (1, q.shape[2]), dtype=torch.float32, device=q.device) + ) + else: + O_partial_kernel = O_partial.view( + O_partial.shape[0] * O_partial.shape[1], O_partial.shape[2]) + if LSE_partial is None: + LSE_partial = ( + LSE_partial_dummy + if LSE_partial_dummy is not None + else torch.empty( + (1, q.shape[1]), dtype=torch.float32, device=q.device) + ) + with torch.cuda.nvtx.range("Decode_Attention"): + _atten_compile_cache[key]( + q_4d, k, v, page_table, seqused_k, + request_indices, qo_tile_indices, kv_tile_indices, block_valid_mask, + split_counts, o_indptr, + out_2d, lse, O_partial_kernel, LSE_partial, + softmax_scale, seqlen_q, page_size, kv_chunk_size_pages, + ) + + +__all__ = ["SparseDecodeAttentionForwardSm100", "run_decode_attention"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bab26c200fff9c62644849b18e55f060fa8783f --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Paged decode split-KV scheduling backed by the precompiled Torch op. + +The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is +built ahead of time by kernel-builder. The op returns the schedule arrays +plus a fixed-order scalar summary, which is reassembled into the schedule +dict here. +""" + +from __future__ import annotations + +import torch + +from ....._ops import ops + +# Order of the scalar summary returned by the op; must match +# csrc/build_decode_schedule.cu. +_SCALAR_KEYS = ( + "split_kv", + "cta_tile_q", + "num_q_tiles", + "kv_chunk_size_pages", + "kv_chunk_size_tokens", + "work_count", + "padded_work_count", + "partial_rows", + "max_split_count", + "max_grid_size", + "active_blocks_per_sm", + "num_sms", + "base_cta", +) + + +def build_decode_schedule( + seqused_k: torch.Tensor, + *, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: int = 0, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> dict[str, object]: + """GPU-only schedule build: single CUDA kernel produces all schedule + index arrays on device. Only a small summary tensor is D2H'd at the end + so the wrapper can size O_partial, pick the kernel grid, and choose + split/non-split compile path. + + ``max_seqlen_k`` is required as the host-side worst-case bound for + padding the work-tile arrays. + """ + + ( + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + kv_pages, + merge_indptr, + o_indptr, + scalars, + ) = ops.build_decode_schedule( + seqused_k, + int(page_size), + int(seqlen_q), + int(num_qo_heads), + int(num_kv_heads), + int(head_dim), + int(max_seqlen_k), + bool(enable_cuda_graph), + int(max_grid_size), + int(fixed_split_size), + bool(disable_split_kv), + ) + + raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars))) + raw["split_kv"] = bool(raw["split_kv"]) + raw["request_indices"] = request_indices + raw["qo_tile_indices"] = qo_tile_indices + raw["kv_tile_indices"] = kv_tile_indices + raw["block_valid_mask"] = block_valid_mask + raw["split_counts"] = split_counts + raw["kv_pages"] = kv_pages + raw["merge_indptr"] = merge_indptr + raw["o_indptr"] = o_indptr + + # The CUDA kernel writes into worst-case-padded buffers (size = + # batch * num_q_tiles * max_pages_global) but only the first + # ``padded_work_count`` entries are valid. Downstream consumers + # (tile_scheduler) take grid size from ``request_indices.shape[0]`` + # so we narrow the views to that count; the underlying allocation + # is unchanged so this is a view, no copy. + pad = int(raw["padded_work_count"]) + for key in ( + "request_indices", + "qo_tile_indices", + "kv_tile_indices", + "block_valid_mask", + ): + raw[key] = raw[key].narrow(0, 0, pad) + return raw + + +__all__ = ["build_decode_schedule"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..3d308bd26c281e744cc7289b1265d8192c1f39e7 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py @@ -0,0 +1,680 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""LDGSTS split-KV combine for paged decode attention.""" + +import math +from functools import partial +from typing import Type + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.cute.nvgpu import cpasync + +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map + + +class SparseDecodeForwardCombine: + """Combine split-KV decode partials with FA-style LDGSTS staging. + + ``mO_partial`` and ``mLSE_partial`` use the split-major padded layout: + ``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where + ``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``. + A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one + ``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM + via ``cpasync.CopyG2SOp`` before the split reduction. + """ + + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + *, + tile_m: int = 64, + k_block_size: int = 128, + max_splits: int = 4, + num_threads: int = 256, + stages: int = 2, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}" + ) + if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]: + raise TypeError(f"Unsupported output dtype: {dtype}") + if dtype_partial is not Float32: + raise TypeError("decode O_partial must be Float32") + if k_block_size != head_dim: + raise NotImplementedError("decode combine currently uses one D=128 k block") + if tile_m % 8 != 0: + raise ValueError("decode combine tile_m must be divisible by 8") + if max_splits < 1 or max_splits > 256: + raise ValueError("decode combine max_splits must be in [1, 256]") + + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.max_splits = max_splits + self.num_threads = num_threads + self.stages = stages + self.is_even_k = head_dim % k_block_size == 0 + + def _setup_attributes(self) -> None: + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 + if self.k_block_size % 128 == 0 + else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout + ) + + lse_copy_bits = Float32.width + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, cute.make_layout(1) + ) + + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.max_splits, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) + ) + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32 + mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] + mLSE: cute.Tensor, # [total_q, Hq] fp32 + seqlen_q: Int32, + q_tokens_per_group: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mO_partial.element_type is not Float32): + raise TypeError("decode O_partial tensor must be Float32") + if const_expr(mLSE_partial.element_type is not Float32): + raise TypeError("decode LSE_partial tensor must be Float32") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode LSE tensor must be Float32") + if const_expr(mO.element_type != self.dtype): + raise TypeError("decode O tensor dtype must match kernel dtype") + if const_expr(mSplitCounts.element_type is not Int32): + raise TypeError("decode split_counts tensor must be Int32") + if const_expr(mOIndptr.element_type is not Int32): + raise TypeError("decode o_indptr tensor must be Int32") + + mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [ + assume_tensor_aligned(t) + for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE) + ] + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.tile_m], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + total_q = mO.shape[0] + head_q = mO.shape[1] + batch = mSplitCounts.shape[0] + head_divmod = FastDivmodDivisor(head_q) + grid = ( + cute.ceil_div(seqlen_q * head_q, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mSplitCounts, + mOIndptr, + mO, + mLSE, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + head_divmod, + Int32(total_q), + Int32(head_q), + seqlen_q, + q_tokens_per_group, + ).launch( + grid=grid, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + head_divmod: FastDivmodDivisor, + total_q: Int32, + head_q: Int32, + seqlen_q: Int32, + q_tokens_per_group: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + + split_count = mSplitCounts[batch_idx] + q_stride = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + max_idx = seqlen_q * head_q + + if m_block * Int32(self.tile_m) < max_idx: + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + partial_base = mOIndptr[batch_idx] + q_idx + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < split_count: + partial_row = partial_base + si * q_stride + lse_ptr = ( + mLSE_partial.iterator + + Int64(partial_row) * Int64(head_q) + + Int64(q_head) + ) + lse_gmem_ptr = cute.make_ptr( + Float32, + lse_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + lse_src = cute.make_tensor(lse_gmem_ptr, (1,)) + cute.copy( + gmem_thr_copy_LSE, + lse_src, + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOqidx = cute.make_rmem_tensor(num_rows, Int32) + tOhidx = cute.make_rmem_tensor(num_rows, Int32) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] + idx = m_block * Int32(self.tile_m) + mi + if idx >= max_idx: + tOqidx[m] = Int32(0) + tOhidx[m] = -Int32(1) + else: + tOqidx[m], tOhidx[m] = divmod(idx, head_divmod) + + load_O_partial = partial( + self.load_O_partial, + mO_partial, + mOIndptr, + gmem_tiled_copy_O_partial, + tOsO_partial, + tOqidx, + tOhidx, + tOcO, + batch_idx, + q_stride, + split_count, + head_q, + k_block, + ) + + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < split_count: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + max_valid_idx = -Int32(1) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + + lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max + LOG2_E = Float32(math.log2(math.e)) + lse_sum_cur = Float32(0.0) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + (ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E, + fastmath=True, + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = ( + Float32(0.0) + if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur) + else cute.arch.rcp_approx(lse_sum_cur) + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + if mi < Int32(self.tile_m): + sMaxValidSplit[mi] = max_valid_split[m] + + if k_block == Int32(0): + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + q_abs = batch_idx * seqlen_q + q_idx + mLSE[q_abs, q_head] = lse_sum[m] + + cute.arch.sync_threads() + + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max( + thr_max_valid_split, + sMaxValidSplit[tOcO[0, m, 0][0]], + ) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(Float32(0.0)) + + stage_load = self.stages - 1 + stage_compute = 0 + for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4): + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] + + split_to_load = s + Int32(self.stages - 1) + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0): + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0): + q_abs = batch_idx * seqlen_q + tOqidx[m] + row_ptr = ( + mO.iterator + + ( + (Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + mO_partial: cute.Tensor, + mOIndptr: cute.Tensor, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOsO_partial: cute.Tensor, + tOqidx: cute.Tensor, + tOhidx: cute.Tensor, + tOcO: cute.Tensor, + batch_idx: Int32, + q_stride: Int32, + split_count: Int32, + head_q: Int32, + k_block: Int32, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= Int32(0): + if split < split_count: + partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m] + row_ptr = ( + mO_partial.iterator + + ( + (Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO_partial.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_partial_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_partial_row_copy = cute.tiled_divide( + mO_partial_row, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_row_copy[None, k_idx], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, None].fill(Float32(0.0)) + + +_combine_compile_cache: dict[tuple[object, ...], object] = {} + + +def _next_power_of_2(x: int) -> int: + return 1 << (max(int(x), 1) - 1).bit_length() + + +def run_decode_combine( + O_partial: torch.Tensor, + LSE_partial: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + *, + seqlen_q: int, + q_tokens_per_group: int, + max_split_count: int, +) -> None: + """Launch LDGSTS decode split-KV combine.""" + + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + if lse.dtype != torch.float32: + raise TypeError(f"lse must be torch.float32, got {lse.dtype}") + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_indptr.dtype != torch.int32: + raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}") + if out.ndim != 3 or O_partial.ndim != 3: + raise ValueError("decode combine expects O tensors with shape [rows, heads, D]") + if LSE_partial.ndim != 2 or lse.ndim != 2: + raise ValueError("decode combine expects LSE tensors with shape [rows, heads]") + if out.shape[1:] != O_partial.shape[1:]: + raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}") + if lse.shape != out.shape[:2]: + raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}") + if LSE_partial.shape != O_partial.shape[:2]: + raise ValueError( + f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}" + ) + if split_counts.ndim != 1 or o_indptr.ndim != 1: + raise ValueError("split_counts and o_indptr must be rank-1 tensors") + if o_indptr.shape != (split_counts.shape[0] + 1,): + raise ValueError( + f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)" + ) + seqlen_q = int(seqlen_q) + q_tokens_per_group = int(q_tokens_per_group) + if seqlen_q <= 0: + raise ValueError("seqlen_q must be positive") + if q_tokens_per_group <= 0: + raise ValueError("q_tokens_per_group must be positive") + if out.shape[0] != split_counts.shape[0] * seqlen_q: + raise ValueError( + f"out rows {out.shape[0]} must equal batch*seqlen_q " + f"{split_counts.shape[0]}*{seqlen_q}" + ) + + max_split_count = int(max_split_count) + if max_split_count <= 0: + raise ValueError("max_split_count must be positive") + if max_split_count > 256: + raise NotImplementedError( + f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}" + ) + max_splits = max(4, _next_power_of_2(max_split_count)) + tile_m = 64 + k_block_size = int(out.shape[-1]) + stages = 2 + + dtype = torch2cute_dtype_map[out.dtype] + key = ( + "decode_combine_ldgsts", + out.shape[-1], + dtype, + O_partial.dtype, + seqlen_q, + q_tokens_per_group, + tile_m, + k_block_size, + max_splits, + stages, + ) + if key not in _combine_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + partial_rows = cute.sym_int64() + head_q = cute.sym_int64() + head_dim = int(out.shape[-1]) + kernel = SparseDecodeForwardCombine( + dtype=dtype, + dtype_partial=Float32, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + max_splits=max_splits, + stages=stages, + ) + _combine_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4), + make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + Int32(seqlen_q), + Int32(q_tokens_per_group), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"): + _combine_compile_cache[key]( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q, + q_tokens_per_group, + ) + + +__all__ = ["SparseDecodeForwardCombine", "run_decode_combine"] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..13b487402bf52d008b7ff7edbe9d584f366256b9 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Decode-specific tile scheduler for paged fp8 attention. + +The pre-schedule step builds a dense worklist over decode KV chunks. Static +persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id. +CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``, +and maps the canceled CTA coordinate back to the same logical task space. +""" + +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ....quack.cute_dsl_utils import ParamsBase + +from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo + + +@dataclass +class DecodeTileSchedulerArguments(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + +class DecodeTileScheduler: + """Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks.""" + + @dataclass + class Params(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + num_heads_kv_divmod: FastDivmodDivisor + total_tasks: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + def __init__( + self, + params: Params, + task_idx: Int32, + clc_scheduler=None, + clc_pipeline=None, + clc_consumer_state=None, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ): + self.params = params + self._task_idx = task_idx + self._clc_scheduler = clc_scheduler + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + self._clc_response_ptr = clc_response_ptr + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: DecodeTileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1" + total_tasks = args.work_capacity * args.num_heads_kv + return DecodeTileScheduler.Params( + args.work_capacity, + args.num_heads_kv, + FastDivmodDivisor(args.num_heads_kv), + total_tasks, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + @staticmethod + def _clc_grid_shape(params: Params): + return ( + cute.round_up(params.work_capacity, params.cluster_shape_m), + params.num_heads_kv, + Int32(1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ) -> "DecodeTileScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + from cutlass.utils import ( + ClcDynamicPersistentTileScheduler, + ClcDynamicPersistentTileSchedulerParams, + ) + + cutlass_params = ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + block_idx = cute.arch.block_idx() + grid_dim = cute.arch.grid_dim() + clc_scheduler = ClcDynamicPersistentTileScheduler.create( + cutlass_params, + block_idx, + grid_dim, + clc_response_ptr, + ) + return DecodeTileScheduler( + params, + block_idx[0], + clc_scheduler, + clc_response_ptr=clc_response_ptr, + loc=loc, + ip=ip, + ) + + if const_expr(params.cluster_shape_m == 1): + task_idx = cute.arch.block_idx()[0] + else: + task_idx = cute.arch.cluster_idx()[0] + return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return DecodeTileScheduler._clc_grid_shape(params) + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m + grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + @cute.jit + def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo: + work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod) + return WorkTileInfo( + (Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)), + is_valid, + ) + + @cute.jit + def _clc_work_to_coords(self, work) -> WorkTileInfo: + work_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + work_idx = work_idx // self.params.cluster_shape_m + return WorkTileInfo( + ( + Int32(work_idx), + Int32(work.tile_idx[1]), + Int32(0), + Int32(0), + ), + work.is_valid_tile, + ) + + @cute.jit + def _clc_response_to_work( + self, + response_stage: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + # CLC responses are 16B opaque records. The scheduler warp can query + # the next stage before all consumer warps have read the current one, + # so each pipeline stage needs its own response slot. + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response( + response_ptr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32( + self.params.cluster_shape_m) + return WorkTileInfo( + ( + Int32(m_idx) + cta_idx_in_cluster, + Int32(n_idx), + Int32(l_idx), + Int32(0), + ), + is_valid, + ) + + @cute.jit + def get_current_work( + self, + response_stage: Int32 = Int32(0), + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_response_to_work( + response_stage, loc=loc, ip=ip) + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + is_valid = self._task_idx < self.params.total_tasks + return self._task_to_work(self._task_idx, is_valid) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_scheduler.initial_work_tile_info() + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work( + self, + *, + loc=None, + ip=None, + mbarrier_addr=None, + response_stage: Int32 = Int32(0), + ): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + assert mbarrier_addr is not None + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + with cute.arch.elect_one(): + cute.arch.issue_clc_query( + mbarrier_addr, response_ptr, loc=loc, ip=ip) + else: + assert mbarrier_addr is None + if const_expr(self.params.cluster_shape_m == 1): + self._task_idx += cute.arch.grid_dim()[0] + else: + self._task_idx += cute.arch.cluster_dim()[0] + + def consumer_advance(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + response_stage = self._clc_consumer_state.index + self._clc_pipeline.consumer_wait(self._clc_consumer_state) + work_tile = self.get_current_work(response_stage=response_stage) + self._clc_pipeline.consumer_release(self._clc_consumer_state) + self._clc_consumer_state.advance() + return work_tile + self.advance_to_next_work() + return self.get_current_work() + + def set_clc_pipeline(self, clc_pipeline, clc_consumer_state): + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return DecodeTileScheduler(*obj_list, loc=self._loc) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py new file mode 100644 index 0000000000000000000000000000000000000000..8e59b3d55bd3e9b164dac1e474dd648501c1aa51 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse k2q CSR builder for SM100. + +Thin dispatcher that calls the CUDA C++ kernel pipeline in +``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and +``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than +silently falling back to a torch-reference path. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL + + +_SUPPORTED_TOPK = (4, 8, 16, 32) +_SUPPORTED_BLK_KV = 128 + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +class SparseK2qCsrBuilderSm100: + """Build the k2q CSR reverse index for sparse attention on SM100. + + The public API matches the historical CUTE DSL builder so callers + (``sparse_index_utils.build_k2q_csr``, attention kernels) need no + changes. Internally the kernel pipeline runs five CUDA C++ kernels: + ``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem`` + -> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``). + """ + + def __init__(self) -> None: + # No persistent state — the JIT-compiled extension is loaded + # lazily by ``src.sm100.build_k2q_csr`` on first call. + self._run = None + self._run_with_schedule = None + + def _ensure_loaded(self) -> None: + if self._run is None: + from ...src.sm100.build_k2q_csr import ( + run_build_k2q_csr, + run_build_k2q_csr_with_schedule, + ) + self._run = run_build_k2q_csr + self._run_with_schedule = run_build_k2q_csr_with_schedule + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + *, + total_k: int, + blk_kv: int = 128, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]: + # ---- Validation ---------------------------------------------------- + if blk_kv != _SUPPORTED_BLK_KV: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports blk_kv == " + f"{_SUPPORTED_BLK_KV}, got {blk_kv}" + ) + if q2k_indices.dtype != torch.int32: + raise TypeError( + f"q2k_indices must be torch.int32, got {q2k_indices.dtype}" + ) + if q2k_indices.ndim != 3: + raise ValueError( + f"q2k_indices must be rank-3 [head_kv, total_q, topK], " + f"got shape {tuple(q2k_indices.shape)}" + ) + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous") + if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32") + if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must share shape [B + 1]" + ) + if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda): + raise ValueError("all inputs must be CUDA tensors") + if ( + q2k_indices.device != cu_seqlens_q.device + or q2k_indices.device != cu_seqlens_k.device + ): + raise ValueError("all inputs must share a device") + if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous(): + raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous") + + total_k = int(total_k) + if total_k < 0: + raise ValueError(f"total_k must be non-negative, got {total_k}") + + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + if topk not in _SUPPORTED_TOPK: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports topK in " + f"{_SUPPORTED_TOPK}, got {topk}" + ) + + batch = int(cu_seqlens_q.shape[0] - 1) + if batch < 0: + raise ValueError("cu_seqlens tensors must have shape [B + 1]") + if return_schedule and max_seqlen_k is None: + raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True") + max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k + max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv) + if total_rows is not None: + total_rows = int(total_rows) + elif total_k % blk_kv == 0: + total_rows = total_k // blk_kv + else: + total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv) + if total_rows < 0: + raise ValueError(f"total_rows must be non-negative, got {total_rows}") + total_rows = max(total_rows, 0) + nnz_upper_bound = total_q * topk + qhead_per_kv = int(qhead_per_kv) + if qhead_per_kv <= 0: + raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}") + if return_schedule: + if max_seqlen_q is None: + raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True") + max_seqlen_q = int(max_seqlen_q) + + # ---- Output tensors ------------------------------------------------ + device = q2k_indices.device + k2q_row_ptr = torch.empty( + (head_kv, total_rows + 1), dtype=torch.int32, device=device, + ) + k2q_q_indices = torch.empty( + (head_kv, nnz_upper_bound), dtype=torch.int32, device=device, + ) + schedule = None + if return_schedule: + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), dtype=torch.int32, device=device + ) + work_count = torch.empty((1,), dtype=torch.int32, device=device) + qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.empty( + (total_q, head_kv), dtype=torch.int32, device=device + ) + schedule = SparseAttentionSchedule( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + qsplit_indices=qsplit_indices, + split_counts=split_counts, + target_q_per_cta=target_q_per_cta, + ) + + # Empty workload short-circuit (the CUDA path also handles this, + # but doing it here saves a JIT load for trivial calls). + if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0: + k2q_row_ptr.zero_() + k2q_q_indices.fill_(-1) + if schedule is not None: + schedule.work_count.zero_() + schedule.split_counts.zero_() + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices + + self._ensure_loaded() + with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"): + if schedule is None: + self._run( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + topk, + blk_kv, + total_rows, + max_kv_blocks, + ) + else: + self._run_with_schedule( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + schedule.scheduler_metadata, + schedule.work_count, + schedule.qsplit_indices, + schedule.split_counts, + topk, + blk_kv, + total_rows, + max_kv_blocks, + schedule.target_q_per_cta, + schedule.work_capacity, + max_seqlen_q, + ) + if schedule is not None: + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices diff --git a/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..662e48f905249913a381f5d11a3f0c49626e98bd --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Prepare scheduler for SM100 sparse attention. + +The scheduler converts uneven CSR k2q row fanout into a flat worklist consumed +by sparse attention kernels. Each work item covers a contiguous q-index range +within one (head_kv, csr row) and carries the decoded batch/KV-block coordinate. +""" + +from dataclasses import dataclass +from typing import Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32, const_expr + +from ...src.common import copy_utils, utils +from ...src.common.cute_dsl_utils import ( + assume_tensor_aligned, + to_cute_tensor as to_cute_tensor_kvouter, +) + + +_PREPARE_COMPILE_CACHE: dict = {} + + +@dataclass +class SparseAttentionSchedule: + enabled: bool + scheduler_metadata: Optional[torch.Tensor] + work_count: Optional[torch.Tensor] + qsplit_indices: Optional[torch.Tensor] = None + split_counts: Optional[torch.Tensor] = None + target_q_per_cta: int = 0 + + @property + def work_capacity(self) -> int: + return 0 if self.scheduler_metadata is None else int(self.scheduler_metadata.shape[0]) + + +SparseSchedulePlan = SparseAttentionSchedule + + +class SparseAttentionScheduleModel: + """Host-side helpers for sparse attention schedule sizing.""" + + @staticmethod + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + @staticmethod + def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + def _target_q_per_cta( + self, + *, + total_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + num_sm = torch.cuda.get_device_properties(device).multi_processor_count + if usable_SM_count > 0: + num_sm = min(int(usable_SM_count), num_sm) + q_tokens_per_group = 128 // qhead_per_kv + total_refs_upper = total_q * topk * head_kv + desired_work_items = max(num_sm * 2, 1) + total_groups_upper = self._ceil_div(max(total_refs_upper, 1), q_tokens_per_group) + target_groups_per_cta = min( + 512, + max(1, self._ceil_div(total_groups_upper, desired_work_items)), + ) + return target_groups_per_cta * q_tokens_per_group + + def balanced_target_q_per_cta( + self, + *, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + q_tokens_per_group = 128 // qhead_per_kv + occupancy_target = self._target_q_per_cta( + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + sink_balance_cap = max(q_tokens_per_group, int(topk) * int(blk_kv) * 2) + target = min(max(occupancy_target, q_tokens_per_group), sink_balance_cap) + return self._round_up(target, q_tokens_per_group) + + def flat_schedule_capacity( + self, + *, + total_rows: int, + total_q: int, + topk: int, + head_kv: int, + target_q_per_cta: int, + ) -> int: + row_upper = max(total_rows, 0) * max(head_kv, 1) + refs_upper = max(total_q, 0) * max(topk, 1) * max(head_kv, 1) + split_upper = self._ceil_div(max(refs_upper, 1), max(target_q_per_cta, 1)) + return max(1, row_upper + split_upper) + + +SPARSE_SCHEDULE_MODEL = SparseAttentionScheduleModel() + + +class SparseAttentionPrepareFlatScheduleSm100: + """Build a compact flat worklist by splitting each CSR row into chunks.""" + + def __init__( + self, + *, + num_threads: int = 128, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + self.warps_per_cta = num_threads // 32 + + @cute.jit + def _emit_work( + self, + mSchedulerMetadata: cute.Tensor, + work_idx: Int32, + work_capacity: Int32, + head_kv_idx: Int32, + row_linear: Int32, + q_begin: Int32, + q_count: Int32, + batch_idx: Int32, + kv_block_idx: Int32, + ): + if work_idx < work_capacity: + mSchedulerMetadata[work_idx, Int32(0)] = head_kv_idx + mSchedulerMetadata[work_idx, Int32(1)] = row_linear + mSchedulerMetadata[work_idx, Int32(2)] = q_begin + mSchedulerMetadata[work_idx, Int32(3)] = q_count + mSchedulerMetadata[work_idx, Int32(4)] = batch_idx + mSchedulerMetadata[work_idx, Int32(5)] = kv_block_idx + + @cute.jit + def _rows_in_batch( + self, + mCuSeqlensK: cute.Tensor, + batch_idx: Int32, + blk_kv: Int32, + ) -> Int32: + seqlen = mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + return (seqlen + blk_kv - Int32(1)) // blk_kv + + @cute.jit + def _rows_before_level( + self, + mCuSeqlensK: cute.Tensor, + level: Int32, + blk_kv: Int32, + ) -> Int32: + total = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + total += cutlass.min(rows, level) + return total + + @cute.jit + def _max_rows_per_batch( + self, + mCuSeqlensK: cute.Tensor, + blk_kv: Int32, + ) -> Int32: + max_rows = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + max_rows = cutlass.max(max_rows, rows) + return max_rows + + @cute.jit + def _decode_sparse_row_linear( + self, + mCuSeqlensK: cute.Tensor, + row_linear: Int32, + blk_kv: Int32, + ) -> tuple[Int32, Int32]: + lo = Int32(0) + hi = self._max_rows_per_batch(mCuSeqlensK, blk_kv) + while lo < hi: + mid = (lo + hi) // Int32(2) + rows_before_next = self._rows_before_level( + mCuSeqlensK, + mid + Int32(1), + blk_kv, + ) + if rows_before_next <= row_linear: + lo = mid + Int32(1) + else: + hi = mid + + level = lo + offset = row_linear - self._rows_before_level(mCuSeqlensK, level, blk_kv) + active_idx = Int32(0) + batch_idx = Int32(0) + found = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + if found == Int32(0): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + if rows > level: + if active_idx == offset: + batch_idx = b + found = Int32(1) + active_idx += Int32(1) + return batch_idx, level + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + blk_kv: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mCuSeqlensK.element_type != Int32): + raise TypeError("mCuSeqlensK must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount = [ + assume_tensor_aligned(t) + for t in (mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount) + ] + total_rows = mK2qCounts.shape[1] - Int32(1) + total_row_heads = total_rows * num_heads_kv + grid_ctas = cute.ceil_div(total_row_heads, self.warps_per_cta) + + self.kernel( + mK2qCounts, + mCuSeqlensK, + mSchedulerMetadata, + mWorkCount, + target_q_per_cta, + work_capacity, + num_heads_kv, + total_rows, + blk_kv, + ).launch( + grid=(grid_ctas,), + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + total_rows: Int32, + blk_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + lane_idx = tidx % Int32(32) + warp_idx = tidx // Int32(32) + row_head_idx = block_idx * Int32(self.warps_per_cta) + warp_idx + total_row_heads = total_rows * num_heads_kv + + head_kv_idx = Int32(0) + row_linear = Int32(0) + row_count = Int32(0) + num_chunks = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + if row_head_idx < total_row_heads: + row_linear = row_head_idx // num_heads_kv + head_kv_idx = row_head_idx - row_linear * num_heads_kv + if lane_idx == Int32(0): + row_start = mK2qCounts[head_kv_idx, row_linear] + row_end = mK2qCounts[head_kv_idx, row_linear + Int32(1)] + row_count = row_end - row_start + batch_idx, kv_block_idx = self._decode_sparse_row_linear( + mCuSeqlensK, + row_linear, + blk_kv, + ) + if row_count > Int32(0): + num_chunks = ( + row_count + target_q_per_cta - Int32(1) + ) // target_q_per_cta + row_count = cute.arch.shuffle_sync(row_count, offset=0) + num_chunks = cute.arch.shuffle_sync(num_chunks, offset=0) + batch_idx = cute.arch.shuffle_sync(batch_idx, offset=0) + kv_block_idx = cute.arch.shuffle_sync(kv_block_idx, offset=0) + + chunk_idx = lane_idx + while chunk_idx < num_chunks: + work_idx = cute.arch.atomic_add( + mWorkCount.iterator.llvm_ptr, + Int32(1), + sem="relaxed", + scope="gpu", + ) + q_begin = chunk_idx * target_q_per_cta + q_count = cutlass.min(target_q_per_cta, row_count - q_begin) + self._emit_work( + mSchedulerMetadata, + work_idx, + work_capacity, + head_kv_idx, + row_linear, + q_begin, + q_count, + batch_idx, + kv_block_idx, + ) + chunk_idx += Int32(32) + + +class SparseAttentionPrepareFwdSplitAtomicSm100: + """Build packed q_idx/split_slot metadata for fwd K1 without K1 atomics.""" + + def __init__( + self, + *, + num_threads: int = 256, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + + @cute.struct + class SharedStorage: + sRow: cute.struct.MemRange[Int32, 3] + + self.shared_storage = SharedStorage + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + work_capacity: Int32, + max_seqlen_q: Int32, + topk: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mK2qIndices.element_type != Int32): + raise TypeError("mK2qIndices must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + if const_expr(mK2qQSplitIndices.element_type != Int32): + raise TypeError("mK2qQSplitIndices must be Int32") + if const_expr(mSplitCounts.element_type != Int32): + raise TypeError("mSplitCounts must be Int32") + if const_expr(mCuSeqlensQ.element_type != Int32): + raise TypeError("mCuSeqlensQ must be Int32") + ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) = [ + assume_tensor_aligned(t) + for t in ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) + ] + self.kernel( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + max_seqlen_q, + topk, + ).launch( + grid=(work_capacity,), + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + max_seqlen_q: Int32, + topk: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + if block_idx < mWorkCount[Int32(0)]: + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sRow = storage.sRow.get_tensor(cute.make_layout((3,))) + head_kv_idx = mSchedulerMetadata[block_idx, Int32(0)] + row_linear = mSchedulerMetadata[block_idx, Int32(1)] + q_begin = mSchedulerMetadata[block_idx, Int32(2)] + q_count = mSchedulerMetadata[block_idx, Int32(3)] + batch_idx_t0 = mSchedulerMetadata[block_idx, Int32(4)] + + if tidx == Int32(0): + row_start_t0 = mK2qCounts[head_kv_idx, row_linear] + q_begin + sRow[0] = row_start_t0 + sRow[1] = q_count + sRow[2] = batch_idx_t0 + cute.arch.barrier() + row_start = sRow[0] + row_count = sRow[1] + batch_idx = sRow[2] + qi = tidx + while qi < row_count: + edge = row_start + qi + q_idx = mK2qIndices[head_kv_idx, edge] + if q_idx >= Int32(0) and q_idx < max_seqlen_q: + q_abs = mCuSeqlensQ[batch_idx] + q_idx + split_ptr = utils.elem_pointer( + mSplitCounts, + (q_abs, head_kv_idx), + ) + split_slot = copy_utils.atomic_add_i32(split_ptr) + if split_slot < topk: + mK2qQSplitIndices[head_kv_idx, edge] = ( + q_idx | ((split_slot & Int32(0xFF)) << Int32(24)) + ) + qi += Int32(self.num_threads) + + +def _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + work_capacity: int, + max_seqlen_q: int, + topk: int, +): + key = ( + "sparse_prepare_fwd_split_atomic_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFwdSplitAtomicSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(split_counts), + to_cute_tensor_kvouter(cu_seqlens_q), + Int32(work_capacity), + Int32(max_seqlen_q), + Int32(topk), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def _get_sparse_prepare_flat_schedule( + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + target_q_per_cta: int, + scheduler_metadata_capacity: int, + head_kv: int, + blk_kv: int, +): + key = ( + "sparse_prepare_flat_schedule_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFlatScheduleSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(cu_seqlens_k), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + Int32(target_q_per_cta), + Int32(scheduler_metadata_capacity), + Int32(head_kv), + Int32(blk_kv), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def prepare_sparse_flat_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + if not enabled: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + + total_rows = int(k2q_row_ptr.shape[1] - 1) + if total_rows <= 0 or head_kv <= 0: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), + dtype=torch.int32, + device=device, + ) + work_count = torch.zeros((1,), dtype=torch.int32, device=device) + scheduler_metadata.zero_() + + compiled_prepare = _get_sparse_prepare_flat_schedule( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFlatSchedule"): + compiled_prepare( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + + return SparseSchedulePlan( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + target_q_per_cta=target_q_per_cta, + ) + +def prepare_sparse_fwd_schedule_and_split( + *, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + max_seqlen_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + blk_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + plan = prepare_sparse_fwd_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=blk_kv, + device=device, + enabled=enabled, + usable_SM_count=usable_SM_count, + ) + if not plan.enabled: + return plan + if plan.scheduler_metadata is None or plan.work_count is None: + raise RuntimeError("fwd GPU schedule requires metadata") + if topk > 255: + raise ValueError(f"packed qsplit metadata supports topK <= 255, got {topk}") + if max_seqlen_q >= (1 << 24): + raise ValueError( + "packed qsplit metadata supports batch-local q_idx < 2^24, " + f"got max_seqlen_q={max_seqlen_q}" + ) + if k2q_qsplit_indices.shape != k2q_q_indices.shape: + raise ValueError("k2q_qsplit_indices shape must match k2q_q_indices") + if split_counts.dtype != torch.int32 or k2q_qsplit_indices.dtype != torch.int32: + raise TypeError("split metadata tensors must be torch.int32") + if split_counts.shape != (total_q, head_kv): + raise ValueError( + f"split_counts must have shape ({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if cu_seqlens_q.dtype != torch.int32: + raise TypeError("cu_seqlens_q must be torch.int32") + if cu_seqlens_q.ndim != 1 or not cu_seqlens_q.is_contiguous(): + raise ValueError("cu_seqlens_q must be a contiguous rank-1 tensor") + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + with torch.cuda.nvtx.range("SparseAttention_InitFwdSplitState"): + split_counts.zero_() + + compiled_split = _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFwdSplit_Atomic"): + compiled_split( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + plan.qsplit_indices = k2q_qsplit_indices + plan.split_counts = split_counts + return plan + + +def prepare_sparse_fwd_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + return prepare_sparse_flat_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=int(total_q), + topk=int(topk), + blk_kv=int(blk_kv), + head_kv=int(head_kv), + qhead_per_kv=int(qhead_per_kv), + device=device, + enabled=bool(enabled), + usable_SM_count=int(usable_SM_count), + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7d5e4ade468de366bb73eed0ccb38d4e358cf8 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100. + +Hub-kernel packaging of the CuTe-DSL sparse attention stack from +https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The +host-side helper kernels (CSR builder, decode scheduler) are precompiled +Torch ops; the attention kernels are compiled at runtime through +nvidia-cutlass-dsl. +""" + +# Sparse attention forward / decode. +from .interface import ( + SparseDecodePagedAttentionWrapper, + sparse_atten_func, + sparse_atten_nvfp4_kv_func, + sparse_decode_atten_func, +) + +# CSR + schedule construction. +from .sparse_index_utils import build_k2q_csr + +# SM100 fused CSR builder. +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + +# FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK +# selection + q2k construction remain caller-owned downstream steps. +from .fp4_indexer_interface import fp4_indexer_block_scores + +# NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention. +from .quantize import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + quantize_bf16_to_nvfp4_128x4, + quantize_kv_bf16_to_nvfp4_128x4, + swizzle_nvfp4_scale_to_128x4, +) + +__version__ = "0.1.1" + +__all__ = [ + # attention + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "SparseDecodePagedAttentionWrapper", + # indexing / CSR + "fp4_indexer_block_scores", + "build_k2q_csr", + "SparseK2qCsrBuilderSm100", + # nvfp4 quantization helpers + "Nvfp4QuantizedTensor", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "dequantize_nvfp4_128x4_to_bf16", + "swizzle_nvfp4_scale_to_128x4", + "nvfp4_global_scale_from_amax", +] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so b/build/torch212-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d694ef1868e197758b3b3d3f22869d5ab0123420 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_msa_cuda_09d7851.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4455123d032155679a0babc8680bd748e01f082d067086fb40b2a1e0f9feb83 +size 1379256 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/_ops.py b/build/torch212-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6be2da4d5d784683e9e2fb8bfe08e93847dc6640 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _msa_cuda_09d7851 +ops = torch.ops._msa_cuda_09d7851 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_msa_cuda_09d7851::{op_name}" diff --git a/build/torch212-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py b/build/torch212-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..48dc1d05480355d2af4f4e47142ae4cd692184b0 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/fp4_indexer_interface.py @@ -0,0 +1,1061 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Public FP4 sparse-attention indexer block-score interface.""" + +from __future__ import annotations + +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32 +from cutlass.cute.runtime import make_ptr + +from .src.sm100.fp4_indexer import ( + Fp4FormatSpec, + Fp4IndexerDecodePackedQSm100, + Fp4IndexerDecodeQPackSm100, + Fp4IndexerScaleReorderSm100, + Fp4IndexerStagedMmaSm100, + _BLOCK_K, + _DECODE_K_TILES_PER_CTA, + _DECODE_PACK_Q_LEN, + _DECODE_QHEAD_PER_KV, + _FP4_PACKED_D_BYTES, + _HEAD_DIM, + _MMA_TILER_MN, + _PAGE_SIZE, + ceil_div, + k_tiles_per_cta_for, + normalize_fp4_format, +) + + +_PUBLIC_SCALE_LAYOUT = "public" +_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma" +_FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _device_arch(device: torch.device) -> tuple[int, int]: + major, minor = torch.cuda.get_device_capability(device) + return int(major), int(minor) + + +def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool: + return device_arch >= (10, 3) + + +def normalize_scale_layout(scale_layout: str) -> str: + """Normalize and validate FP4 indexer scale layout mode. + + Parameters + ---------- + scale_layout : str + Either ``"public"`` for logical scale tensors or ``"preordered_mma"`` + for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``. + + Returns + ------- + str + The normalized scale layout string. + """ + + scale_layout = str(scale_layout) + if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT): + raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}") + return scale_layout + + +def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int: + if q_len <= 0 or k_len <= 0: + return 0 + q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0]) + k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta) + group_tokens = k_tiles_per_cta * _BLOCK_K + causal_offset = int(k_len) - int(q_len) + tasks = 0 + for q_tile_idx in range(q_tile_count): + q_tile_start = q_tile_idx * _MMA_TILER_MN[0] + q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1) + visible_limit = q_tile_last + causal_offset + if visible_limit >= 0: + tasks += min(k_group_count, visible_limit // group_tokens + 1) + return tasks + + +def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int: + """Conservative X-grid bound for per-batch causal prefill compact mapping.""" + + if max_q_len <= 0 or max_k_len <= 0: + return 0 + q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0]) + candidates = {int(max_q_len)} + for q_tile_idx in range(q_tile_count): + q_len = q_tile_idx * _MMA_TILER_MN[0] + 1 + if q_len <= max_q_len: + candidates.add(q_len) + return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates) + + +def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None: + if tensor.device != device: + raise ValueError(f"{name} must be on the same CUDA device") + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None: + fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None) + allowed = {torch.uint8, torch.int8} + if fp4_x2_dtype is not None: + allowed.add(fp4_x2_dtype) + if tensor.dtype not in allowed: + raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}") + + +def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 3: + raise ValueError(f"{name} must have shape [total_q, Hq, 64]") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 4: + raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]") + if int(tensor.shape[-2]) != _PAGE_SIZE: + raise ValueError(f"{name}.shape[-2] must be 128") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def validate_q_scale_thg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + total_q: int, + heads: int, +) -> None: + """Validate public Q FP4 scale layout ``[total_q, Hq, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical Q scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + total_q : int + Total query token count. + heads : int + Number of Q heads. + """ + + expected = (int(total_q), int(heads), fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def validate_k_scale_phsg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + page_count: int, + heads: int, +) -> None: + """Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical K scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + page_count : int + Number of physical KV pages. + heads : int + Number of KV heads. + """ + + expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``.""" + + spec = normalize_fp4_format(fp4_format) + return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l)) + + +def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g) + + +def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return contiguous storage shape for preordered MMA scales.""" + + spec = normalize_fp4_format(fp4_format) + return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4) + + +def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_storage_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1) + + +def validate_mma_scale_storage( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + mn: int, + l: int, +) -> None: + """Validate preordered MMA scale storage expected by the FP4 indexer. + + Parameters + ---------- + scale : torch.Tensor + Tensor view whose shape/stride should match + ``fp4_indexer_mma_scale_storage_shape`` and + ``fp4_indexer_mma_scale_storage_stride``. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + mn : int + Logical M/N extent of the scale domain. + l : int + Logical batch/head extent folded into the final layout dimension. + """ + + expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name) + expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name) + if tuple(scale.shape) != expected_shape: + raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}") + if tuple(scale.stride()) != expected_stride: + raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + + +def _empty_mma_scale_tensor( + *, + mn: int, + l: int, + spec: Fp4FormatSpec, + device: torch.device, +) -> torch.Tensor: + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + storage = torch.empty( + (int(l), rest_m, rest_g, 32, 4, 4), + dtype=spec.torch_scale_dtype, + device=device, + ) + return storage.permute(3, 4, 1, 5, 2, 0) + + +def _compile_fp4_scale_reorder_kernel( + *, + fmt: Fp4FormatSpec, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_scale_reorder_sm100_1cta", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_reorder_scales_for_mma_cute( + q_scale: torch.Tensor, + k_scale: torch.Tensor, + *, + fp4_format: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reorder public Q/K FP4 scales to MMA-friendly storage. + + Parameters + ---------- + q_scale : torch.Tensor + Public Q scale tensor with shape ``[total_q, Hq, G]``. + k_scale : torch.Tensor + Public K scale tensor with shape ``[page_count, Hk, 128, G]``. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by + ``validate_mma_scale_storage``. These tensors can be passed to + ``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``. + """ + + spec = normalize_fp4_format(fp4_format) + if q_scale.device != k_scale.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device") + _require_cuda_tensor(q_scale, name="q_scale") + _require_cuda_tensor(k_scale, name="k_scale") + if q_scale.ndim != 3: + raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}") + if k_scale.ndim != 4: + raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}") + total_q, heads_q, _ = (int(v) for v in q_scale.shape) + page_count, heads_k, _, _ = (int(v) for v in k_scale.shape) + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + + q_scale_mma = _empty_mma_scale_tensor( + mn=total_q, + l=heads_q, + spec=spec, + device=q_scale.device, + ) + k_scale_mma = _empty_mma_scale_tensor( + mn=_PAGE_SIZE, + l=page_count * heads_k, + spec=spec, + device=k_scale.device, + ) + + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + q_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + problem_size = ( + Int32(total_q), + Int32(heads_q), + Int32(page_count), + Int32(heads_k), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream) + compiled = _compile_fp4_scale_reorder_kernel( + fmt=spec, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + q_scale_mma_ptr=q_scale_mma_ptr, + k_scale_mma_ptr=k_scale_mma_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return q_scale_mma, k_scale_mma + + +def _compile_fp4_decode_q_pack_kernel( + *, + fmt: Fp4FormatSpec, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_q_pack_sm100", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _pack_decode_q_for_mma( + q_bytes: torch.Tensor, + q_scale_storage: torch.Tensor, + cu_seqlens_q: torch.Tensor, + *, + fmt: Fp4FormatSpec, + heads_q: int, + heads_k: int, + batch: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_pack = torch.empty( + (batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + dtype=torch.uint8, + device=q_bytes.device, + ) + q_scale_pack = torch.empty( + fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name), + dtype=fmt.torch_scale_dtype, + device=q_bytes.device, + ) + if q_pack.data_ptr() % 128 != 0: + raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA") + if q_scale_pack.data_ptr() % 32 != 0: + raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned") + q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(q_bytes.shape[0]), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream) + compiled = _compile_fp4_decode_q_pack_kernel( + fmt=fmt, + q_ptr=q_ptr, + q_scale_ptr=q_scale_ptr, + q_pack_ptr=q_pack_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return q_pack, q_scale_pack + + +def _compile_fp4_decode_packed_q_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_packed_q_sm100", + fmt.name, + bool(causal), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodePackedQSm100( + fmt=fmt.name, + causal=causal, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _run_fp4_decode_packed_q_scores( + q_pack: torch.Tensor, + k_bytes: torch.Tensor, + q_scale_pack: torch.Tensor, + k_scale_storage: torch.Tensor, + scores: torch.Tensor, + kv_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + qo_offset_arg: torch.Tensor, + *, + fmt: Fp4FormatSpec, + causal: bool, + has_qo_offset: int, + heads_q: int, + heads_k: int, + batch: int, + max_k_tiles: int, + total_q: int, + device_arch: tuple[int, int], + use_tmem_load_red: bool, +) -> None: + page_count = int(k_bytes.shape[0]) + rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA) + compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA) + compact_schedule = compact_groups < rectangular_groups + if compact_schedule: + scores.fill_(float("-inf")) + + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + k_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + problem_size = ( + Int32(_PAGE_SIZE), + Int32(max_k_tiles * _PAGE_SIZE), + Int32(_HEAD_DIM), + Int32(batch * heads_k), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream) + compiled = _compile_fp4_decode_packed_q_kernel( + fmt=fmt, + causal=causal, + compact_schedule=compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_pack_ptr=q_pack_ptr, + k_ptr=k_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + + +def _compile_fp4_qk_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + preordered_q_scale_tma: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_staged_mma_sm100", + fmt.name, + bool(causal), + bool(preordered_q_scale_tma), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerStagedMmaSm100( + fmt=fmt.name, + causal=causal, + preordered_q_scale_tma=preordered_q_scale_tma, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_block_scores( + q_fp4: torch.Tensor, + k_fp4: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + *, + max_seqlen_q: int, + max_seqlen_k: int, + kv_indices: torch.Tensor, + fp4_format: str, + causal: bool = False, + qo_offset: Optional[torch.Tensor] = None, + scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT, +) -> torch.Tensor: + """Return FP4 QK max scores per 128-token KV page. + + Parameters + ---------- + q_fp4 : torch.Tensor + Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last + dimension stores two FP4 values per byte for logical head dimension + 128. + k_fp4 : torch.Tensor + Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``. + q_scale : torch.Tensor + Q scale tensor. With ``scale_layout="public"``, shape is + ``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use + ``fp4_indexer_reorder_scales_for_mma_cute`` output layout. + k_scale : torch.Tensor + K scale tensor. With ``scale_layout="public"``, shape is + ``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the + preordered MMA scale layout. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + cu_page_offsets : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request + page counts. + max_seqlen_q : int + Maximum Q sequence length. + max_seqlen_k : int + Maximum KV sequence length. + kv_indices : torch.Tensor + Flattened physical page indices with shape ``[sum_pages]`` and dtype + int32. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + causal : bool, optional + Whether to apply causal masking. + qo_offset : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid + only when ``causal=True``. + scale_layout : str, optional + ``"public"`` accepts logical public scale tensors and launches a scale + reorder kernel. ``"preordered_mma"`` expects preordered MMA scale + tensors and skips the reorder. + + Returns + ------- + torch.Tensor + Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32. + Entries beyond the valid KV page range are ``-inf``. + """ + + spec = normalize_fp4_format(fp4_format) + causal = bool(causal) + scale_layout = normalize_scale_layout(scale_layout) + use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE + q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4") + k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4") + total_q, heads_q, _ = (int(v) for v in q_bytes.shape) + page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape) + if page_size != _PAGE_SIZE: + raise ValueError(f"k_fp4 page_size must be 128, got {page_size}") + if heads_q % heads_k != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + _require_cuda_tensor(q_fp4, name="q_fp4") + _require_cuda_tensor(k_fp4, name="k_fp4") + device_arch = _device_arch(q_fp4.device) + use_tmem_load_red = _supports_tmem_load_red(device_arch) + _require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device) + _require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device) + _require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device) + if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4") + if scale_layout == _PUBLIC_SCALE_LAYOUT: + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + else: + validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q) + validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k) + batch = int(cu_seqlens_q.shape[0]) - 1 + if batch < 0: + raise ValueError("cu_seqlens_q must have shape [B + 1]") + if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape: + raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]") + if q_bytes.data_ptr() % 128 != 0: + raise ValueError("q_fp4 data pointer must be 128B aligned for TMA") + if k_bytes.data_ptr() % 128 != 0: + raise ValueError("k_fp4 data pointer must be 128B aligned for TMA") + if kv_indices is None: + raise ValueError("kv_indices is required") + if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1: + raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device") + if not kv_indices.is_contiguous(): + raise ValueError("kv_indices must be contiguous") + if qo_offset is not None: + if not causal: + raise ValueError("qo_offset is only valid when causal=True") + if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,): + raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device") + if not qo_offset.is_contiguous(): + raise ValueError("qo_offset must be contiguous") + + m_extent = int(max_seqlen_q) + max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE) + n_aligned = max_k_tiles * _PAGE_SIZE + if max_k_tiles == 0: + return torch.full( + (heads_q, 0, total_q), + float("-inf"), + dtype=torch.float32, + device=q_fp4.device, + ) + + scores = torch.empty( + (heads_q, max_k_tiles, total_q), + dtype=torch.float32, + device=q_fp4.device, + ) + if qo_offset is None: + qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device) + has_qo_offset = 0 + else: + qo_offset_arg = qo_offset + has_qo_offset = 1 + if scale_layout == _PUBLIC_SCALE_LAYOUT: + q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute( + q_scale, + k_scale, + fp4_format=spec.name, + ) + else: + q_scale_arg = q_scale + k_scale_arg = k_scale + scale_assumed_align = 32 + if q_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + if k_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV + if use_decode_packed_q: + q_pack, q_scale_pack = _pack_decode_q_for_mma( + q_bytes, + q_scale_arg, + cu_seqlens_q, + fmt=spec, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + ) + _run_fp4_decode_packed_q_scores( + q_pack, + k_bytes, + q_scale_pack, + k_scale_arg, + scores, + kv_indices, + cu_seqlens_q, + cu_seqlens_k, + cu_page_offsets, + qo_offset_arg, + fmt=spec, + causal=causal, + has_qo_offset=has_qo_offset, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + max_k_tiles=max_k_tiles, + total_q=total_q, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + ) + return scores + prefill_compact_task_count = 0 + prefill_compact_schedule = False + if causal and has_qo_offset == 0: + k_tiles_per_cta = k_tiles_per_cta_for(causal) + q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0]) + k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta) + rectangular_task_count = q_tile_count * k_group_count + prefill_compact_task_count = min( + rectangular_task_count, + _causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta), + ) + prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19 + if prefill_compact_schedule: + scores.fill_(float("-inf")) + q_ptr = make_ptr( + cutlass.Uint8, + q_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + k_ptr = make_ptr( + cutlass.Uint8, + k_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + scores_ptr = make_ptr( + cutlass.Float32, + scores.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + kv_indices_ptr = make_ptr( + cutlass.Int32, + kv_indices.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_k_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_k.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_page_offsets_ptr = make_ptr( + cutlass.Int32, + cu_page_offsets.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + qo_offset_ptr = make_ptr( + cutlass.Int32, + qo_offset_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(m_extent), + Int32(n_aligned), + Int32(_HEAD_DIM), + Int32(batch * heads_q), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + Int32(prefill_compact_task_count), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream) + compiled = _compile_fp4_qk_kernel( + fmt=spec, + causal=causal, + preordered_q_scale_tma=use_preordered_q_scale_tma, + compact_schedule=prefill_compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_ptr=q_ptr, + k_ptr=k_ptr, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return scores + + +__all__ = [ + "fp4_indexer_block_scores", +] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/interface.py b/build/torch212-cxx11-cu130-x86_64-linux/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9e507961840b3322238646ffffe3e97cf5d13130 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/interface.py @@ -0,0 +1,2011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse attention interface. + +Current delivery scope: + - head dimension is supported only for D=128 + +Public API: + sparse_atten_func(...) + sparse_decode_atten_func(...) + SparseDecodePagedAttentionWrapper + +Internal forward core: + _sparse_atten_csr_varlen_forward(...) + +Preprocessing (external, done once): + q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr() + -> k2q_row_ptr [head_kv, total_rows + 1] int32 + -> k2q_q_indices [head_kv, total_q * topK] int32 +""" + +import math +import os +from typing import Optional + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 +from cutlass.cute.runtime import from_dlpack + +from .src.sm100.fwd.combine import combine +from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100 +from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 +from .src.sm100.prepare_scheduler import ( + SparseAttentionSchedule, + prepare_sparse_fwd_schedule_and_split, +) +from .src.sm100.decode_schedule import ( + DecodeAttentionSchedule, + prepare_decode_schedule, +) +from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter +from .src.common.tma_utils import ( + create_q_gather4_tma_desc, +) + +_compile_cache: dict = {} +_TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12 +_SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32) +_SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_DECODE_QHEAD_PER_KV = 16 + + +def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype: + supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn} + if partial_dtype not in supported: + raise TypeError( + "partial_dtype must be one of torch.float32 / torch.bfloat16 / " + "torch.float16 / torch.float8_e4m3fn, " + f"got {partial_dtype}" + ) + return partial_dtype + + +def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype: + dtype = fallback if dtype is None else dtype + if dtype not in _SUPPORTED_FWD_MMA_DTYPES: + raise TypeError( + f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}" + ) + return dtype + + +def _resolve_forward_mma_dtypes( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qk_dtype: Optional[torch.dtype], + pv_dtype: Optional[torch.dtype], +) -> tuple[torch.dtype, torch.dtype]: + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + if pv_dtype is None: + # Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V + # stages both K and V as BF16 compute operands. + if ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ): + pv_dtype = torch.bfloat16 + else: + pv_dtype = v.dtype + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype") + + if q.dtype != qk_dtype: + raise ValueError( + "qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported" + ) + if k.dtype != qk_dtype: + if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16): + raise ValueError( + "unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported" + ) + if v.dtype != pv_dtype: + if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16): + raise ValueError( + "unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported" + ) + return qk_dtype, pv_dtype + + +def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4): + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) + return tensor.mark_layout_dynamic(leading_dim=0) + + +def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype): + if dtype == torch.bfloat16: + return cutlass.BFloat16 + if dtype == torch.float16: + return cutlass.Float16 + if dtype == torch.float8_e4m3fn: + return cutlass.Float8E4M3FN + raise TypeError( + f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}" + ) + + +def _prepare_paged_kv_for_tma(k, v, blk_kv: int): + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + return k, v + + +def _validate_cu_seqlens( + cu_seqlens: torch.Tensor, + *, + name: str, + device: torch.device, +) -> None: + if cu_seqlens.device != device: + raise ValueError(f"{name} must be on the same device as q") + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must have shape [B + 1]") + if cu_seqlens.shape[0] < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int: + return int(k2q_row_ptr.shape[1] - 1) + + +def _validate_csr_varlen_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in _SUPPORTED_FWD_DTYPES: + raise TypeError( + "CSR sparse forward supports only torch.bfloat16 and " + f"torch.float8_e4m3fn Q/K/V, got {q.dtype}" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("q, k, v must be on the same device") + mixed_fp8_kv_bf16_q = ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ) + if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype): + raise ValueError( + "q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache" + ) + if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]: + raise ValueError("q, k, v must have the same head dimension") + dim = q.shape[-1] + if dim != 128: + raise NotImplementedError( + f"CSR sparse forward currently supports only D=128, got D={dim}" + ) + if page_table is None: + if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]: + raise ValueError("k and v must have the same [Hkv, D] tail dimensions") + head_kv = k.shape[-2] + else: + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must have the same Hkv and D" + ) + head_kv = k.shape[1] + if ( + q.device != k2q_row_ptr.device + or q.device != k2q_q_indices.device + ): + raise ValueError("CSR metadata must be on the same device as q") + if ( + k2q_row_ptr.dtype != torch.int32 + or k2q_q_indices.dtype != torch.int32 + ): + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + total_q = q.shape[0] + + head_q = q.shape[1] + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < total_q * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + total_k = k.shape[0] + if k.ndim != 3 or v.ndim != 3: + raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]") + if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]): + raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]") + else: + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2 or page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape != v.shape: + raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}") + if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must match " + "[num_pages, Hkv, page_size, D]" + ) + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError( + f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; " + "require page_size == blk_kv" + ) + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_csr_varlen_nvfp4_kv_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn): + raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}") + if q.shape[-1] != 128: + raise NotImplementedError( + f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}" + ) + if k.dtype != torch.uint8 or v.dtype != torch.uint8: + raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}") + if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8: + raise TypeError( + "KVFP4 block scales must be torch.uint8 E4M3 tensors, got " + f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}" + ) + if k_global_scale is not None and k_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None") + if v_global_scale is not None and v_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None") + tensors = ( + k, + v, + k_scale_128x4, + v_scale_128x4, + k2q_row_ptr, + k2q_q_indices, + cu_seqlens_q, + cu_seqlens_k, + ) + optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None) + if any(t.device != q.device for t in tensors + optional_tensors): + raise ValueError("KVFP4 inputs and metadata must be on the same device as q") + if k.shape != v.shape: + raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}") + packed_dim = q.shape[-1] // 2 + scale_cols = q.shape[-1] // 16 + if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2: + raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors") + if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols: + raise ValueError( + "KVFP4 block scales must have at least D/16 columns; " + f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}" + ) + if k_global_scale is not None and k_global_scale.numel() < 1: + raise ValueError("KVFP4 K global scale must contain at least one element") + if v_global_scale is not None and v_global_scale.numel() < 1: + raise ValueError("KVFP4 V global scale must contain at least one element") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + if k.ndim != 3: + raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]") + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + total_k = int(k.shape[0]) + head_kv = int(k.shape[1]) + required_scale_rows = total_k * head_kv + else: + if k.ndim != 4: + raise ValueError( + "KVFP4 Sparse Page Attention requires k/v shape " + "[num_pages, Hkv, page_size, D/2]" + ) + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError( + f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}" + ) + head_kv = int(k.shape[1]) + required_scale_rows = int(k.shape[0]) * head_kv * page_size + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + + padded_scale_rows = ((required_scale_rows + 127) // 128) * 128 + padded_scale_cols = ((scale_cols + 3) // 4) * 4 + for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)): + if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols: + raise ValueError( + f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, " + f"need at least {(padded_scale_rows, padded_scale_cols)}" + ) + + if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device: + raise ValueError("CSR metadata must be on the same device as q") + if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32: + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + if page_table is not None and page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if seqused_k is not None and seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < q.shape[0] * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_sparse_decode_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("decode attention requires q to have shape [total_q, Hq, D]") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("decode q, k, and v must be on the same device") + if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype: + raise TypeError( + "decode attention currently supports only torch.float8_e4m3fn Q/K/V" + ) + if k.shape != v.shape: + raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}") + if q.shape[-1] != 128 or k.shape[-1] != 128: + raise NotImplementedError( + f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}" + ) + if not bool(causal): + raise NotImplementedError("decode attention currently supports only causal=True") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + + head_kv = int(k.shape[1]) + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("decode q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV: + raise NotImplementedError( + "decode attention currently supports only " + f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}" + ) + + if page_table is None: + raise ValueError("decode attention requires page_table") + if page_table.device != q.device: + raise ValueError("decode page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("decode page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]") + batch = int(page_table.shape[0]) + if page_table.stride(-1) != 1: + raise ValueError("decode page_table must be contiguous in the last dimension") + + if seqused_k is None: + raise ValueError("decode attention requires seqused_k") + if seqused_k.device != q.device: + raise ValueError("decode seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("decode seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("decode seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("decode seqused_k must be contiguous") + + seqlen_q = int(seqlen_q) + max_seqlen_k = int(max_seqlen_k) + if seqlen_q <= 0 or max_seqlen_k <= 0: + raise ValueError("decode seqlen_q and max_seqlen_k must be positive") + if int(q.shape[0]) != batch * seqlen_q: + raise ValueError("decode q.shape[0] must equal batch * seqlen_q") + + if q2k_indices is not None: + if q2k_indices.device != q.device: + raise ValueError("decode q2k_indices must be on the same device as q") + if q2k_indices.dtype != torch.int32: + raise TypeError("decode q2k_indices must be torch.int32") + if q2k_indices.ndim != 3: + raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]") + if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]: + raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]") + if not q2k_indices.is_contiguous(): + raise ValueError("decode q2k_indices must be contiguous") + return batch, head_kv + + +def _validate_schedule_common( + schedule: SparseAttentionSchedule, + *, + device: torch.device, +) -> None: + if schedule.scheduler_metadata is None: + raise ValueError("schedule.scheduler_metadata is required") + if schedule.work_count is None: + raise ValueError("schedule.work_count is required") + metadata = schedule.scheduler_metadata + work_count = schedule.work_count + if metadata.device != device or work_count.device != device: + raise ValueError("schedule tensors must be on the same device as q") + if metadata.dtype != torch.int32 or work_count.dtype != torch.int32: + raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32") + if metadata.ndim != 2 or metadata.shape[1] != 6: + raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]") + if work_count.shape != (1,): + raise ValueError("schedule.work_count must have shape [1]") + if not metadata.is_contiguous() or not work_count.is_contiguous(): + raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous") + + +def _validate_fwd_schedule( + schedule: SparseAttentionSchedule, + *, + q: torch.Tensor, + k2q_q_indices: torch.Tensor, + head_kv: int, +) -> None: + _validate_schedule_common(schedule, device=q.device) + if schedule.qsplit_indices is None: + raise ValueError("schedule.qsplit_indices is required for forward") + if schedule.split_counts is None: + raise ValueError("schedule.split_counts is required for forward") + qsplit = schedule.qsplit_indices + split_counts = schedule.split_counts + if qsplit.device != q.device or split_counts.device != q.device: + raise ValueError("forward schedule tensors must be on the same device as q") + if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32: + raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32") + if qsplit.shape != k2q_q_indices.shape: + raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices") + total_q = q.shape[0] + if split_counts.shape != (total_q, head_kv): + raise ValueError( + "schedule.split_counts must have shape " + f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if not qsplit.is_contiguous() or not split_counts.is_contiguous(): + raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous") + + +def sparse_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, + usable_SM_count: int = -1, + qk_dtype: Optional[torch.dtype] = None, + pv_dtype: Optional[torch.dtype] = None, +): + """Run SM100 CSR block-sparse varlen attention. + + This is the public forward-only sparse attention API. It consumes + query-to-key block selections converted to CSR metadata by + ``build_k2q_csr`` and supports both dense KV layout and paged KV layout. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Dense layout ``[total_k, Hkv, 128]`` or paged layout + ``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K + may be FP8 E4M3 while QK compute uses BF16 staging. + v : torch.Tensor + Same layout and head count as ``k``. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + max_seqlen_q : int + Maximum Q sequence length in the batch. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + KV block size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return LSE computed with logits scaled by + ``softmax_scale / lse_temperature_scale``. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. Supported values are + FP32, BF16, FP16, and FP8 E4M3. + return_softmax_lse : bool, optional + If True, return ``(out, softmax_lse)`` or + ``(out, softmax_lse, temperature_lse)``. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Effective KV length per request + for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. If omitted, the schedule is built + during the call. + usable_SM_count : int, optional + Maximum number of SMs used by the scheduler. ``-1`` uses all SMs. + qk_dtype : torch.dtype, optional + Compile-time MMA operand dtype for QK. Defaults to Q storage dtype, + except supported FP8 K/V cache staging modes. + pv_dtype : torch.dtype, optional + Compile-time MMA operand dtype for PV. Defaults to V storage dtype, + except supported FP8 K/V cache staging modes. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + + Notes + ----- + ``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support + head dimension 128 only. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_inputs( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + max_seqlen_q = int(max_seqlen_q) + max_seqlen_k = int(max_seqlen_k) + + return _sparse_atten_csr_varlen_forward( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + int(topK), + int(blk_kv), + bool(causal), + float(softmax_scale), + lse_temperature_scale, + return_temperature_lse, + partial_dtype, + bool(return_softmax_lse), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + schedule, + int(usable_SM_count), + int(batch), + int(head_kv), + int(max_seqlen_q), + int(max_seqlen_k), + qk_dtype, + pv_dtype, + ) + + +def sparse_atten_nvfp4_kv_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Run SM100 CSR sparse attention with packed NVFP4 K/V. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged + layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8 + because each byte packs two FP4 values. + v : torch.Tensor + Packed NVFP4 V data with the same shape as ``k``. + k_scale_128x4 : torch.Tensor + K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8 + containing FP8 E4M3 scale values. + v_scale_128x4 : torch.Tensor + V block scales in the same 128x4 tiled storage. + k_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for K. May be ``None``. + v_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for V. May be ``None``. The V global + scale is applied in the combine stage. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q, cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV + lengths. + max_seqlen_q, max_seqlen_k : int + Maximum Q and KV sequence lengths in the batch. + blk_kv : int, optional + KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return temperature-scaled LSE. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. + return_softmax_lse : bool, optional + If True, return LSE together with the output. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Effective KV length per request for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + """ + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + total_q, head_q, dim = q.shape + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + + schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k_scale_128x4.contiguous(), + v_scale_128x4.contiguous(), + None if k_global_scale is None else k_global_scale.contiguous(), + None if v_global_scale is None else v_global_scale.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + k2q_qsplit_indices.contiguous(), + split_counts.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + O_partial, + LSE_partial, + LSE_temperature_partial, + float(softmax_scale), + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + int(blk_kv), + head_kv, + int(max_seqlen_q), + causal=bool(causal), + schedule=schedule, + ) + + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + output_scale=v_global_scale, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def sparse_decode_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor] = None, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = True, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + schedule: Optional[DecodeAttentionSchedule] = None, + O_partial: Optional[torch.Tensor] = None, + LSE_partial: Optional[torch.Tensor] = None, +): + """Run forward-only paged FP8 decode attention. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8 + E4M3 dtype. + v : torch.Tensor + Paged V cache with the same shape and dtype as ``k``. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype + int32. ``None`` selects the dense all-KV decode path. + page_table : torch.Tensor + Physical page table with shape ``[batch_size, max_num_pages_per_seq]`` + and dtype int32. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per request. + seqlen_q : int + Uniform query length per request. Ragged Q lengths should use prefill + or append paths instead. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + Page size. Must match ``k.shape[2]``. + causal : bool, optional + Whether to apply causal masking. Current decode kernel requires True. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + schedule : DecodeAttentionSchedule, optional + Prebuilt decode schedule. + O_partial, LSE_partial : torch.Tensor, optional + Optional split-KV partial workspaces. Normally owned by + ``SparseDecodePagedAttentionWrapper``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output with shape ``q.shape``. Optional LSE has shape + ``[batch_size * seqlen_q, Hq]`` and dtype float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + batch, head_kv = _validate_sparse_decode_inputs( + q, + k, + v, + q2k_indices, + page_table=page_table, + seqused_k=seqused_k, + seqlen_q=seqlen_q, + max_seqlen_k=max_seqlen_k, + blk_kv=blk_kv, + causal=causal, + ) + head_q = int(q.shape[1]) + head_dim = int(q.shape[2]) + if schedule is None: + schedule = prepare_decode_schedule( + seqused_k=seqused_k.contiguous(), + page_size=int(blk_kv), + seqlen_q=int(seqlen_q), + num_qo_heads=head_q, + num_kv_heads=head_kv, + head_dim=head_dim, + max_seqlen_k=int(max_seqlen_k), + ) + if schedule.split_kv: + if O_partial is None: + O_partial = torch.empty( + (schedule.partial_rows, head_q, head_dim), + dtype=torch.float32, + device=q.device, + ) + if LSE_partial is None: + LSE_partial = torch.empty( + (schedule.partial_rows, head_q), + dtype=torch.float32, + device=q.device, + ) + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + lse = torch.empty( + q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q), + dtype=torch.float32, + device=q.device, + ) + _call_sparse_decode_forward_sm100_paged_fp8( + q.contiguous(), + k.contiguous(), + v.contiguous(), + None if q2k_indices is None else q2k_indices.contiguous(), + page_table.contiguous(), + seqused_k.contiguous(), + out, + lse, + schedule, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + max_seqlen_k=int(max_seqlen_k), + blk_kv=int(blk_kv), + causal=bool(causal), + return_lse=bool(return_softmax_lse), + ) + if return_softmax_lse: + return out, lse + return out + + +class SparseDecodePagedAttentionWrapper: + """Plan/run helper for paged FP8 decode attention. + + Use this wrapper when the same page table shape and sequence metadata are + reused across multiple decode layers. ``plan`` validates metadata and + allocates persistent schedules/workspaces; ``run`` then launches the decode + kernel with lower per-call overhead than ``sparse_decode_atten_func``. + """ + + def __init__(self, *, blk_kv: int = 128, causal: bool = True): + self.blk_kv = int(blk_kv) + self.causal = bool(causal) + self.batch: Optional[int] = None + self.num_qo_heads: Optional[int] = None + self.num_kv_heads: Optional[int] = None + self.head_dim: Optional[int] = None + self.page_table: Optional[torch.Tensor] = None + self.seqused_k: Optional[torch.Tensor] = None + self.q2k_indices: Optional[torch.Tensor] = None + self.seqlen_q: Optional[int] = None + self.max_seqlen_k: Optional[int] = None + self.is_sparse: bool = False + self.decode_schedule: Optional[DecodeAttentionSchedule] = None + self.request_indices: Optional[torch.Tensor] = None + self.qo_tile_indices: Optional[torch.Tensor] = None + self.kv_tile_indices: Optional[torch.Tensor] = None + self.merge_indptr: Optional[torch.Tensor] = None + self.o_indptr: Optional[torch.Tensor] = None + self.block_valid_mask: Optional[torch.Tensor] = None + self.kv_pages: Optional[torch.Tensor] = None + self.split_counts: Optional[torch.Tensor] = None + self.split_kv: bool = False + self.cta_tile_q: int = 0 + self.num_q_tiles: int = 0 + self.kv_chunk_size_pages: int = 0 + self.kv_chunk_size_tokens: int = 0 + self.work_count: int = 0 + self.padded_work_count: int = 0 + self.O_partial: Optional[torch.Tensor] = None + self.LSE_partial: Optional[torch.Tensor] = None + # Cached dummy buffers used in non-split path to satisfy the kernel's + # positional arg signature without per-call torch.empty (saves ~5us + # on every run() for small kv). + self._O_partial_dummy: Optional[torch.Tensor] = None + self._LSE_partial_dummy: Optional[torch.Tensor] = None + # When the caller doesn't ask for LSE, the kernel still needs a valid + # tensor pointer to write to. Cache a small placeholder so run() can + # skip the per-call torch.empty for it as well. + self._lse_dummy: Optional[torch.Tensor] = None + + def plan( + self, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + q2k_indices: Optional[torch.Tensor] = None, + num_qo_heads: Optional[int] = None, + num_kv_heads: Optional[int] = None, + head_dim: Optional[int] = 128, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, + ) -> "SparseDecodePagedAttentionWrapper": + """Prepare decode scheduling metadata and reusable workspaces. + + Parameters + ---------- + page_table : torch.Tensor + Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps + logical pages to physical KV-cache pages. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per + request. + seqlen_q : int + Uniform query length per request. + max_seqlen_k : int + Maximum KV sequence length in the batch. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and + dtype int32. ``None`` selects the dense all-KV path. + num_qo_heads : int + Number of Q/O heads. + num_kv_heads : int + Number of KV heads. Current decode kernel requires + ``num_qo_heads / num_kv_heads == 16`` at run time. + head_dim : int, optional + Head dimension. Must be 128. + enable_cuda_graph : bool, optional + Build schedule metadata compatible with CUDA graph capture. + max_grid_size : int, optional + Override maximum CTA count used by the scheduler. + fixed_split_size : int, optional + Force a fixed split-KV chunk size in pages. + disable_split_kv : bool, optional + Disable split-KV even for long KV sequences. + + Returns + ------- + SparseDecodePagedAttentionWrapper + ``self``, planned and ready for ``run``. + """ + if page_table.ndim != 2: + raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]") + if page_table.dtype != torch.int32: + raise TypeError("decode plan requires page_table to be torch.int32") + if seqused_k.dtype != torch.int32: + raise TypeError("decode plan requires seqused_k to be torch.int32") + if not page_table.is_cuda or not seqused_k.is_cuda: + raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors") + if page_table.device != seqused_k.device: + raise ValueError("decode plan requires page_table and seqused_k on the same device") + if page_table.stride(-1) != 1: + raise ValueError("decode plan requires page_table contiguous in the last dimension") + if seqused_k.shape != (int(page_table.shape[0]),): + raise ValueError("decode plan requires seqused_k with shape [B]") + if q2k_indices is not None and q2k_indices.dtype != torch.int32: + raise TypeError("decode plan requires q2k_indices to be torch.int32") + if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0: + raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k") + if num_qo_heads is None or num_kv_heads is None or head_dim is None: + raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim") + if head_dim is not None and int(head_dim) != 128: + raise NotImplementedError("decode plan currently supports only head_dim=128") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads") + + self.batch = int(page_table.shape[0]) + self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads) + self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads) + self.head_dim = None if head_dim is None else int(head_dim) + self.page_table = page_table.contiguous() + self.seqused_k = seqused_k.contiguous() + self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous() + self.seqlen_q = int(seqlen_q) + self.max_seqlen_k = int(max_seqlen_k) + self.is_sparse = q2k_indices is not None + + # max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++ + # schedule launcher because the decode attn kernel always runs at + # 1 CTA/SM (its register/smem budget saturates the SM). Callers + # can still override via the explicit max_grid_size kwarg. + schedule = prepare_decode_schedule( + seqused_k=self.seqused_k, + page_size=self.blk_kv, + seqlen_q=self.seqlen_q, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seqlen_k=self.max_seqlen_k, + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=max_grid_size, + fixed_split_size=fixed_split_size, + disable_split_kv=bool(disable_split_kv), + ) + self.decode_schedule = schedule + self.request_indices = schedule.request_indices + self.qo_tile_indices = schedule.qo_tile_indices + self.kv_tile_indices = schedule.kv_tile_indices + self.merge_indptr = schedule.merge_indptr + self.o_indptr = schedule.o_indptr + self.block_valid_mask = schedule.block_valid_mask + self.kv_pages = schedule.kv_pages + self.split_counts = schedule.split_counts + self.split_kv = schedule.split_kv + self.cta_tile_q = schedule.cta_tile_q + self.num_q_tiles = schedule.num_q_tiles + self.kv_chunk_size_pages = schedule.kv_chunk_size_pages + self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens + self.work_count = schedule.work_count + self.padded_work_count = schedule.padded_work_count + if schedule.split_kv: + self.O_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self.LSE_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + self._O_partial_dummy = None + self._LSE_partial_dummy = None + else: + self.O_partial = None + self.LSE_partial = None + # decode_forward_paged_fp8 always wants non-None partial buffers + # for the kernel's positional arg layout (compile keeps the slot + # alive even when split_kv=False). Allocate once here and reuse. + self._O_partial_dummy = torch.empty( + (1, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self._LSE_partial_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + # LSE dummy is shape (1, head_q) — used when caller doesn't request + # LSE and the schedule isn't split-KV (split-KV always writes LSE). + self._lse_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + return self + + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + ): + """Launch decode using metadata cached by ``plan``. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``. + v : torch.Tensor + Paged V cache with the same shape as ``k``. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + out : torch.Tensor, optional + Preallocated BF16 output buffer with shape ``q.shape``. + lse : torch.Tensor, optional + Preallocated float32 LSE buffer with shape ``[total_q, Hq]``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output, optionally with float32 LSE. + """ + if self.decode_schedule is None: + raise RuntimeError("decode wrapper must be planned before run") + if self.is_sparse: + # Sparse path still goes through the validating wrapper for now; + # only the dense fast path is collapsed. + return sparse_decode_atten_func( + q, k, v, self.q2k_indices, + page_table=self.page_table, seqused_k=self.seqused_k, + seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, + blk_kv=self.blk_kv, causal=self.causal, + softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse, + schedule=self.decode_schedule, + O_partial=self.O_partial, LSE_partial=self.LSE_partial, + ) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + if out is None: + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + if lse is None: + if return_softmax_lse or self.split_kv: + # Real LSE needed — must allocate per-call (shape depends on q). + lse = torch.empty( + q.shape[:2], dtype=torch.float32, device=q.device, + ) + else: + # Kernel only needs a valid pointer; reuse cached dummy. + lse = self._lse_dummy + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + schedule = self.decode_schedule + decode_forward_paged_fp8( + q, k, v, + self.page_table, self.seqused_k, + out, lse, + schedule.request_indices, schedule.qo_tile_indices, + schedule.kv_tile_indices, schedule.block_valid_mask, + schedule.split_counts, schedule.o_indptr, schedule.merge_indptr, + self.O_partial, self.LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=self.seqlen_q, + page_size=self.blk_kv, + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=self.causal, + return_lse=bool(return_softmax_lse), + # cached dummies — avoid per-call torch.empty inside run_decode_attention + O_partial_dummy=self._O_partial_dummy, + LSE_partial_dummy=self._LSE_partial_dummy, + ) + if return_softmax_lse: + return out, lse + return out + + +def _sparse_atten_csr_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + causal: bool, + softmax_scale: float, + lse_temperature_scale: float, + return_temperature_lse: bool, + partial_dtype: torch.dtype, + return_softmax_lse: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + page_table: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + schedule: Optional[SparseAttentionSchedule], + usable_SM_count: int, + batch: int, + head_kv: int, + max_seqlen_q: int, + max_seqlen_k: int, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + total_q, head_q, dim = q.shape + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by head_kv") + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + schedule = _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count, + causal=causal, + schedule=schedule, + qk_dtype=qk_dtype, + pv_dtype=pv_dtype, + ) + # Sparse Attention and Sparse Page Attention both use the varlen-Q + # combine path; the kernel-written LSE_out is the final contract. + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def _call_sparse_decode_forward_sm100_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + schedule: DecodeAttentionSchedule, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, + return_lse: bool = True, +) -> None: + """Compile and launch the SM100 paged fp8 decode forward kernel. + + Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse + the same schedule wrapper but needs a separate q2k gather path. + """ + if q2k_indices is not None: + raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet") + if schedule.cta_tile_q != 128: + raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}") + if schedule.split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode forward requires O_partial and LSE_partial") + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + + decode_forward_paged_fp8( + q, + k, + v, + page_table, + seqused_k, + out, + lse, + schedule.request_indices, + schedule.qo_tile_indices, + schedule.kv_tile_indices, + schedule.block_valid_mask, + schedule.split_counts, + schedule.o_indptr, + schedule.merge_indptr, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(blk_kv), + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + ) + + +def _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count=-1, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + """Compile and launch the SM100 sparse forward K1 kernel on CSR metadata.""" + head_dim = q.shape[-1] + dtype = q.dtype + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype") + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size) + else: + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + k.dtype, + v.dtype, + qk_dtype, + pv_dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype), + pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"): + _compile_cache[key]( + k_kernel, + v_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule + + +def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V.""" + + head_dim = q.shape[-1] + dtype = q.dtype + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0" + k_global_scale_kernel = k_global_scale + # V global scale is linear in the final output. Keep K1 on block-scale-only V + # and apply the tensor scale once in K2 combine. + v_global_scale_kernel = None + has_k_global_scale = k_global_scale_kernel is not None + has_v_global_scale = v_global_scale_kernel is not None + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("KVFP4 sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + _prepare_paged_kv_for_tma(k, v, n_block_size) + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen_nvfp4_kv", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + bool(fp8_pair_dequant), + bool(has_k_global_scale), + bool(has_v_global_scale), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardNvfp4KvSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + fp8_pair_dequant=bool(fp8_pair_dequant), + has_k_global_scale=bool(has_k_global_scale), + has_v_global_scale=bool(has_v_global_scale), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k_scale_128x4), + to_cute_tensor_kvouter(v_scale_128x4), + None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel), + None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"): + _compile_cache[key]( + k_kernel, + v_kernel, + k_scale_128x4, + v_scale_128x4, + k_global_scale_kernel, + v_global_scale_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule diff --git a/build/torch212-cxx11-cu130-x86_64-linux/metadata.json b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..ec96b6da0a70ecf38eaf82ca990e658c11213063 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,71 @@ +{ + "name": "msa", + "id": "_msa_cuda_09d7851", + "version": 0, + "license": "other", + "upstream": "https://github.com/MiniMax-AI/MSA", + "python-depends": [ + "tvm-ffi", + "nvidia-cutlass-dsl" + ], + "backend": { + "type": "cuda", + "archs": [ + "10.0" + ] + }, + "digest": { + "algorithm": "sha256", + "files": { + "__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=", + "_msa_cuda_09d7851.abi3.so": "1EVRI9AyFVZ5oLq8hoC9dI4B8ILQZwhvtAsqHg+f64M=", + "_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=", + "fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=", + "interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=", + "msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=", + "quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + "quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=", + "quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=", + "quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=", + "quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=", + "quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=", + "quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=", + "sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=", + "src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=", + "src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=", + "src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=", + "src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=", + "src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=", + "src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=", + "src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=", + "src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=", + "src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=", + "src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=", + "src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=", + "src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=", + "src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=", + "src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=", + "src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=", + "src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=", + "src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=", + "src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=", + "src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=", + "src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=", + "src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=", + "src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=", + "src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=", + "src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=", + "src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=", + "src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=", + "src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=", + "src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=", + "src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=", + "src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=", + "src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=", + "src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=", + "src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw=" + } + } +} \ No newline at end of file diff --git a/build/torch212-cxx11-cu130-x86_64-linux/metadata.json.sigstore b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json.sigstore new file mode 100644 index 0000000000000000000000000000000000000000..d87c6f4c58df89ffac7077555312fa76762bd932 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/metadata.json.sigstore @@ -0,0 +1 @@ +{"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHSjCCBtGgAwIBAgIUUslxB0INf5j4eDKf+aXVHPL6zTowCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA5WhcNMjYwNjMwMTc1NDA5WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE09xP8GFQP9M2os9Kw63oiq0rPsh6MEQywm7Eb+EFz2Hgy0b4vhxwT9HlCm5o8liEOQSJKuoWrV5x3r01JBcPXKOCBfAwggXsMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQUH2hwnJs/0xC7TGHTQW9Jb775vI0wHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBigYKKwYBBAHWeQIEAgR8BHoAeAB2AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhnZkAAAQDAEcwRQIgLfhabZQp9qi8OJLkd7wx8KxMXzgdUgrNSpnNx/Yt1LwCIQDrIW2sLQAysG+mdGRKGlEINhgEIxWUAQia6zV255ndzDAKBggqhkjOPQQDAwNnADBkAjBjnDa4dVwVkTXgWpxLv0fpr16f6bmsWHrhAQ7ZCnnW02LckbFmZA7nMrnKm5TdKVkCMFqhHgEEx9qNSJI2AcBSpqVluaP1SO8n+R/CWk8EFR4u7CBhxHfaGnnQWhwBcdiL1A=="},"tlogEntries":[{"logIndex":"2024793509","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841450","inclusionPromise":{"signedEntryTimestamp":"MEQCIGSyFyAGtjn93VDgYy5pMYkIcqcz7mI/6YbdAlgQxtKBAiByn0GtDvtWdkMToAYsGW72Sb1qHId0VifY3TbJYTtV/g=="},"inclusionProof":{"logIndex":"1902889247","rootHash":"H/sc0bl0Lo+O7HsynZszd0bSvIZEkuK51WpfQLsHYBA=","treeSize":"1902889259","hashes":["2LtcXT2qXUzXh843QRHKNIFuuXpuxza7x8AeDyQby1s=","AO4Owb1wZ6utg+d9sHuKDZ6BD8u+19RqT1cusUE3oOs=","sWG6DfkOl9ltXsT+61LcV8RFOURovNScZoVZcU8PH54=","FQOppdiB+agmFn4PY7yD7XeTzfH6iwZPScKN9Qf9XNA=","CYk5RESSTtmT3YlnQ0zuOAu8BukoBUAys6niD2REoE0=","CSdRwr4mfaSZAUK6l9RMGZUtvFsp0X86rNqu5WO9V4I=","+SguqhsiLtVi/4DqFmT73WxhBMefFfl4zk7U8tHWRnI=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889259\nH/sc0bl0Lo+O7HsynZszd0bSvIZEkuK51WpfQLsHYBA=\n\n— rekor.sigstore.dev wNI9ajBFAiEAtq00y0/QNTktfeELsydCq+tGklbso6WvgwCa/9KzwjgCID33JjnryUA9e2A4HqFz+mEKa8qlbWHr5+EGs6Ymn0yH\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiI5ZTk5NjQ0MDZlOTliMzA1MTBjZGUwYjI1ODZhZTgzMzVjNGJlMDE3MTQ0ODA0MmM1YzU1NjEwMzQxNTJjNTcyIn19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FWUNJUUQ4OEwwd0pST1M0WWEybE81ZFBmUnErTkQ4Smh5VGZTR1l1ZGp4NEUrbzRRSWhBTEgwMEh5TS9GMzBqL05BRFN1YU1NQTh3TnlxeWFmcGV5ODQrV0Z6LzBsQiIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFRha05EUW5SSFowRjNTVUpCWjBsVlZYTnNlRUl3U1U1bU5XbzBaVVJMWml0aFdGWklVRXcyZWxSdmQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTlZkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFMVYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVV3T1hoUU9FZEdVVkE1VFRKdmN6bExkell6YjJseE1ISlFjMmcyVFVWUmVYZHROMFVLWWl0RlJub3lTR2Q1TUdJMGRtaDRkMVE1U0d4RGJUVnZPR3hwUlU5UlUwcExkVzlYY2xZMWVETnlNREZLUW1OUVdFdFBRMEptUVhkbloxaHpUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlZJTW1oM0NtNUtjeTh3ZUVNM1ZFZElWRkZYT1VwaU56YzFka2t3ZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwWjFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT0VKSWIwRUtaVUZDTWtGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2h1V210QlFVRlJSQXBCUldOM1VsRkpaMHhtYUdGaVdsRndPWEZwT0U5S1RHdGtOM2Q0T0V0NFRWaDZaMlJWWjNKT1UzQnVUbmd2V1hReFRIZERTVkZFY2tsWE1uTk1VVUY1Q25OSEsyMWtSMUpMUjJ4RlNVNW9aMFZKZUZkVlFWRnBZVFo2VmpJMU5XNWtla1JCUzBKblozRm9hMnBQVUZGUlJFRjNUbTVCUkVKclFXcENhbTVFWVRRS1pGWjNWbXRVV0dkWGNIaE1kakJtY0hJeE5tWTJZbTF6VjBoeWFFRlJOMXBEYm01WE1ESk1ZMnRpUm0xYVFUZHVUWEp1UzIwMVZHUkxWbXREVFVaeGFBcElaMFZGZURseFRsTktTVEpCWTBKVGNIRldiSFZoVURGVFR6aHVLMUl2UTFkck9FVkdValIxTjBOQ2FIaElabUZIYm01UlYyaDNRbU5rYVV3eFFUMDlDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyTADAgEAMIICwAYJKoZIhvcNAQcCoIICsTCCAq0CAQMxDTALBglghkgBZQMEAgEwgbgGCyqGSIb3DQEJEAEEoIGoBIGlMIGiAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQgObnDyCZClpo/Z9qKAEx9Mh3jGppJFenw8by0WcfLjHMCFQCKfFXwLtbDZ7amCx06KhOlj3PLhhgPMjAyNjA2MzAxNzQ0MTBaMAMCAQGgMqQwMC4xFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEVMBMGA1UEAxMMc2lnc3RvcmUtdHNhoAAxggHaMIIB1gIBATBRMDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAsGCWCGSAFlAwQCAaCB/DAaBgkqhkiG9w0BCQMxDQYLKoZIhvcNAQkQAQQwHAYJKoZIhvcNAQkFMQ8XDTI2MDYzMDE3NDQxMFowLwYJKoZIhvcNAQkEMSIEIC4iGkbklp26RUgVTfzeAiDSe/Ts3+juCrT71lYnXG9WMIGOBgsqhkiG9w0BCRACLzF/MH0wezB5BCCF+Se8B6tiysO0Q1bBDvyBssaIP9p6uebYcNnROs0FtzBVMD2kOzA5MRUwEwYDVQQKEwxzaWdzdG9yZS5kZXYxIDAeBgNVBAMTF3NpZ3N0b3JlLXRzYS1zZWxmc2lnbmVkAhQ6E1QvDJBh7rzBQy/Lio6LKiOLDDAKBggqhkjOPQQDAgRmMGQCMDgqf/x7UkWmMWUp+kRCLPm0cVMCvLUPQSPsKedAjgUAqKtnSuHIZLb7S67fW9FtkgIwA8PWxegdU1Z9DZiFrz7m3QOE0WaM9JQLM5SMn4Ug7SDyibT0ytxDsq7pLlvX7qrv"}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"nplkQG6ZswUQzeCyWGroM1xL4BcUSAQsXFVhA0FSxXI="},"signature":"MEYCIQD88L0wJROS4Ya2lO5dPfRq+ND8JhyTfSGYudjx4E+o4QIhALH00HyM/F30j/NADSuaMMA8wNyqyafpey84+WFz/0lB"}} \ No newline at end of file diff --git a/build/torch212-cxx11-cu130-x86_64-linux/msa/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/msa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/msa/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/activation.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cbeb29242b92b7cc336cd336604e58c36f4459 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quack/activation.py @@ -0,0 +1,532 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Tuple +from functools import partial + +import cutlass.cute as cute +from cutlass import Float32, Boolean, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm + + +F32_or_F32x2 = Float32 | Tuple[Float32, Float32] + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, +) + + +@dsl_user_op +def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) + return 0.5 + 0.5 * tanh(0.5 * x) + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + + +@dsl_user_op +def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + # return dout * out * (1.0 - out) + return dout * (out - out * out) + + +@dsl_user_op +def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) + else: + return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)) + + +@dsl_user_op +@cute.jit +def drelu( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0)) + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0)) + return dx, relu(x) + + +@dsl_user_op +def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * x + else: + relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) + return cute.arch.mul_packed_f32x2(relu_x, x) + + +@dsl_user_op +@cute.jit +def drelu_sq( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward + Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out + Returns: (dx, relu_sq_out) where: + - dx = dout * 2 * x if x > 0, else 0 + - relu_sq_out = max(x, 0) * x + """ + if const_expr(not isinstance(x, tuple)): + relu_x = relu(x) + relu_sq_out = relu_x * x + # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0 + dx = 2.0 * (dout * relu_x) + return dx, relu_sq_out + else: + relu_x = relu(x) + relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x) + dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x)) + return dx, relu_sq_out + + +@dsl_user_op +def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ + gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x))) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774 + if const_expr(not isinstance(x, tuple)): + return 0.5 * ( + x + # Currently cute.math.tanh(x, fastmath=True) generates very slow code + # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True)) + * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) + ) + else: + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x) + return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + + +@dsl_user_op +def dgelu_tanh_approx( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward + Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out + Returns: (dx, gelu_out) + + Derivative uses the chain rule: + d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2 + and sech^2(z) = 1 - tanh^2(z) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774 + sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322 + + if const_expr(not isinstance(x, tuple)): + # Compute z = x * (c1 + c2 * x^2) + x_sq = x * x + # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True) + tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq)) + half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z + gelu_out = x * half_tanh_z_plus_one + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = 1 - tanh_z * tanh_z + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx)) + + dx = dout * dgelu + return dx, gelu_out + else: + # Compute z = x * (c1 + c2 * x^2) + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one) + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + + dx = cute.arch.mul_packed_f32x2(dout, dgelu) + return dx, gelu_out + + +@dsl_user_op +@cute.jit +def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + use_linear = Boolean(x > 20.0) + return ( + cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True) + if not use_linear + else x + ) + else: + log2_e = math.log2(math.e) + x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e)) + x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) + x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0)) + log_x_exp_p1 = ( + cute.math.log2(x_exp_p1[0], fastmath=True), + cute.math.log2(x_exp_p1[1], fastmath=True), + ) + ln2 = math.log(2.0) + softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + use_linear_0 = Boolean(x[0] > 20.0) + use_linear_1 = Boolean(x[1] > 20.0) + return ( + softplus_x[0] if not use_linear_0 else x[0], + softplus_x[1] if not use_linear_1 else x[1], + ) + + +@dsl_user_op +@cute.jit +def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + use_linear = Boolean(out > 20.0) + # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout + dx = dout - dout * cute.math.exp(-out, fastmath=True) + return dx if not use_linear else dout + + +@dsl_user_op +def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2: + """ + silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA. + """ + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x if const_expr(not already_halved) else x + # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half + return x_half * tanh(x_half) + x_half + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half) + + +@dsl_user_op +def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return silu(x) * y + else: + return cute.arch.mul_packed_f32x2(silu(x), y) + + +@dsl_user_op +def dswiglu( + x: F32_or_F32x2, + y: F32_or_F32x2, + dout: F32_or_F32x2, + *, + already_halved: bool = False, + loc=None, + ip=None, +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out + Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x) + + d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + + This has been optimized to use fewer instructions (i.e. we expand things out + to use FFMA instead of FADD and FMUL). + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) + # FMUL, MUFU.TANH, then FFMA + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = x * sigmoid_x # FMUL + else: + tanh_x = tanh(x) # MUFU.TANH + sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA + silu_x = x * tanh_x + x # FFMA + silu_x_dout = silu_x * dout # FMUL + # d_silu(x) * dout + # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout + # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA + dx = d_silu_x_dout * y # FMUL + dy = silu_x_dout + swiglu_out = silu_x * y # FMUL + # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(x) and silu(x) + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x) + else: + tanh_x = (tanh(x[0]), tanh(x[1])) + sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2( + sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x + ) + d_silu_x_dout = cute.arch.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout + ) + dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y) + dy = silu_x_dout + swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y) + return dx, dy, swiglu_out + + +@dsl_user_op +def swiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> F32_or_F32x2: + """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y. + https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249 + x * sigmoid(alpha * x) * (y + 1) + Compile down to FMUL, FMUL, TANH, FFMA, FFMA + """ + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x + # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half + silu_x = x_half * tanh(alpha * x_half) + x_half + return silu_x * y + silu_x + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + + +@dsl_user_op +def dswiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + Swiglu OAI backward pass: computes gradients w.r.t. x and y + Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out + Returns: (dx, dy, swiglu_oai_out) + + Derivative of x * sigmoid(alpha * x) w.r.t. x: + d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x)) + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + alpha_x_half = (0.5 * alpha) * x # FMUL + # MUFU.TANH, then FFMA + # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True) + sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) + silu_x = x * sigmoid_alpha_x # FMUL + silu_x_dout = silu_x * dout # FMUL + # FFMA, FFMA, FMUL + d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1 + dy = silu_x_dout + swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1 + # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(alpha * x) + alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + silu_x_minus_product = cute.arch.fma_packed_f32x2( + silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x + ) + sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2( + (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x + ) + d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + dy = silu_x_dout + swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + return dx, dy, swiglu_out + + +@dsl_user_op +def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GLU: Gated Linear Unit + glu(x, y) = sigmoid(x) * y + Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + """ + if const_expr(not isinstance(x, tuple)): + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + return sigmoid_x * y # FMUL + else: + sigmoid_x = sigmoid(x) + return cute.arch.mul_packed_f32x2(sigmoid_x, y) + + +@dsl_user_op +def dglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out + Returns: (dx, dy, glu_out) where: + - dx = dout * y * sigmoid(x) * (1 - sigmoid(x)) + - dy = dout * sigmoid(x) + - glu_out = sigmoid(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + sigmoid_x_dout = sigmoid_x * dout # FMUL + glu_out = sigmoid_x * y # FMUL + # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout + # = y * (1 - sigmoid(x)) * sigmoid_x_dout + # = (y - y * sigmoid(x)) * sigmoid_x_dout + # = (y - glu_out) * sigmoid_x_dout + dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL + dy = sigmoid_x_dout + # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA + return dx, dy, glu_out + else: + sigmoid_x = sigmoid(x) + sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout) + glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y) + # dx = (y - glu_out) * sigmoid_x_dout + y_minus_glu_out = sub_packed_f32x2(y, glu_out) + dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + dy = sigmoid_x_dout + return dx, dy, glu_out + + +@dsl_user_op +def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ReGLU: ReLU Gated Linear Unit + reglu(x, y) = relu(x) * y = max(x, 0) * y + """ + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * y + else: + relu_x = relu(x) + return cute.arch.mul_packed_f32x2(relu_x, y) + + +@dsl_user_op +@cute.jit +def dreglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out + Returns: (dx, dy, reglu_out) where: + - dx = dout * y if x > 0, else 0 + - dy = dout * relu(x) + - reglu_out = relu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + relu_x = cute.arch.fmax(x, Float32(0.0)) + dx = (dout * y) if x_pos else Float32(0.0) + dy = dout * relu_x + reglu_out = relu_x * y + return dx, dy, reglu_out + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + relu_x = relu(x) + dout_y = cute.arch.mul_packed_f32x2(dout, y) + dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) + dy = cute.arch.mul_packed_f32x2(dout, relu_x) + reglu_out = cute.arch.mul_packed_f32x2(relu_x, y) + return dx, dy, reglu_out + + +@dsl_user_op +def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GeGLU: GELU Gated Linear Unit + geglu(x, y) = gelu(x) * y + Uses the tanh approximation of GELU + """ + if const_expr(not isinstance(x, tuple)): + return gelu_tanh_approx(x) * y + else: + return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y) + + +@dsl_user_op +def dgeglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out + Returns: (dx, dy, geglu_out) where: + - dx = dout * y * d_gelu(x) + - dy = dout * gelu(x) + - geglu_out = gelu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = dgelu_x_dout * y + dy = gelu_x * dout + geglu_out = gelu_x * y + return dx, dy, geglu_out + else: + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y) + dy = cute.arch.mul_packed_f32x2(gelu_x, dout) + geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y) + return dx, dy, geglu_out diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/compile_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4375594669c8f12d6a79d8878316271cb819568a --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quack/compile_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Optional + +import cutlass.cute as cute + + +def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/copy_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad989559766d6ee6e8ece9d322bf08980706dfa --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quack/copy_utils.py @@ -0,0 +1,890 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import re +from typing import Optional, Type, Tuple, Callable, Sequence +from functools import partial + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Int16, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline +from cutlass._mlir.dialects import llvm +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + +Sm100MmaPeerBitMask = 0xFEFFFFFF + + +@dsl_user_op +def cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + retile: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + if const_expr(retile): + src = tiled_copy.retile(src) + cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def load_s2r_retile( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst_shape: cute.Tensor | cute.Shape, + *, + loc=None, + ip=None, +) -> cute.Tensor: + # Will also accept dst_shape being a tensor, in which case we write into that tensor + if const_expr(not isinstance(dst_shape, cute.Tensor)): + dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip) + else: + dst = dst_shape + cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + num_copy_elems = src.shape[0][0] + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +# def tiled_copy_2d( +# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +# ) -> cute.TiledCopy: +# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width +# copy_elems = num_copy_bits // dtype.width +# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() +# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +# gmem_threads_per_row = major_mode_size // copy_elems +# assert num_threads % gmem_threads_per_row == 0 +# thr_layout = cute.make_ordered_layout( +# (num_threads // gmem_threads_per_row, gmem_threads_per_row), +# order=(1, 0), +# ) +# val_layout = cute.make_layout((1, copy_elems)) +# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return b, m, s + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: + bit_msk = (1 << b) - 1 + yyy_msk = bit_msk << (m + s) + return ptr_int ^ ((ptr_int & yyy_msk) >> s) + + +def swizzle_ptr(ptr: cute.Pointer): + b, m, s = parse_swizzle_from_pointer(ptr) + ptr_int = swizzle_int(ptr.toint(), b, m, s) + return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) + + +def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: + outer = tensor.layout + width = tensor.element_type.width + inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for + # for 16 bits and <3, 2, 3> for 32 bits) + new_layout = cute.recast_layout( + width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer)) + ) + # recast_ptr to remove the pointer swizzle + return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout) + + +def partition_D_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_D(tensor).iterator), + thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +def partition_S_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_S(tensor).iterator), + thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +@dsl_user_op +def sm90_get_smem_load_op( + layout_c: cutlass.utils.LayoutEnum, + elem_ty_c: Type[cutlass.Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem load atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_c : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_c : Type[Numeric] + The element type for output tensor D. + + Returns: + -------- + Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters. + """ + + if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta): + raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") + is_m_major = layout_c.is_m_major_c() + if elem_ty_c.width == 16: + return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip) + else: + return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_load_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_store_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + + def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs): + dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx] + cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sC + + +def get_smem_load_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sC = thr_copy.partition_S(sC) + else: + tSR_sC = partition_S_position_independent(thr_copy, sC) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape + + def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs): + src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx] + return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs) + + return copy_fn, thr_copy, tSR_sC + + +def epilog_smem_copy_atom( + tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False +) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + +def get_smem_store_epi( + tiled_mma: cute.TiledMma, + epi_tile: cute.Shape, + sC: Optional[cute.Tensor], + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]: + dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16 + tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile) + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom) + thr_copy = tiled_copy.get_slice(tidx) + tRS_sC = None + if const_expr(sC is not None): + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + sC_shape = sC.shape[:2] if sC is not None else epi_tile + # (R2S, R2S_M, R2S_N, PIPE_C) + tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape + tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs) + + return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC + + +def get_smem_store_A( + tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sA = thr_copy.partition_D(sA) + else: + tRS_sA = partition_D_position_independent(thr_copy, sA) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sA + + +def get_smem_load_A( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + tidx: Int32, + arch: int, + with_dst_tensor: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sA = thr_copy.partition_S(sA) + else: + tSR_sA = partition_S_position_independent(thr_copy, sA) + tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs): + return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs) + + return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + Get the address of the TMA descriptor embedded in a TMA Copy Atom. + + Extracts the constant memory address of the TMA descriptor for use with + custom PTX instructions. + + :param tma_atom: TMA Copy Atom from make_tiled_tma_atom + :return: Pointer to TMA descriptor in constant memory + + Example: + >>> desc_ptr = get_tma_descriptor_address(tma_atom) + """ + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + + +@dsl_user_op +def tma_gather4_load( + tma_desc_ptr: cute.Pointer, + dst_smem_ptr: cute.Pointer, + mbarrier_ptr: cute.Pointer, + col_idx: Int32, + row_indices: Sequence[Int32], + *, + num_cta: int = 1, + multicast_mask=None, + loc=None, + ip=None, +) -> None: + """ + Perform TMA gather4 load from global memory to shared memory. + + Issues PTX instruction: + cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar]; + + This loads 4 rows (specified by row_indices) from a 2D tensor at the given + column index into shared memory, using the TMA descriptor. + + :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned) + :type tma_desc_ptr: Pointer + :param dst_smem_ptr: Destination address in shared memory + :type dst_smem_ptr: Pointer + :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking + :type mbarrier_ptr: Pointer + :param col_idx: Column index + :type col_idx: Int32 + :param row_indices: Sequence of exactly 4 row indices + :type row_indices: Sequence[Int32] + :param num_cta: Number of CTAs participating (default: 1) + :type num_cta: int + :param multicast_mask: Optional multicast mask + :type multicast_mask: Int16 + + Requirements: + - row_indices must contain exactly 4 elements + - Compute capability >= SM_100 (Blackwell) + - TMA descriptor must be properly initialized for 2D tensor + + Example: + >>> from cutlass.cute.nvgpu import cpasync + >>> from cutlass.cute import core + >>> + >>> # Create TMA descriptor + >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...) + >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom) + >>> + >>> # Compute indices (typically from kernel logic) + >>> col_idx = core.get(...) or 5 # Int32 value + >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values + >>> + >>> # Gather 4 rows at computed column + >>> tma_gather4_load( + ... tma_desc_ptr=tma_desc_ptr, + ... dst_smem_ptr=smem_ptr, + ... mbarrier_ptr=barrier_ptr, + ... col_idx=col_idx, + ... row_indices=row_indices + ... ) + """ + if len(row_indices) != 4: + raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}") + col_val = Int32(col_idx).ir_value() + row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices] + # Convert pointers to integer addresses + desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip) + if num_cta > 1: + # Executed by both CTAs. Set peer bit to 0 so that the + # transaction bytes will update CTA0's barrier. + mbar_addr = mbar_addr & Sm100MmaPeerBitMask + mbar_addr = mbar_addr.ir_value() + # Handle multicast_mask - may already be ir.Value or Python int + multicast_mask_val = None + if multicast_mask is not None: + multicast_mask_val = Int16(multicast_mask).ir_value() + assert multicast_mask_val is None, "multicast is not supported yet" + # Emit inline PTX for TMA gather4 + # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar]; + ptx = ( + f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} " + "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];" + ) + + llvm.inline_asm( + None, + [ + dst_addr, + desc_addr, + col_val, + row_vals[0], + row_vals[1], + row_vals[2], + row_vals[3], + mbar_addr, + ], + ptx, + "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy( + atom, + src[None, src_idx], + dst[None, dst_idx], + mbar_ptr=tma_bar_ptr, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +@cute.jit +def gather_m_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_M), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + tAsA = thr_copy_A.partition_D(sA) + # k-major + assert tAsA.shape[2] == 1 + tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + m_idx = cute.make_rmem_tensor(rows_per_thread, Int32) + for m in cutlass.range(rows_per_thread, unroll_full=True): + row_idx = tAcA[0, m, 0][0] + if tApA_m[m]: + m_idx[m] = gsAIdx[row_idx] + else: + m_idx[m] = 0 # It's ok to load row 0 in the case of OOB + + mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + + def copy_fn(src_idx, dst_idx, pred: bool = False): + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + mA_cur = mA_k[None, (None, src_idx)] + for m in cutlass.range_constexpr(tAcA.shape[1]): + # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape + # ((elems_per_load), thread_per_row) + # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA + # So we append 1s to the last dimension and then do tiled_divide, then slice. + mA_row = cute.tiled_divide( + cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1) + )[None, None, 0] + if const_expr(is_even_m_smem) or tApA_m[m]: + # There's only 1 load per row + assert cute.size(tAcA.shape, mode=[2]) == 1 + ki = tAcA[0, 0, 0][1] // elems_per_load + cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k) + + return copy_fn + + +@cute.jit +def gather_k_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (tile_M, whatever) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + gAIdx, sAIdx = None, None + if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem): + gAIdx = gsAIdx + else: + assert gsAIdx.memspace == cute.AddressSpace.smem + sAIdx = gsAIdx + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + # (atom_v, CPY_M, 1, STAGE) + tAsA = thr_copy_A.partition_D(sA) + # m-major + tAsA = cute.group_modes(tAsA, 0, 3) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) + # This is very convoluted but idk a better way + # for tile_M=128, flat_divide gives (8, 16, K), + # then logical_divide gives ((8, 1), (8, 2), K). + tidx = thr_copy_A.thr_idx + tAmA = cute.logical_divide( + cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col) + )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K) + + def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]: + # Prefetch mAIdx early, even before smem is free + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + gAIdx_cur = gAIdx[None, src_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + if const_expr(not pred): + k_idx[k] = gAIdx_cur[col_idx] + else: + if tApA_k[k]: + k_idx[k] = gAIdx_cur[col_idx] + else: + k_idx[k] = -1 + return k_idx, tApA_k + + def prefetch_from_smem_fn( + a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False + ) -> Tuple[cute.Tensor, cute.Tensor]: + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_cur = sAIdx[None, dst_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + k_idx[k] = sAIdx_cur[col_idx] + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return k_idx, tApA_k + + def copy_fn( + src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False + ): + k_idx, tApA_k = k_idx_tApA_k + tApA_k_pred = None + if const_expr(pred): + tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread) + for k in cutlass.range_constexpr(tAcA.shape[2]): + # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2)) + for m in cutlass.range_constexpr(tAcA.shape[1]): + if tApA_m[m]: + cute.copy( + thr_copy_A, + tAmA[None, m, k_idx[k]], + tAsA[(None, m, k), dst_idx], + pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k], + ) + + return copy_fn, prefetch_from_gmem_fn if const_expr( + gAIdx is not None + ) else prefetch_from_smem_fn + + +@cute.jit +def gather_m_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # ((4, 32), (64, 1), STAGE) + sAIdx: cute.Tensor, # (tile_M), + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Callable: + tile_M = cute.size(sAIdx, mode=[0]) + tile_K = cute.size(sA[None, None, 0]) // tile_M + assert tile_M % 4 == 0 + # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2 + cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel + + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout + ) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) + # ((4, 1), 8, (64, 1), STAGE) + tSR_sA = warp_copy_AIdx_s2r.partition_S(sA) + tSR_rAIdx = load_s2r(tSR_sAIdx) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + col_idx = tile_K * src_idx + for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, m] for v in range(4)] + smem_ptr = tSR_sA[None, m, None, dst_idx].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) + + return copy_fn diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c92cf39ac08b92245316da46526494d7d8370e1 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quack/cute_dsl_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from functools import lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float16, BFloat16, Float32 +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: Float16, + torch.bfloat16: BFloat16, + torch.float32: Float32, + torch.int32: Int32, + torch.int64: Int64, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quack/layout_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/quack/layout_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..099e0daf54cdac4b25b6d96f01b35451c810249b --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quack/layout_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, const_expr + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: + shape = (*a.shape[:dim], size, *a.shape[dim:]) + stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +@cute.jit +def permute_gated_Cregs_b16(t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower) + + +@cute.jit +def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 + a b | c d | e f | g h + to + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [2, 0, 3, 1] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b10 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a b | c d | e f | g h -> a b | c d | f e | h g + left0 = left if quad_idx < 2 else right + right0 = right if quad_idx < 2 else left + # a b | c d | f e | h g -> a b | f d | c e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a e | f b | c g | h d + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a e | f b | c g | h d -> a e | b f | c g | d h + t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0 + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + + +@cute.jit +def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + to + T0 | T1 | T2 | T3 + a b | c d | e f | g h + This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [1, 3, 0, 2] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b01 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + # This is just the inverse of permute_Cregs_b32_for_stsm + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a e | b f | c g | d h -> a e | f b | c g | h d + left0 = left if quad_idx % 2 == 0 else right + right0 = right if quad_idx % 2 == 0 else left + # a e | f b | c g | h d -> a b | f d | c e | h g + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a b | c d | f e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | c d | f e | h g -> a b | c d | e f | g h + t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0 + + +@cute.jit +def concat_layout(*layouts: cute.Layout) -> cute.Layout: + return cute.make_layout( + tuple(l.shape for l in layouts), + stride=tuple(l.stride for l in layouts), + ) + + +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def convert_layout_zero_stride( + input: cute.Tensor | cute.Layout, ref_layout: cute.Layout +) -> cute.Layout: + layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input + # Group the modes with non-zero stride in the ref_layout together, + # and the modes with zero stride together + layout_flat = cute.flatten(layout) + ref_layout_flat = cute.flatten(ref_layout) + nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0] + zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0] + # There's an edge case when all modes are zero stride + new_shape = ( + tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,), + tuple(layout_flat[i].shape for i in zero_modes), + ) + new_stride = ( + tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,), + tuple(layout_flat[i].stride for i in zero_modes), + ) + out_layout = cute.make_layout(new_shape, stride=new_stride) + if const_expr(isinstance(input, cute.Tensor)): + return cute.make_tensor(input.iterator, out_layout) + else: + return out_layout + + +def mma_partition_C_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def mma_partition_A_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/quantize.py b/build/torch212-cxx11-cu130-x86_64-linux/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4719a4854bc9388b2a866598f9e21c1f14921181 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/quantize.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Transformer Engine NVFP4 quantization helper. + +This file is intended as a customer-facing example for preparing KV tensors +for the KVFP4 attention kernel: + - BF16/FP16 K/V input + - packed E2M1 FP4 data from Transformer Engine + - E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout + - one FP32 tensor/global scale per tensor +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + + +NVFP4_BLOCK_SIZE = 16 +NVFP4_FP4_MAX = 6.0 +NVFP4_FP8_E4M3_MAX = 448.0 + + +@dataclass(frozen=True) +class Nvfp4QuantizedTensor: + """Packed NVFP4 tensor plus dequantization metadata. + + Attributes + ---------- + data : torch.Tensor + Packed E2M1 FP4 data from Transformer Engine. The last dimension is + half of the original logical last dimension because each byte stores + two FP4 values. + scale_128x4 : torch.Tensor + E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage. + global_scale : torch.Tensor + FP32 tensor/global dequant scale. + logical_scale_shape : tuple[int, int] + Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling. + original_shape : tuple[int, ...] + Original BF16/FP16 tensor shape before quantization. + """ + + data: torch.Tensor + scale_128x4: torch.Tensor + global_scale: torch.Tensor + logical_scale_shape: Tuple[int, int] + original_shape: Tuple[int, ...] + + +def _round_up(x: int, multiple: int) -> int: + return ((int(x) + multiple - 1) // multiple) * multiple + + +def nvfp4_scale_128x4_offset( + row: torch.Tensor, + col: torch.Tensor, + scale_cols: int, +) -> torch.Tensor: + """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage. + + Parameters + ---------- + row : torch.Tensor + Logical row indices. + col : torch.Tensor + Logical scale-column indices. + scale_cols : int + Logical number of scale columns before padding to a multiple of 4. + + Returns + ------- + torch.Tensor + Flat offsets into the padded 128x4 tiled storage. + """ + + tiles_n = _round_up(scale_cols, 4) // 4 + tile_m = row // 128 + tile_n = col // 4 + outer = row % 128 + inner = col % 4 + return ( + (tile_m * tiles_n + tile_n) * 512 + + (outer % 32) * 16 + + (outer // 32) * 4 + + inner + ) + + +def swizzle_nvfp4_scale_to_128x4( + scale: torch.Tensor, + *, + rows: int, + cols: int, +) -> torch.Tensor: + """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout. + + Parameters + ---------- + scale : torch.Tensor + Logical rowwise scale tensor with at least shape ``[rows, cols]``. + rows : int + Number of logical rows to convert. + cols : int + Number of logical scale columns to convert. + + Returns + ------- + torch.Tensor + Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)`` + and swizzled into 128x4 tiled storage. + """ + + if scale.ndim != 2: + raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}") + + rows = int(rows) + cols = int(cols) + padded_rows = _round_up(rows, 128) + padded_cols = _round_up(cols, 4) + if scale.shape[0] < rows or scale.shape[1] < cols: + raise ValueError( + "scale is smaller than the requested logical shape: " + f"got {tuple(scale.shape)}, need at least {(rows, cols)}" + ) + + logical = scale[:rows, :cols].contiguous() + if logical.shape != (padded_rows, padded_cols): + logical = torch.nn.functional.pad( + logical.to(torch.float32), + (0, padded_cols - cols, 0, padded_rows - rows), + ).to(scale.dtype) + swizzled = torch.empty_like(logical) + + row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None] + col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1) + swizzled.reshape(-1)[offset] = logical.reshape(-1) + return swizzled + + +def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor: + """Compute TE NVFP4 tensor/global dequant scale from rowwise amax. + + Parameters + ---------- + amax : torch.Tensor + Rowwise absolute maxima returned by Transformer Engine. + + Returns + ------- + torch.Tensor + FP32 global scale equal to ``amax / (448 * 6)``. + """ + + return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX) + + +def _import_te_nvfp4_quantizer(): + try: + from transformer_engine.pytorch.tensor import NVFP4Quantizer + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "Transformer Engine NVFP4 quantization is unavailable. Install a " + "Transformer Engine build with its PyTorch dependencies, including " + "FlashAttention v3 when required by that TE build." + ) from exc + return NVFP4Quantizer + + +def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor: + """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine. + + TE returns rowwise scales in logical padded layout. This helper returns + the scales in physical 128x4 tiled storage, so the attention kernel can + load them with ``nvfp4_scale_128x4_offset``. + + Parameters + ---------- + x : torch.Tensor + CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16, + and the flattened row dimension ``prod(x.shape[:-1])`` must also be + divisible by 16. + + Returns + ------- + Nvfp4QuantizedTensor + Packed FP4 data, 128x4-swizzled block scales, global scale, and shape + metadata needed by the KVFP4 attention kernel or by reference + dequantization. + """ + + if not x.is_cuda: + raise ValueError("NVFP4 quantization requires a CUDA tensor") + if x.dtype not in (torch.bfloat16, torch.float16): + raise TypeError(f"x must be bf16 or fp16, got {x.dtype}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}") + if x.shape[-1] % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}" + ) + + rows = 1 + for dim in x.shape[:-1]: + rows *= int(dim) + if rows % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + "flattened row dimension must be divisible by " + f"{NVFP4_BLOCK_SIZE}, got {rows}" + ) + + NVFP4Quantizer = _import_te_nvfp4_quantizer() + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False) + qx = quantizer.quantize(x.contiguous()) + meta = qx.get_metadata() + + data = meta["rowwise_data"] + if data.dtype is not torch.uint8: + data = data.view(torch.uint8) + logical_scale = meta["rowwise_scale_inv"] + amax = meta["amax_rowwise"] + scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE + scale_128x4 = swizzle_nvfp4_scale_to_128x4( + logical_scale, + rows=rows, + cols=scale_cols, + ) + global_scale = nvfp4_global_scale_from_amax(amax).contiguous() + + return Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale_128x4, + global_scale=global_scale, + logical_scale_shape=(rows, scale_cols), + original_shape=tuple(int(v) for v in x.shape), + ) + + +def quantize_kv_bf16_to_nvfp4_128x4( + k: torch.Tensor, + v: torch.Tensor, +) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]: + """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention. + + Parameters + ---------- + k : torch.Tensor + CUDA BF16 or FP16 K tensor. + v : torch.Tensor + CUDA BF16 or FP16 V tensor. + + Returns + ------- + tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor] + Quantized K and V tensors with independent scales. + """ + + return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v) + + +def dequantize_nvfp4_128x4_to_bf16( + qx: Nvfp4QuantizedTensor, + *, + include_global_scale: bool = True, +) -> torch.Tensor: + """Reference dequantization for validation. + + This mirrors the kernel contract: + x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale + + Parameters + ---------- + qx : Nvfp4QuantizedTensor + Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``. + include_global_scale : bool, optional + If True, multiply by ``qx.global_scale`` after applying per-block + scales. + + Returns + ------- + torch.Tensor + BF16 tensor with shape ``qx.original_shape``. + """ + + data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8) + if data.shape[-1] * 2 != qx.original_shape[-1]: + raise ValueError( + "packed data last dimension does not match original shape: " + f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical" + ) + + rows, scale_cols = qx.logical_scale_shape + logical_dim = int(qx.original_shape[-1]) + if scale_cols * NVFP4_BLOCK_SIZE != logical_dim: + raise ValueError( + "logical scale columns do not match original last dimension: " + f"{scale_cols} scale cols vs dim {logical_dim}" + ) + + fp4_lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=data.device, + ) + packed = data.reshape(rows, logical_dim // 2) + lo = packed & 0x0F + hi = packed >> 4 + values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device) + values[:, 0::2] = fp4_lut[lo.long()] + values[:, 1::2] = fp4_lut[hi.long()] + + row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None] + col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, scale_cols) + scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols) + scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32) + scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1) + out = values * scale + if include_global_scale: + global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32) + out = out * global_scale + return out.reshape(qx.original_shape).to(torch.bfloat16) + + +def _example() -> None: + device = torch.device("cuda") + k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v) + print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype) + print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype) + print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype) + print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype) + print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype) + print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + raise RuntimeError("quantize.py requires CUDA") + _example() diff --git a/build/torch212-cxx11-cu130-x86_64-linux/sparse_index_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/sparse_index_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54c982c9230b189051e3a0bdf76d22b397dd62a --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/sparse_index_utils.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Host-side q2k <-> k2q index conversion for sparse attention. + +These utilities prepare sparse metadata on the Python side for tests, +benchmarks, and other offline preprocessing flows. They are not kernel +runtime helpers, so they intentionally live outside `src/common`. + +Sparse attention pattern: + - Each Q token independently selects up to topK KV blocks (blk_kv tokens each). + - Under GQA, all Q heads in one group share the same sparsity pattern, + so indices are defined at the head_kv level. + +Shapes: + q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks), + trailing unused slots padded with -1 + k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1 + k2q_counts: [batch, head_kv, Nkv] int32 + +CSR reverse-index format: + q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices + k2q_row_ptr: [head_kv, total_rows + 1] int32 + k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx +""" + +from typing import Optional, Tuple + +import torch + +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + + +def q2k_to_k2q( + q2k_indices: torch.Tensor, + num_kv_blocks: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert q2k sparse indices to k2q representation. + + For each KV block, find which Q tokens attend to it. + + Args: + q2k_indices: [batch, head_kv, Sq, topK] int32. + For each Q token, the KV blocks it attends to. Unused slots must + be padded with -1. + num_kv_blocks: Total number of KV blocks (= Skv / blk_kv). + + Returns: + k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32. + For each KV block, the Q token indices that attend to it, + left-packed and padded with -1. Last dim fixed to Sq (upper bound). + k2q_counts: [batch, head_kv, num_kv_blocks] int32. + Actual number of Q tokens per KV block. + """ + B, H, Sq, topK = q2k_indices.shape + device = q2k_indices.device + N = Sq * topK + + kv_flat = q2k_indices.reshape(B, H, N).long() + valid_flat = kv_flat >= 0 + q_flat = ( + torch.arange(Sq, device=device) + .unsqueeze(-1) + .expand(Sq, topK) + .reshape(N) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, H, N) + ) + + k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + k2q_counts.scatter_add_( + 2, + safe_kv_flat, + valid_flat.to(torch.int32), + ) + + sort_keys = torch.where( + valid_flat, + kv_flat, + torch.full_like(kv_flat, num_kv_blocks), + ) + sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True) + sorted_q = q_flat.gather(-1, sort_idx) + sorted_valid = valid_flat.gather(-1, sort_idx) + + offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device) + offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long() + + global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N) + group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1)) + pos_in_group = global_pos - group_offset + + k2q_indices = torch.full( + (B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device + ) + flat_k2q = k2q_indices.reshape(B, H, -1) + flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group + for b in range(B): + for h in range(H): + valid = sorted_valid[b, h] + flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int() + + return k2q_indices, k2q_counts + + +def k2q_to_q2k( + k2q_indices: torch.Tensor, + k2q_counts: torch.Tensor, + Sq: int, + topK: int, +) -> torch.Tensor: + """Convert dense k2q indices back to q2k representation. + + Parameters + ---------- + k2q_indices : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values + are Q token indices padded with ``-1``. + k2q_counts : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of + valid Q indices per KV block. + Sq : int + Q sequence length per batch item in this dense reference format. + topK : int + Maximum number of KV blocks selected per Q token. + + Returns + ------- + torch.Tensor + Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted + by KV block index with ``-1`` padding at the tail. + """ + B, H, Nkv, _ = k2q_indices.shape + device = k2q_indices.device + + q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device) + counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device) + + for b in range(B): + for h in range(H): + for kv_blk in range(Nkv): + count = k2q_counts[b, h, kv_blk].item() + for j in range(count): + qt = k2q_indices[b, h, kv_blk, j].item() + if qt < 0: + continue + p = counters[b, h, qt].item() + if p < topK: + q2k[b, h, qt, p] = kv_blk + counters[b, h, qt] += 1 + + q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k) + _, sort_idx = q2k_sort_key.sort(dim=-1) + q2k = q2k.gather(-1, sort_idx) + return q2k + + +def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}") + if cu_seqlens.numel() < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor: + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + return (seqlens_k + kv_block_size - 1) // kv_block_size + + +def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]: + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + batch = len(rows_per_batch_cpu) + max_rows = max(rows_per_batch_cpu, default=0) + row_dtype = ( + torch.int32 + if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max + else torch.int64 + ) + row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype) + row_linear = 0 + for kv_block_idx in range(max_rows): + for batch_idx, row_count in enumerate(rows_per_batch_cpu): + if kv_block_idx < row_count: + row_map_cpu[batch_idx, kv_block_idx] = row_linear + row_linear += 1 + return row_map_cpu.to(rows_per_batch.device), row_linear + + +def build_k2q_csr_torch_reference( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, +) -> tuple: + """Torch reference for q2k -> k2q CSR conversion. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32. Values are + batch-local KV block indices padded with ``-1``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape + ``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape + ``[head_kv, total_q * topK]``. + """ + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError( + "q2k_indices must have shape [head_kv, total_q, topK], " + f"got {tuple(q2k_indices.shape)}" + ) + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + + head_kv, total_q, topk = q2k_indices.shape + if total_q != int(cu_seqlens_q[-1].item()): + raise ValueError( + f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] " + f"({int(cu_seqlens_q[-1].item())})" + ) + + rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size) + row_map, total_rows = _build_packed_row_map(rows_per_batch) + nnz_upper_bound = total_q * topk + + k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device) + k2q_q_indices = torch.full( + (head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device + ) + if total_rows == 0 or total_q == 0 or topk == 0: + return k2q_row_ptr, k2q_q_indices + + counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device) + total_entries = total_q * topk + row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64 + row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device) + q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device) + valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device) + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist() + entry_cursor = 0 + + for batch_idx, kv_rows in enumerate(rows_per_batch_cpu): + q_start = q_cu_cpu[batch_idx] + q_end = q_cu_cpu[batch_idx + 1] + q_len = q_end - q_start + if q_len == 0: + continue + num_entries = q_len * topk + q2k_batch = q2k_indices[:, q_start:q_end, :] + valid_batch = q2k_batch >= 0 + if valid_batch.any(): + max_valid_kv = int(q2k_batch[valid_batch].max().item()) + if max_valid_kv >= kv_rows: + raise ValueError( + f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, " + f"but that batch only has {kv_rows} logical kv blocks" + ) + kv_flat = q2k_batch.reshape(head_kv, num_entries).long() + valid_flat = valid_batch.reshape(head_kv, num_entries) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + row_map_batch = row_map[batch_idx] + row_flat = row_map_batch[safe_kv_flat] + q_flat = ( + torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32) + .view(1, q_len, 1) + .expand(head_kv, q_len, topk) + .reshape(head_kv, num_entries) + ) + row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat + q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat + valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat + counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32)) + entry_cursor += num_entries + + k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32) + + sort_stride = max(total_q, 1) + invalid_key = total_rows * sort_stride + max_sort_key = invalid_key + max(total_q - 1, 0) + if max_sort_key < torch.iinfo(torch.int32).max: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32) + sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all] + else: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64) + sort_keys[valid_all] = ( + row_all[valid_all].to(torch.int64) * sort_stride + + q_all[valid_all].to(torch.int64) + ) + _, sort_idx = sort_keys.sort(dim=1, stable=True) + sorted_q = q_all.gather(1, sort_idx) + + valid_counts = valid_all.sum(dim=1) + write_mask = ( + torch.arange(total_entries, device=q2k_indices.device) + .unsqueeze(0) + .expand(head_kv, -1) + < valid_counts.unsqueeze(1) + ) + k2q_q_indices[write_mask] = sorted_q[write_mask] + + return k2q_row_ptr, k2q_q_indices + + +_K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100() + + +def build_k2q_csr( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, + *, + total_k: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]: + """Build the public k2q CSR reverse index on GPU. + + Runtime construction does not read device-side ``cu_seqlens`` on the host, + so callers must provide size hints such as ``total_k`` from already-known + tensor shapes. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are + batch-local KV block indices with trailing ``-1`` padding. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + total_k : int + Total KV token count. Required; normally ``k.shape[0]`` for dense KV + or ``sum(kv_segment_lens)`` for paged KV. + max_seqlen_k : int, optional + Maximum KV sequence length. Passing this avoids recomputing a bound. + max_seqlen_q : int, optional + Maximum Q sequence length. + total_rows : int, optional + Total number of packed KV-block rows across the batch. If omitted, + the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``. + qhead_per_kv : int, optional + Number of Q heads per KV head under GQA. + return_schedule : bool, optional + If True, also return the sparse forward schedule object produced by the + SM100 builder. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object] + ``(k2q_row_ptr, k2q_q_indices)`` or + ``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on + the same CUDA device as ``q2k_indices``. + """ + if total_k is None: + raise ValueError("build_k2q_csr requires total_k from k.shape[0]") + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}") + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + return _K2Q_CSR_BUILDER( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + total_k=int(total_k), + blk_kv=int(kv_block_size), + max_seqlen_k=max_seqlen_k, + max_seqlen_q=max_seqlen_q, + total_rows=total_rows, + qhead_per_kv=qhead_per_kv, + return_schedule=return_schedule, + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/aot_cache.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/aot_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd0b4da4ddb6fba21bcb18c924f5e9e8b583e6 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/aot_cache.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Persistent AOT cache for CuTe DSL compiled kernels. + +Saves compiled TVM FFI kernels as .o files on first compile, +loads them on subsequent runs to skip JIT compilation. + +Environment variables: + MM_SPARSE_ATTN_AOT_CACHE: Override cache directory + (default: ~/.cache/minfer/mm_sparse_attn) + MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely +""" + +import hashlib +import os +import time + +import cutlass.cute as cute + +_AOT_CACHE_DIR = os.environ.get( + "MM_SPARSE_ATTN_AOT_CACHE", + os.path.expanduser("~/.cache/minfer/mm_sparse_attn"), +) +_AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1" + +_loaded_modules: dict[str, object] = {} + + +def _key_to_path(key: tuple) -> str: + h = hashlib.sha256(repr(key).encode()).hexdigest()[:16] + name = str(key[0]).replace("/", "_") + return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}") + + +def try_load_aot(key: tuple): + if _AOT_DISABLE: + return None + obj_path = _key_to_path(key) + ".o" + if not os.path.isfile(obj_path): + return None + func_name = str(key[0]) + try: + if obj_path not in _loaded_modules: + _loaded_modules[obj_path] = cute.runtime.load_module( + obj_path, enable_tvm_ffi=True + ) + return getattr(_loaded_modules[obj_path], func_name) + except Exception as e: + print(f"[aot_cache] Failed to load {obj_path}: {e}") + return None + + +def save_aot(key: tuple, compiled) -> None: + if _AOT_DISABLE: + return + if not hasattr(compiled, "export_to_c"): + return + obj_path = _key_to_path(key) + ".o" + os.makedirs(_AOT_CACHE_DIR, exist_ok=True) + tmp_path = obj_path + f".tmp.{os.getpid()}" + func_name = str(key[0]) + try: + t0 = time.time() + compiled.export_to_c(tmp_path, function_name=func_name) + os.replace(tmp_path, obj_path) + dt = time.time() - t0 + print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)") + except Exception as e: + print(f"[aot_cache] Failed to save {func_name}: {e}") + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/barrier.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5753a8a175b529567e0be238f47fd4cc8401bf --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/barrier.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd22f7efa3cef9988b4036c2d00fc1d3b9c816e8 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/blackwell_helpers.py @@ -0,0 +1,1093 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +from . import mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, + num_unroll_groups: int = 1, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range( + cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups + ): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, + acc_tmem_addr, + rA, + rB, + sA_cur, + sB_cur, + zero_init=zero_init, + cta_group=cta_group, + **kwargs, + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + split_arrive: Optional[int] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + assert split_arrive is not None, ( + "split_arrive must be provided when mbar_ptr is not None" + ) + split_arrive_idx = split_arrive // op.shape_mnk[2] + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed( + acc_tmem_addr: Int32, + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_start_b: Int32, + idesc: int, + smem_desc_base_a: Optional[int], + smem_desc_base_b: int, + tCrA_layout: cute.Layout, + tCrB_layout: cute.Layout, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + else: + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] + + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + # smem_desc_start_a_lo = smem_desc_start_a + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(num_k_tile // 4 * 3, num_k_tile) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_smem_desc( + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_base_a: Optional[int], + tCrA_layout: cute.Layout, + var_name_prefix: str = "smem_desc", +) -> None: + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + smem_desc_base_a_lo, smem_desc_a_hi = None, None + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + if const_expr(not is_ts): + llvm.inline_asm( + None, + [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], + f".reg .b32 {var_name_prefix}_lo;\n\t" + f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" + f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + + "".join( + ( + f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" + f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" + ) + for k in range(1, num_k_tile) + ), + "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: + idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) + llvm.inline_asm( + None, + [], + f".reg .b32 {var_name};\n\t" # noqa + f"mov.b32 {var_name}, {hex(idesc)};\n\t", + constraints="", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed_varname( + acc_tmem_addr: Int32, + smem_desc_start_b: Int32, + # idesc: int, + smem_desc_base_b: int, + tCrB_layout: cute.Layout, + smem_var_name_prefix: str, + idesc_var_name: str, + smem_offset: int, + zero_init: bool | Boolean = False, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + is_ts = False + num_k_tile = cute.size(tCrB_layout.shape[2]) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + # ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + # ".reg .b64 smem_desc_b;\n\t" + f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + # f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $2;\n\t" + "mov.b32 smem_desc_b_lo_start, $0;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + + "".join( + ( + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + ) + for k in range(1, num_k_tile) + ) + + "setp.ne.b32 p, $1, 0;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + + "".join( + ( + # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/block_info.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/block_info.py new file mode 100644 index 0000000000000000000000000000000000000000..463290ab3b022a8883e7d40b84ff1ab31827e5dc --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/block_info.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...src.common.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n)) + n_block_min = 0 + if num_splits > 1: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) + return n_block_min, n_block_max + + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_block_max = cute.ceil_div( + seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m + ) + m_block_min = 0 + if const_expr(self.is_causal): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx *= self.qhead_per_kvhead_packgqa + m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m) + return m_block_min, m_block_max diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/copy_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98ba5f40b7b9543744e663a96bcdf637c7e2a146 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/copy_utils.py @@ -0,0 +1,1179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Copy, store, and layout execution helpers. + +`copy_utils.py` is the canonical owner for generic copy primitives, async +bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers. +""" + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +# Generic Copy Primitives + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +# Store/Layout Helpers + +@dsl_user_op +def atomic_add_i32(gmem_ptr, *, loc=None, ip=None): + """Simple atomicAdd. Intended for use under a single-thread guard.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "atom.global.add.u32 $0, [$1], 1;\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None): + """Lane-0 atomicAdd broadcast to the whole warp via shfl.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "{\n" + ".reg .pred p;\n" + ".reg .u32 lane, r;\n" + "mov.u32 lane, %laneid;\n" + "mov.u32 r, 0;\n" + "setp.eq.u32 p, lane, 0;\n" + "@p atom.global.add.u32 r, [$1], 1;\n" + "shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n" + "mov.u32 $0, r;\n" + "}\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def stg_128( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.bf16.f32 h0, $5;\n" + "cvt.rn.bf16.f32 h1, $6;\n" + "cvt.rn.bf16.f32 h2, $7;\n" + "cvt.rn.bf16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.f16.f32 h0, $5;\n" + "cvt.rn.f16.f32 h1, $6;\n" + "cvt.rn.f16.f32 h2, $7;\n" + "cvt.rn.f16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_32_fp8_e4m3( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n" + "mov.b32 p0, {h0, h1};\n" + "st.global.b32 [$4], p0;\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_bf16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two bf16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.bf16.f32 h0, $1;\n" + "cvt.rn.bf16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_f16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two fp16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .f16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.f16.f32 h0, $1;\n" + "cvt.rn.f16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_fp8_e4m3_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [ + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + ] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + Float32(v8).ir_value(loc=loc, ip=ip), + Float32(v9).ir_value(loc=loc, ip=ip), + Float32(v10).ir_value(loc=loc, ip=ip), + Float32(v11).ir_value(loc=loc, ip=ip), + Float32(v12).ir_value(loc=loc, ip=ip), + Float32(v13).ir_value(loc=loc, ip=ip), + Float32(v14).ir_value(loc=loc, ip=ip), + Float32(v15).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n" + "cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n" + "cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n" + "cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n" + "cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n" + "cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n" + "cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; " + "mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; " + "mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; " + "mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; " + "mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;", + ( + "=f,=f,=f,=f,=f,=f,=f,=f," + "=f,=f,=f,=f,=f,=f,=f,=f," + "l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f" + ), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + *acc_layout_col_major.shape[2:], + ), + stride=( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + *acc_layout_col_major.stride[2:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor: + layout = convert_layout_acc_mn( + convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout) + ) + return cute.make_tensor(tensor.iterator, layout) + + +def real_col_to_stg128_fake_col(col: Int32) -> Int32: + nt = col // Int32(16) + col16 = col - nt * Int32(16) + pair = col16 // Int32(2) + rank = pair % Int32(4) + kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2)) + return nt * Int32(16) + rank * Int32(4) + kv + + +def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(16) + fake16 = fake_col - nt * Int32(16) + rank = fake16 // Int32(4) + kv = fake16 % Int32(4) + return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2)) + + +def real_col_to_stg128_half_fake_col(col: Int32) -> Int32: + nt = col // Int32(32) + col32 = col - nt * Int32(32) + lane = (col32 % Int32(8)) // Int32(2) + group = col32 // Int32(8) + elem = col32 % Int32(2) + return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem + + +def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(32) + fake32 = fake_col - nt * Int32(32) + lane = fake32 // Int32(8) + lane_slot = fake32 - lane * Int32(8) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem + + +def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32: + nt = col // Int32(64) + col64 = col - nt * Int32(64) + lane = (col64 % Int32(8)) // Int32(2) + group = col64 // Int32(8) + elem = col64 % Int32(2) + return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem + + +def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(64) + fake64 = fake_col - nt * Int32(64) + lane = fake64 // Int32(16) + lane_slot = fake64 - lane * Int32(16) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem + + +# Cluster & Bulk Async Ops + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_s2cluster( + smem_src_ptr: cute.Pointer, + smem_dst_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + size: int | Int32, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +): + smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() + smem_dst_ptr_i32 = set_block_rank( + smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [ + smem_dst_ptr_i32, + smem_src_ptr_i32, + mbar_ptr_i32, + Int32(size).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +# TMA Copy Adapters + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +__all__ = [ + "atomic_add_broadcast_i32", + "atomic_add_fp32x4", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "copy", + "cpasync_bulk_g2s", + "cpasync_bulk_get_copy_fn", + "cpasync_bulk_s2cluster", + "cpasync_reduce_bulk_add_f32", + "cvt_copy", + "get_copy_atom", + "load_s2r", + "make_16x256b_tensor_mn_view", + "make_tmem_copy", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "set_block_rank", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "sts_32_bf16", + "sts_32_f16", + "store_shared_remote_fp32x4", + "tiled_copy_1d", + "tiled_copy_2d", + "tma_get_copy_fn", + "tma_producer_copy_fn", +] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3473fbbf77fa1261abfc8fd960102c70d3e64bd --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/cute_dsl_utils.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import logging +import os +import pathlib +import time +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +logger = logging.getLogger("minimax") + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta +from cutlass.cute.runtime import from_dlpack + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile. + + Behaviour: + - Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set. + - Logs JIT compile wall time at DEBUG level via the ``minimax`` logger, + tagged with the kernel's class name when available. Enable with + ``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env + ``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT + (~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md). + """ + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + kernel_obj = args[0] if args else kwargs.get("op") + kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "" + t0 = time.time() + output = cute_compile_og(*args, **kwargs) + dt = time.time() - t0 + logger.debug("[%s] compiled in %.1fs", kernel_name, dt) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1": + if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + logger.addHandler(_h) + logger.setLevel(logging.DEBUG) + + +# Monkey-patch cute.compile so every JIT compile across the repo gets timed +# without touching individual call sites. Idempotent: only patches once. +if cute.compile is not cute_compile_patched: + cute.compile = cute_compile_patched + + +def assume_strides_aligned(t): + """Assume all strides except the last are divisible by 128 bits. + + Python int strides (e.g., stride=0 from GQA expand) are kept as-is + since they're static and don't need alignment assumptions. + """ + divby = 128 // t.element_type.width + strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1]) + return (*strides, t.stride[-1]) + + +def assume_tensor_aligned(t): + """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.""" + if t is None: + return None + return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t))) + + +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/fast_math.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/fast_math.py new file mode 100644 index 0000000000000000000000000000000000000000..63a8b4a501ac499e372056a07d499832c830b474 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/fast_math.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/mask.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0da42c3be9bf1c3dcff81ccde579b54131bfa4c6 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/mask.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Callable, Optional, TypeAlias +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Uint32, const_expr + +from ...src.common import utils as utils +from ...src.common.seqlen_info import SeqlenInfoQK + +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + + +@cute.jit +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) + + +@cute.jit +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) + + +@cute.jit +def mask_r2p_lambda( + X: cute.Tensor, + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, +) -> None: + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)): + mask = mask_gen_fn(s) + for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * MASK_R2P_CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) + + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_info: SeqlenInfoQK + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + swap_AB: cutlass.Constexpr[bool] = False + + @property + def seqlen_q(self) -> Int32: + return self.seqlen_info.seqlen_q + + @property + def seqlen_k(self) -> Int32: + return self.seqlen_info.seqlen_k + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + row_idx: Optional[Int32] = None, + kv_valid_cols: Optional[Int32] = None, + kv_block_col_start: Optional[Int32] = None, + ) -> None: + if const_expr(not mask_seqlen and not mask_causal): + return + + col_limit = Int32(self.tile_n) + if const_expr(mask_seqlen): + if const_expr(kv_valid_cols is not None): + col_limit = kv_valid_cols + else: + col_limit = self.seqlen_k - n_block * Int32(self.tile_n) + + if const_expr(mask_causal): + if const_expr(row_idx is None): + row_axis = 0 if const_expr(not self.swap_AB) else 1 + row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa) + else: + row_idx_cur = row_idx + if const_expr(kv_block_col_start is not None): + block_col_start = kv_block_col_start + else: + block_col_start = n_block * Int32(self.tile_n) + causal_col_limit = ( + row_idx_cur + self.seqlen_k - self.seqlen_q + - block_col_start + Int32(1) + ) + col_limit = ( + cutlass.min(col_limit, causal_col_limit) + if const_expr(mask_seqlen) + else causal_col_limit + ) + + if col_limit < Int32(self.tile_n): + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit, s), + rank1=True, + ) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + is_full_block: bool = False, + check_m_boundary: bool = True, + valid_tok_count: Optional[Int32] = None, + q_idx_tile: Optional[cute.Tensor] = None, + masked_tok_count: Optional[Int32] = None, + ) -> None: + del is_full_block, check_m_boundary + del t0ScS_t2r + row_axis = 0 if const_expr(not self.swap_AB) else 1 + col_axis = 1 if const_expr(not self.swap_AB) else 0 + + if const_expr(valid_tok_count is not None): + kv_block_col_start = n_block * Int32(self.tile_n) + causal_q_offset = self.seqlen_k - self.seqlen_q + nfrag = const_expr(cute.size(acc_S.shape)) + for i in cutlass.range(nfrag, unroll_full=True): + row_idx = tScS_t2r[i][row_axis] + tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa) + acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i] + if const_expr(mask_seqlen): + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i] + if const_expr(mask_causal): + if const_expr(q_idx_tile is not None): + causal_tok_count = ( + masked_tok_count + if const_expr(masked_tok_count is not None) + else Int32(0) + ) + if tok_idx < causal_tok_count: + q_idx = q_idx_tile[tok_idx] + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = ( + -Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i] + ) + return + + thr_col_offset = tScS_t2r[0][col_axis] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + + if const_expr(not mask_causal): + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + return + + thr_row_offset = tScS_t2r[0][row_axis] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + row_limit_top = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + num_rep = cute.size(tScS_t2r, mode=[0]) + row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..53c58d17f5085d207f2a1d7b6b45d627ff3322e3 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/mma_sm100_desc.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT +# +# The bit-field encodings, enum values, and descriptor layout below mirror the +# SM100 tcgen05 MMA instruction descriptor as documented and +# implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay +# identical to the hardware/ISA encodings; see the "Third-party licenses" +# section of README.md at the repo root for attribution. + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix "layout" in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type -> encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 + if cutlass_type is cutlass.Float8E4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.Float8E5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for SM100 MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + # CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this + # descriptor aligned with generated/reference SM100 FP8 kernels. + desc |= (int(is_f8f6f4) & 0x1) << 23 + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 + + +def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: + sA_swizzle = sA.iterator.type.swizzle_type + return make_smem_desc_base( + cute.recast_layout(128, sA.element_type.width, sA.layout[0]), + sA_swizzle, + major, + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/named_barrier.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/named_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..a7722a471ca011a94d5fd7774224906001979b78 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/named_barrier.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import enum + + +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + LoadWG = enum.auto() + StoreEpilogue = enum.auto() + KvLoad = enum.auto() + KvDequantK = enum.auto() + KvDequantV = enum.auto() diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dc25edd3f48fbe2c77ec94c8ab3f1ea417507 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/pack_gqa.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""PackGQA primitives for GQA (grouped-query attention) tile layouts. + +Contains: +- ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead`` + into the seqlen dimension of a tensor layout (zero-copy view). +- ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / + ``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)`` + as a single packed row dimension. +- ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE`` + for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q + innermost (stride-1). +""" + +from dataclasses import dataclass +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ...quack import layout_utils + +from . import utils + + +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass +class PackGQA: + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + +@dataclass +class PackGQAComb(PackGQA): + """PackGQA subclass for the K2 combine kernel. + + Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from + ``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when + LSE_partial is laid out with H_q innermost. + + K2 combine treats each query head independently (no GQA grouping in combine + itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller — + all heads are folded into one "group" per Sq position. + """ + + @cute.jit + def load_LSE( + self, + mLSE_partial: cute.Tensor, + # Packed layout after caller-side reshape: + # shape ((qhead_per_kvhead, seqlen_q), num_splits) + # stride ((1, qhead_per_kvhead), ...) + # — H_q is the innermost (stride-1) element of the packed first dim. + sLSE: cute.Tensor, + # SMEM destination: ``(topk, m_block_size)`` fp32. + topk: cutlass.Constexpr[int], + # Explicit topk so the identity tensor shape is a plain int, + # avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape. + gmem_tiled_copy: cute.TiledCopy, + tidx: Int32, + block: Int32, + num_splits: Int32, + seqlen: Int32, + num_heads_divmod: FastDivmodDivisor, + mCounter: Optional[cute.Tensor] = None, + batch_idx: Optional[Int32] = None, + qhead_per_kvhead: Int32 = Int32(1), + # divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so + # caller controls whether the divisor is constexpr or a runtime value. + ): + """Coalesced GMEM→SMEM async load of LSE_partial for one tile. + + For each (split, row) slot this thread owns in the tile, compute the + GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32. + Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``) + are filled with ``-inf`` so they flow cleanly through downstream reductions. + + Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values + (head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to + adjacent GMEM addresses when H_q is stride-1 — one sector per warp. + """ + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cLSE = cute.make_identity_tensor((topk, self.m_block_size)) + tLSEcLSE = gmem_thr_copy.partition_S(cLSE) + tLSEsLSE = gmem_thr_copy.partition_D(sLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = block * self.m_block_size + mi + m_pos, h_pos = divmod(idx, num_heads_divmod) + + if m_pos < seqlen: + row_count = ( + mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead] + if const_expr(mCounter is not None) + else num_splits + ) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + # Build a 1-element GMEM tensor at ((h_pos, m_pos), si), + # matching PackGQA.store_LSE's ptr pattern so cute.copy + # receives a proper Tensor, not a scalar. + src_ptr_i64 = utils.elem_pointer( + mLSE_partial, ((h_pos, m_pos), si)).toint() + src_ptr = cute.make_ptr( + Float32, src_ptr_i64, + cute.AddressSpace.gmem, assumed_align=4, + ) + src_t = cute.make_tensor(src_ptr, (1,)) + cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/paged_kv.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/paged_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5f6923c42a826d4f3dd1f192ce2fdb38eefbf5 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/paged_kv.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + + +@dataclass(frozen=True) +class PagedKVManager: + mPageTable: cute.Tensor + page_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + + @staticmethod + def create( + mPageTable: cute.Tensor, + *, + page_size: int, + n_block_size: int, + ): + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + return PagedKVManager( + mPageTable, + page_size=page_size, + n_block_size=n_block_size, + ) + + @cute.jit + def logical_length( + self, + batch_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + if const_expr(mSeqUsedK is not None): + return mSeqUsedK[batch_idx] + return num_kv_blocks * Int32(self.n_block_size) + + @cute.jit + def valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def physical_block_index( + self, + batch_idx: Int32, + kv_block_idx: Int32, + ) -> Int32: + return self.mPageTable[batch_idx, kv_block_idx] + +__all__ = ["PagedKVManager"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/pipeline.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..27f711772f5c6fa16a86f4aa305f42a0ca9322eb --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/pipeline.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate, dsl_user_op +from cutlass.pipeline import PipelineState +from cutlass.pipeline import PipelineUserType +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg +import cutlass.pipeline as cutlass_pipeline + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """Compatibility wrapper for FA-style helpers now vendored into src.common.""" + return cutlass_pipeline.make_pipeline_state(type, stages) + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + @staticmethod + def create(*args, **kwargs): + obj = NamedBarrierOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", NamedBarrier) + return obj + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +@dataclass(frozen=True) +class PipelineAsync(PipelineAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineAsync + object.__setattr__(obj, "__class__", PipelineAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_try_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + *, + loc=None, + ip=None, + ): + return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineTmaAsyncOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineTmaAsync) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineTmaUmma + object.__setattr__(obj, "__class__", PipelineTmaUmma) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx( + state.index, tx_count, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineUmmaAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineUmmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineUmmaAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsyncUmmaOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineAsyncUmma) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py new file mode 100644 index 0000000000000000000000000000000000000000..873304f71c2cb47ffdd1453fe771c754783f51a4 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/seqlen_info.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...quack import copy_utils + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + + +@dataclass(frozen=True) +class SeqlenInfo: + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + batch_idx: Int32, + seqlen_static: Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, + ): + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) + ) + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] + else: + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + return SeqlenInfoQK( + offset_q, + offset_k, + padded_offset_q, + padded_offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, + ) + + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + else: + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) + + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) + else: + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/softmax.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8f94c1c9e40aeb44c0a128165d90a502feb04afd --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/softmax.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Online softmax primitives. + +Contains: +- ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O. + The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed + ``fmul.f32x2`` (2× CUDA-core throughput) when available. +- ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``, + ``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel. +""" + +import math +import operator +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from ...quack import layout_utils +from ...quack.cute_dsl_utils import ParamsBase + +from . import utils + + +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is + rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL + compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU + confirms: FFMA2 count is 0 for the TensorSSA path). The packed + rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA + instruction count for the softmax scale/subtract stage. + """ + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + row_scale = cute.make_rmem_tensor_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + + for r in cutlass.range(cute.size(row_max), unroll_full=True): + acc_S_row_slice = acc_S_mn[r, None] + acc_S_row = acc_S_row_slice.load() + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + + row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4) + row_max_prev = row_max[r] + row_max[r] = row_max_cur + + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + + row_max_cur_scaled = row_max_cur * scale_log2 + minus_row_max_scaled = -row_max_cur_scaled + n = cute.size(acc_S_row_slice) + + if cutlass.const_expr(arch >= 100 and n % 2 == 0): + # SM100 packed f32x2 FMA path: scale + subtract in one pass. + for i in cutlass.range(0, n, 2, unroll_full=True): + acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row_slice[i], acc_S_row_slice[i + 1]), + (scale_log2, scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + for i in cutlass.range(n, unroll_full=True): + acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True) + acc_S_row_exp = acc_S_row_slice.load() + else: + acc_S_row_exp = cute.math.exp2( + acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True + ) + acc_S_row_slice.store(acc_S_row_exp) + + if cutlass.const_expr(is_first): + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) + row_scale[r] = 1.0 + else: + row_scale[r] = cute.math.exp2( + (row_max_prev - row_max_cur) * scale_log2, fastmath=True + ) + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_sum[r] = acc_S_row_sum + + return row_scale + + @cute.jit + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp. + + On SM100+ with an even ``num_rows`` and no sink_val, the loop is + unrolled in pairs so the key per-row arithmetic ― rcp*final_scale, + max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one + ``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more + ``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare). + """ + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_rmem_tensor_like(row_max, Float32) + + LN2 = math.log(2.0) + num_rows = cute.size(row_sum) + use_packed = cutlass.const_expr( + self.arch >= 100 and num_rows % 2 == 0 and sink_val is None + ) + + if use_packed: + for r in cutlass.range(0, num_rows, 2, unroll_full=True): + s0 = row_sum[r] + s1 = row_sum[r + 1] + m0 = row_max[r] + m1 = row_max[r + 1] + bad0 = s0 == 0.0 or s0 != s0 + bad1 = s1 == 0.0 or s1 != s1 + + # row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar + # (no packed rcp intrinsic); the trailing multiply packs. + rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0) + rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1) + row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2( + (rcp0, rcp1), (final_scale, final_scale) + ) + + # LSE = (row_max * scale_log2 + log2(row_sum)) * LN2 + # packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2. + log0 = cute.math.log2(s0, fastmath=True) + log1 = cute.math.log2(s1, fastmath=True) + lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2( + (m0, m1), (scale_log2, scale_log2), (log0, log1) + ) + lse_0, lse_1 = cute.arch.mul_packed_f32x2( + (lse_pre_0, lse_pre_1), (LN2, LN2) + ) + row_sum[r] = -Float32.inf if bad0 else lse_0 + row_sum[r + 1] = -Float32.inf if bad1 else lse_1 + else: + for r in cutlass.range(num_rows, unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) + + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + row_scale[r] = ( + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = row_sum[r] + row_sum[r] = ( + (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor.""" + acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + n = cute.size(acc_O_mn, mode=[1]) + if cutlass.const_expr(self.arch >= 100 and n % 2 == 0): + # SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput). + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + scale = row_scale[r] + for j in cutlass.range(0, n, 2, unroll_full=True): + acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2( + (acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale) + ) + else: + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + """SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths.""" + + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + @cute.jit + def update_row_max_deferred_exp2( + self, + acc_S_row: cute.TensorSSA, + is_first: int, + ) -> Tuple[Float32, Float32]: + """update_row_max variant that publishes the log2-delta (un-exp2'd) so + the consumer can do the exp2 only when an actual rescale fires. + + Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where: + - ``row_max_safe`` is the same row-max as ``update_row_max`` (with + ``rescale_threshold`` rollback applied). + - ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when + the threshold rollback fired (consumer treats as no rescale), else + the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2`` + (consumer computes ``cute.math.exp2`` and rescales). + + This keeps MUFU.EX2 off the sm_stats publication critical path that + gates the correction WG's consumer wait. + """ + publish = Float32(0.0) + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + # publish stays 0.0 (signal: no rescale needed) + else: + publish = acc_scale_ + else: + publish = acc_scale_ + self.row_max[0] = row_max_new + return row_max_safe, publish + + @cute.jit + def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0]) + self.row_max[0] = row_max_new + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + + @cute.jit + def compute_scaled_exp2_row_sum( + self, + acc_S_row: cute.Tensor, + scale: Float32, + ) -> Float32: + return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch) + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + else: + if cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True + ) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert_sum( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + init_sum: Float32, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ) -> Float32: + # When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res + # pairs in the inner loop use the FFMA2-based polynomial ex2 emulation + # (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's + # apply_exp2_convert. This removes the MUFU "wait" stall that dominates + # the second-largest stall bucket in decode (~22% of total). + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + acc_sum = (init_sum, Float32(0.0)) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2( + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + use_real = cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ) + if cutlass.const_expr(use_real): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + utils.ex2_emulation_2( + acc_S_row_frg[k, j], + acc_S_row_frg[k + 1, j], + ) + ) + acc_sum = cute.arch.add_packed_f32x2( + acc_sum, + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + return acc_sum[0] + acc_sum[1] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..985b4289e146288355dfecd7169383eb64df4f09 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/tile_scheduler.py @@ -0,0 +1,967 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable +from dataclasses import dataclass + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams + +from ...quack.cute_dsl_utils import ParamsBase + +from ...src.common import utils as utils +from ...src.common.fast_math import clz + + +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `SparseAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + args.use_cluster_idx, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) + return ( + grid_x, + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_cluster_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks_cluster: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) + total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(num_block_cluster), + FastDivmodDivisor(args.num_head), + total_blocks_cluster, + cluster_shape_m=args.cluster_shape_mn[0], + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": + if const_expr(cute.size(params.cluster_shape_m) == 1): + tile_idx = cute.arch.block_idx()[0] + else: + tile_idx = cute.arch.cluster_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + usable_SM_count=0, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + cluster_shape_m = int(params.cluster_shape_m) + if usable_SM_count > 0: + sm_count = usable_SM_count + else: + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m + max_ctas = max(max_ctas, cluster_shape_m) + grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks_cluster + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.cluster_shape_m == 1): + self._tile_idx += cute.arch.grid_dim()[0] + else: + self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + num_head: Int32 + num_batch: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileLPTScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # Seems faster if swizzle is a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), + is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + if const_expr(params.lpt): + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + # Round down to nearest multiple of cluster since odd excess is always padding. + total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx // params.cluster_shape_m + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt or params.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + * params.cluster_shape_m + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + if cutlass.const_expr(params.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_m + bidx_in_cluster[0] + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/tma_utils.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/tma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdc19a08eacf9a060f2c0a7a4d50a4adb735094 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/tma_utils.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Raw TMA ops and descriptor builders. + +`tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA +descriptor construction. Non-TMA store/layout helpers are re-exported from +`copy_utils.py` for backward compatibility. +""" + +import ctypes + +from cutlass import Int32, Int64 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as cute_ir +import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir +from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen + + +# Raw TMA Ops + +TMA_CACHE_EVICT_FIRST = 0x12F0000000000000 +TMA_CACHE_EVICT_LAST = 0x14F0000000000000 + + +@dsl_user_op +def tma_tile_load( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $9;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None): + """Prefetch a raw TMA descriptor pointer into the descriptor cache.""" + ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip) + ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width) + ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip) + ptr_ty = cute_ir.PtrType.get( + cute_nvgpu_ir.TmaDescriptorTiledType.get(), + cute_ir.AddressSpace.gmem, + 128, + ) + desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip) + cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def tma_tile_prefetch( + tma_desc_ptr, + col_idx, + row_idx, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint " + "[$0, {$1, $2}], $3;\n", + "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_prefetch( + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint " + "[$0, {$1, $2, $3, $4, $5}], $6;\n", + "l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_load_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes.L2::cache_hint " + "[sa], [$3, {$4, $5}], [ma], $7;\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $0;\n" + "add.u32 sa, sa, $1;\n" + "cvt.u32.u64 ma, $8;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint " + "[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n" + "}\n", + "l,r,l,r,r,r,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_store( + tma_desc_ptr, + col_idx, + row_idx, + smem_ptr, + smem_byte_offset, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store.""" + llvm.inline_asm( + T.i32(), + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + "cvt.u32.u64 sa, $4;\n" + "add.u32 sa, sa, $5;\n" + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + " [$1, {$2, $3}], [sa];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,r,l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +# Descriptor Builders + +_TMA_DESC_BYTES = 128 + + +def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes: + import torch + import cuda.bindings.driver as cuda + + if tensor_2d.ndim != 2: + raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}") + rows, cols = tensor_2d.shape + if tensor_2d.stride(-1) != 1: + raise ValueError(f"{context} tensor must be contiguous in the last dimension") + dtype_map = { + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + } + if tensor_2d.dtype not in dtype_map: + raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}") + + sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)] + strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())] + box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)] + elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)] + err, tm = cuda.cuTensorMapEncodeTiled( + dtype_map[tensor_2d.dtype], + 2, + tensor_2d.data_ptr(), + sizes, + strides, + box, + elem_stride, + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}" + buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr()) + return bytes(buf) + + +def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device): + import torch + + desc_bytes = bytes(desc_bytes) + device = torch.device(device) + if device.type != "cuda": + raise ValueError(f"TMA descriptors require a CUDA device, got {device}") + + host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True) + host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8)) + device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device) + stream = torch.cuda.current_stream(device) + with torch.cuda.stream(stream): + device_desc.copy_(host_desc, non_blocking=True) + device_desc.record_stream(stream) + # Keep the staging buffer alive for the async copy without caching descriptors. + device_desc._tma_host_desc = host_desc + return device_desc + + +def create_flat_gather4_tma_desc(tensor_2d, box_x=64): + """Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor.""" + if tensor_2d.ndim != 2: + raise ValueError( + f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}" + ) + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=1, + context="gather4", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_q_gather4_tma_desc(q_flat, box_x=64): + return create_flat_gather4_tma_desc(q_flat, box_x=box_x) + + +def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y): + """Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride.""" + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=box_y, + context="strided 2D", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128): + """Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage.""" + import torch + + if kv_flat.ndim != 3: + raise ValueError( + f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}" + ) + total_k, head_kv, dim = kv_flat.shape + row_stride = head_kv * dim + desc_table = bytearray() + for h in range(head_kv): + head_view = torch.as_strided( + kv_flat, + size=(total_k, dim), + stride=(row_stride, 1), + storage_offset=h * dim, + ) + desc_table.extend( + _encode_tma_desc_2d_bytes( + head_view, + box_x=box_x, + box_y=box_y, + context="flat KV", + ) + ) + return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape( + head_kv, _TMA_DESC_BYTES + ) + + +# Compatibility Re-exports + +from .copy_utils import ( + atomic_add_broadcast_i32, + atomic_add_i32, + convert_layout_acc_mn, + convert_layout_from_tmem16x256b_to_acc_sm90, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, + stg_128, + stg_128_cs, + stg_128_bf16, + stg_128_bf16_cs, + stg_128_f16, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, + stg_32_fp8_e4m3, + stg_64_bf16, + stg_64_f16, +) + + +__all__ = [ + "TMA_CACHE_EVICT_FIRST", + "TMA_CACHE_EVICT_LAST", + "atomic_add_broadcast_i32", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "create_flat_gather4_tma_desc", + "create_flat_kv_tma_descs", + "create_q_gather4_tma_desc", + "create_strided_2d_tma_desc", + "make_16x256b_tensor_mn_view", + "prefetch_tma_desc_raw", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "tma_gather4", + "tma_gather4_cached", + "tma_gather4_prefetch", + "tma_tile_load", + "tma_tile_load_cached", + "tma_tile_prefetch", + "tma_tile_store", +] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/common/utils.py b/build/torch212-cxx11-cu130-x86_64-linux/src/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bd0ba76b532cb54c159eba5e82320266c80c63 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/common/utils.py @@ -0,0 +1,1088 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import math +import hashlib +import inspect +from typing import Type, Callable, Optional, Tuple, overload + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass.cute.runtime import from_dlpack + + +from ...quack import activation +_MIXER_ATTRS = ("__vec_size__",) + +# Obtained from sollya: +# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); +POLY_EX2 = { + 0: (1.0), + 1: ( + 1.0, + 0.922497093677520751953125, + ), + 2: ( + 1.0, + 0.6657850742340087890625, + 0.330107033252716064453125, + ), + 3: ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ), + 4: ( + 1.0, + 0.693042695522308349609375, + 0.2412912547588348388671875, + 5.2225358784198760986328125e-2, + 1.3434938155114650726318359375e-2, + ), + 5: ( + 1.0, + 0.693151414394378662109375, + 0.24016360938549041748046875, + 5.5802188813686370849609375e-2, + 9.01452265679836273193359375e-3, + 1.86810153536498546600341796875e-3, + ), +} + + +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash + + hasher = hashlib.sha256(base_hash.encode()) + + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) + + return hasher.hexdigest() + + +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale): + """Compute softmax_scale_log2 from softmax_scale. + + Returns (softmax_scale_log2, None). + """ + return softmax_scale * LOG2_E, None + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_rmem_tensor(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + else: + # New API: infers result type automatically + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@cute.jit +def fadd_exp2_scaled_reduce( + x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements" + if const_expr(arch < 100): + return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch) + elif const_expr(cute.size(x.shape) % 8 == 0): + local_sum = [ + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + ] + for i in cutlass.range_constexpr(0, cute.size(x.shape), 8): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i + 0], x[i + 1]), (scale, scale) + ) + acc2, acc3 = cute.arch.mul_packed_f32x2( + (x[i + 2], x[i + 3]), (scale, scale) + ) + acc4, acc5 = cute.arch.mul_packed_f32x2( + (x[i + 4], x[i + 5]), (scale, scale) + ) + acc6, acc7 = cute.arch.mul_packed_f32x2( + (x[i + 6], x[i + 7]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + acc2 = cute.math.exp2(acc2, fastmath=True) + acc3 = cute.math.exp2(acc3, fastmath=True) + acc4 = cute.math.exp2(acc4, fastmath=True) + acc5 = cute.math.exp2(acc5, fastmath=True) + acc6 = cute.math.exp2(acc6, fastmath=True) + acc7 = cute.math.exp2(acc7, fastmath=True) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1)) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3)) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5)) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7)) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + else: + row_sum = Float32(0.0) + for i in cutlass.range_constexpr(0, cute.size(x.shape), 2): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i], x[i + 1]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + row_sum += acc0 + acc1 + return row_sum + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_f32( + a: float | Float32, + b: float | Float32, + c: float | Float32, + d: float | Float32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n" + "mov.b32 $0, {h0, h1};\n" + "}\n", + "=r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_bf16x4( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Convert packed e4m3x4 bits into two packed bf16x2 registers.""" + out0 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "and.b32 out, q, 0x80008000;\n\t" + "and.b32 mant, q, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + out1 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, qs, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "shl.b32 qs, q, 8;\n\t" + "and.b32 out, qs, 0x80008000;\n\t" + "and.b32 mant, qs, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return out0, out1 + + +@dsl_user_op +def cvt_fp4x2_e2m1_f16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert one packed E2M1 byte into one packed f16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0;\n\t" + "mov.b32 {byte0, _, _, _}, $1;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_f16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed f16x2 registers.""" + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + +@dsl_user_op +def cvt_fp4x8_e2m1_bf16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed bf16x2 registers.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8( + src, loc=loc, ip=ip + ) + return ( + cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip), + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_scaled_e4m3x8( + src: cutlass.Int32, + scale_e4m3: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 tmp, ra;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "prmt.b32 tmp, $3, 0, 0;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "mov.b32 ra, {byte0, byte1, _, _};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t" + "mov.b32 ra, {_, _, byte2, byte3};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 sf_bytes, sf_f16x2;\n\t" + ".reg .b16 sf_pair, e0, e1, e2, e3;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + ".reg .b32 h0, h1, h2, h3;\n\t" + "prmt.b32 sf_bytes, $3, 0, 0;\n\t" + "mov.b32 {sf_pair, _}, sf_bytes;\n\t" + "cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t" + "mul.rn.f16x2 h0, h0, sf_f16x2;\n\t" + "mul.rn.f16x2 h1, h1, sf_f16x2;\n\t" + "mul.rn.f16x2 h2, h2, sf_f16x2;\n\t" + "mul.rn.f16x2 h3, h3, sf_f16x2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t" + "mov.b32 $0, {e0, e1};\n\t" + "mov.b32 $1, {e2, e3};\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def cvt_f16x2_to_bf16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert a packed f16x2 register into a packed bf16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b16 h0, h1;\n\t" + ".reg .f32 f0, f1;\n\t" + "mov.b32 {h0, h1}, $1;\n\t" + "cvt.f32.f16 f0, h0;\n\t" + "cvt.f32.f16 f1, h1;\n\t" + "cvt.rn.bf16x2.f32 $0, f1, f0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def mul_bf16x2( + a: cutlass.Int32, + b: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Multiply two packed bf16x2 registers.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Int32(a).ir_value(loc=loc, ip=ip), + cutlass.Int32(b).ir_value(loc=loc, ip=ip), + ], + "mul.rn.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32: + """Decode one E4M3 byte and replicate it into a packed bf16x2 register.""" + + src_u8 = src & cutlass.Int32(0xFF) + packed = src_u8 * cutlass.Int32(0x01010101) + out0, _ = cvt_fp8x4_e4m3_bf16x4(packed) + return out0 + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_rmem_tensor(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@cute.jit +def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None: + """Convert a Float32 rmem tensor to dst's element type. + + fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment + ``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the + DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8 + bytes within a 32-bit register cell in the order DSL chooses, which is + expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads. + """ + if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]): + cvt_f16(src, dst) + elif const_expr(dst.element_type is cutlass.Float8E4M3FN): + assert src.element_type is Float32, "src must be Float32" + assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements" + frg_tile = 4 + src_frg = cute.logical_divide(src, cute.make_layout(frg_tile)) + dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile)) + for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])): + dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type)) + else: + assert src.element_type is Float32, "src must be Float32" + dst_view = cute.make_tensor(dst.iterator, src.layout) + dst_view.store(src.load().to(dst.element_type)) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: + assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" + # We assume x <= 127.0 + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +@dsl_user_op +def ex2_emulation_2( + x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None +) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = activation.sub_packed_f32x2( + xy_rounded, (fp32_round_int, fp32_round_int) + ) + xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_rmem_tensor(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] + + +# ------------------------------------------------------------------ +# Host-side Python helpers (not @cute.jit — called from PyTorch host code) +# ------------------------------------------------------------------ + +def default_softmax_scale(dim: int) -> float: + """Default softmax scale: 1 / sqrt(dim).""" + return 1.0 / math.sqrt(dim) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f23267fe73800d35db382a1919bc28196da5aa8c --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention kernels.""" diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf19c60a32d2f57595c9666323b47738b878115 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/build_k2q_csr/__init__.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""q2k -> k2q CSR builder backed by the precompiled Torch ops. + +The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built +ahead of time by kernel-builder; it is reached through the ``_ops`` +namespace instead of being JIT-compiled at import time. + +The kernel pipeline is tuned and verified for SM100; other +architectures are not supported. +""" + +from __future__ import annotations + +import torch + +from ...._ops import ops + + +def run_build_k2q_csr( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, +) -> None: + """In-place fill of ``row_ptr`` and ``q_idx``. + + Args: + q2k: int32 [H, total_q, topK] contiguous (CUDA). + cu_seqlens_q: int32 [B+1] contiguous (CUDA). + cu_seqlens_k: int32 [B+1] contiguous (CUDA). + row_ptr: int32 [H, total_rows + 1] CUDA, written in place. + q_idx: int32 [H, total_q * topK] CUDA, written in place + (trailing slots set to -1). + topk: must be in {4, 8, 16, 32}. + blk_kv: must equal 128. + total_rows: sum over batches of ceil(seqlen_k / blk_kv). + max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound + used to size the row_map workspace and clamp valid kv ids. + """ + ops.run_build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + ) + + +def run_build_k2q_csr_with_schedule( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + qsplit_idx: torch.Tensor, + split_counts: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, + target_q_per_cta: int, + work_capacity: int, + max_seqlen_q: int, +) -> None: + """In-place fill of CSR plus fused sparse attention schedule metadata.""" + ops.run_build_k2q_csr_with_schedule( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + scheduler_metadata, + work_count, + qsplit_idx, + split_counts, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + int(target_q_per_cta), + int(work_capacity), + int(max_seqlen_q), + ) + + +def is_supported(topk: int, blk_kv: int) -> bool: + return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128 + + +__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..037791818feb030a5969ebf6ac3cc3943cdb7dce --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/decode_schedule.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Split-KV schedule for paged fp8 decode attention. + +The public PageKV representation remains this repo's rectangular page table: +``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only +describes how query tiles and KV chunks are split into work items. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DecodeAttentionSchedule: + split_kv: bool + cta_tile_q: int + num_q_tiles: int + kv_chunk_size_pages: int + kv_chunk_size_tokens: int + work_count: int + padded_work_count: int + partial_rows: int + max_split_count: int + max_grid_size: int + active_blocks_per_sm: int + num_sms: int + base_cta: int + request_indices: torch.Tensor + qo_tile_indices: torch.Tensor + kv_tile_indices: torch.Tensor + merge_indptr: torch.Tensor + o_indptr: torch.Tensor + block_valid_mask: torch.Tensor + kv_pages: torch.Tensor + split_counts: torch.Tensor + + +def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def prepare_decode_schedule( + *, + seqused_k: torch.Tensor, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, +) -> DecodeAttentionSchedule: + """Build paged decode split-KV schedule on the GPU. + + A single CUDA kernel reads ``seqused_k`` on device and writes all + schedule index arrays. Only a small summary tensor is D2H-synced so + the wrapper can size O_partial / pick the kernel grid / choose the + split-vs-non-split compile path. + + ``max_seqlen_k`` is the host-side worst-case bound used to pad the + work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``. + """ + _require_i32_cuda_1d(seqused_k, name="seqused_k") + # Hard cap: current single-CTA schedule kernel stores per-batch state + # in shared memory. Larger batches require a multi-CTA cooperative + # scheduler (unimplemented). Fail fast at the Python boundary so the + # error doesn't surface from inside the CUDA extension. + if int(seqused_k.shape[0]) > 1024: + raise NotImplementedError( + "decode schedule currently supports batch <= 1024 " + f"(got batch={int(seqused_k.shape[0])}). Larger batches need " + "the multi-CTA scheduler — not yet implemented." + ) + # Two API-boundary checks tied to the kernel's packed-GQA layout + # (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8): + # + # (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as + # col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first + # q-token in the packed group) this is col_limit = seqlen_k - seqlen_q + # + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked + # row then enters a mask-codegen path with PTX-undefined shift counts + # and the kernel hangs. The condition is also semantically invalid + # in batched-decode: you can't emit seqlen_q new tokens with fewer + # than seqlen_q total context tokens (seqlen_k includes them). + # + # (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires + # when the LAST partial page has < q_tokens_per_group=8 valid + # columns, because then the *last MMA tile* hits the same all-masked + # row case for the trailing q-tokens. + # + # Both are tracked as a separate kernel-level TODO (un-pack the + # all-masked row → skip mask call, or saturate causal_col_limit at >= 1 + # in mask.py). Until then, fail fast at the Python boundary with a + # clear message rather than letting the kernel timeout. + seqlen_q_i = int(seqlen_q) + bad_q = seqused_k < seqlen_q_i + if bool(bad_q.any().item()): + bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) " + f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. " + f"This is also a batched-decode invariant: seqlen_k must include " + f"the seqlen_q new tokens being emitted." + ) + rem = seqused_k % int(page_size) + bad_rem = (rem > 0) & (rem < seqlen_q_i) + if bool(bad_rem.any().item()): + bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] % page_size ∈ " + f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. " + f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has " + f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). " + f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to " + f"a multiple of {page_size}." + ) + if int(page_size) <= 0: + raise ValueError("page_size must be positive") + if int(seqlen_q) <= 0: + raise ValueError("seqlen_q must be positive") + if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0: + raise ValueError("head counts must be positive") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(num_qo_heads) // int(num_kv_heads) != 16: + raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16") + if int(head_dim) != 128: + raise NotImplementedError("decode schedule currently supports only head_dim=128") + if int(max_seqlen_k) <= 0: + raise ValueError("max_seqlen_k must be positive") + + from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule + + raw = build_decode_schedule( + seqused_k, + page_size=int(page_size), + seqlen_q=int(seqlen_q), + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + head_dim=int(head_dim), + max_seqlen_k=int(max_seqlen_k), + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=0 if max_grid_size is None else int(max_grid_size), + fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size), + disable_split_kv=bool(disable_split_kv), + ) + return DecodeAttentionSchedule( + split_kv=bool(raw["split_kv"]), + cta_tile_q=int(raw["cta_tile_q"]), + num_q_tiles=int(raw["num_q_tiles"]), + kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]), + kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]), + work_count=int(raw["work_count"]), + padded_work_count=int(raw["padded_work_count"]), + partial_rows=int(raw["partial_rows"]), + max_split_count=int(raw["max_split_count"]), + max_grid_size=int(raw["max_grid_size"]), + active_blocks_per_sm=int(raw["active_blocks_per_sm"]), + num_sms=int(raw["num_sms"]), + base_cta=int(raw["base_cta"]), + request_indices=raw["request_indices"], + qo_tile_indices=raw["qo_tile_indices"], + kv_tile_indices=raw["kv_tile_indices"], + merge_indptr=raw["merge_indptr"], + o_indptr=raw["o_indptr"], + block_valid_mask=raw["block_valid_mask"], + kv_pages=raw["kv_pages"], + split_counts=raw["split_counts"], + ) + + +__all__ = [ + "DecodeAttentionSchedule", + "prepare_decode_schedule", +] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa83e39a5504ac6cf8d732255e495e48b35fa20a --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fp4_indexer.py @@ -0,0 +1,1956 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 FP4 sparse-attention indexer kernels.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from ...src.common import pipeline as common_pipeline + + +FP4_FORMAT = Literal["mxfp4", "nvfp4"] +_FP4_PACKED_D_BYTES = 64 +_HEAD_DIM = 128 +_BLOCK_K = 128 +_PAGE_SIZE = 128 +_MMA_TILER_MN = (128, 128) +_MMA_INST_SHAPE_K = 64 +_NON_CAUSAL_K_TILES_PER_CTA = 16 +_CAUSAL_K_TILES_PER_CTA = 16 +_DECODE_PACK_Q_LEN = 8 +_DECODE_QHEAD_PER_KV = 16 +_DECODE_K_TILES_PER_CTA = 16 +_AB_DTYPE = cutlass.Float4E2M1FN + + +@dataclass(frozen=True) +class Fp4FormatSpec: + name: FP4_FORMAT + sf_vec_size: int + scale_groups: int + torch_scale_dtype: torch.dtype + cutlass_scale_dtype: type + + +_FORMAT_SPECS: dict[str, Fp4FormatSpec] = { + "mxfp4": Fp4FormatSpec( + name="mxfp4", + sf_vec_size=32, + scale_groups=4, + torch_scale_dtype=torch.float8_e8m0fnu, + cutlass_scale_dtype=cutlass.Float8E8M0FNU, + ), + "nvfp4": Fp4FormatSpec( + name="nvfp4", + sf_vec_size=16, + scale_groups=8, + torch_scale_dtype=torch.float8_e4m3fn, + cutlass_scale_dtype=cutlass.Float8E4M3FN, + ), +} + + +def normalize_fp4_format(fmt: str) -> Fp4FormatSpec: + key = str(fmt).lower() + try: + return _FORMAT_SPECS[key] + except KeyError as exc: + raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc + + +def ceil_div(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def k_tiles_per_cta_for(causal: bool) -> int: + return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA + + +class Fp4IndexerScaleReorderSm100: + """Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, page_count, heads_k = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = cute.ceil_div(self.scale_groups, 4) + k_l = page_count * heads_k + + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (total_q, heads_q, self.scale_groups), + stride=(heads_q * self.scale_groups, self.scale_groups, 1), + ), + ) + k_scale = cute.make_tensor( + k_scale_ptr, + cute.make_layout( + (page_count, heads_k, _PAGE_SIZE, self.scale_groups), + stride=( + heads_k * _PAGE_SIZE * self.scale_groups, + _PAGE_SIZE * self.scale_groups, + self.scale_groups, + 1, + ), + ), + ) + + q_mma_layout = cute.make_ordered_layout( + (32, 4, rest_q_m, 4, rest_g, heads_q), + order=(2, 1, 4, 0, 3, 5), + ) + k_mma_layout = cute.make_ordered_layout( + (32, 4, 1, 4, rest_g, k_l), + order=(2, 1, 4, 0, 3, 5), + ) + q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout) + k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout) + q_scale_mma = cute.group_modes(q_scale_mma, 0, 3) + q_scale_mma = cute.group_modes(q_scale_mma, 1, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 0, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 1, 3) + + q_scale_count = total_q * heads_q * Int32(self.scale_groups) + k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups) + total_scale_count = q_scale_count + k_scale_count + grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta) + self.kernel( + q_scale, + k_scale, + q_scale_mma, + k_scale_mma, + heads_q, + heads_k, + q_scale_count, + total_scale_count, + ).launch( + grid=(grid_ctas, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + q_scale: cute.Tensor, + k_scale: cute.Tensor, + q_scale_mma: cute.Tensor, + k_scale_mma: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + q_scale_count: Int32, + total_scale_count: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + block_idx, _, _ = cute.arch.block_idx() + grid_dim, _, _ = cute.arch.grid_dim() + linear = block_idx * Int32(self.threads_per_cta) + tidx + stride = grid_dim * Int32(self.threads_per_cta) + + while linear < total_scale_count: + if linear < q_scale_count: + group = linear % Int32(self.scale_groups) + tmp = linear // Int32(self.scale_groups) + head = tmp % heads_q + row = tmp // heads_q + q_scale_mma[row, group, head] = q_scale[row, head, group] + else: + k_linear = linear - q_scale_count + group = k_linear % Int32(self.scale_groups) + tmp = k_linear // Int32(self.scale_groups) + row = tmp % Int32(_PAGE_SIZE) + tmp = tmp // Int32(_PAGE_SIZE) + head = tmp % heads_k + page = tmp // heads_k + scale_l = page * heads_k + head + k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group] + linear += stride + + +class Fp4IndexerStagedMmaSm100: + """Single-kernel FP4 indexer for preordered MMA scale storage.""" + + def __init__( + self, + *, + fmt: str, + causal: bool, + preordered_q_scale_tma: bool = False, + compact_schedule: bool = False, + use_tmem_load_red: bool = False, + ): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.preordered_q_scale_tma = bool(preordered_q_scale_tma) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal) + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + m, + _, + k, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + compact_task_count, + ) = problem_size + page_count = lk // heads_k + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (total_q, _HEAD_DIM, heads_q), + stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (total_q, _HEAD_DIM, heads_q), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor( + kv_indices_ptr, + cute.make_layout((page_count,), stride=(1,)), + ) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + if const_expr(self.preordered_q_scale_tma): + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + else: + tma_qs = tma_q + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0]) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid_x = compact_task_count + else: + grid_x = grid_q_tiles * grid_k_groups + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + q_scale_tensor, + k_scale_tensor, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + has_qo_offset, + max_k_tiles, + grid_k_groups, + ).launch( + grid=(grid_x, batch * heads_q, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset + return True + + @cute.jit + def _full_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.jit + def _partial_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + q_len: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mQS: cute.Tensor, + mKS: cute.Tensor, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + k_group_count: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + lane_idx = cute.arch.lane_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_idx, q_l, _ = cute.arch.block_idx() + batch_idx = q_l // heads_q + hq = q_l - batch_idx * heads_q + hk = hq // (heads_q // heads_k) + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + task_valid = True + q_tile_idx = Int32(0) + ktile_group = Int32(0) + if const_expr(self.compact_schedule): + remaining = task_idx + q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0]) + batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + q_scan = Int32(0) + task_valid = False + while q_scan < q_tile_count and not task_valid: + q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0]) + q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_scan_last >= q_len: + q_scan_last = q_len - Int32(1) + visible_limit = q_scan_last + causal_offset + visible_group_count = Int32(0) + if visible_limit >= Int32(0): + visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1) + if visible_group_count > batch_k_group_count: + visible_group_count = batch_k_group_count + task_valid = remaining < visible_group_count + if not task_valid: + remaining -= visible_group_count + q_scan += Int32(1) + if task_valid: + q_tile_idx = q_scan + ktile_group = remaining + else: + q_len = Int32(0) + k_len = Int32(0) + else: + q_tile_idx = task_idx // k_group_count + ktile_group = task_idx - q_tile_idx * k_group_count + q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0]) + q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_tile_last >= q_len: + q_tile_last = q_len - Int32(1) + q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len + q_tile_global_start = q_begin + q_tile_start + q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128) + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_tile_start, + q_tile_last, + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCsQ = thr_mma.partition_A(sQ_public) + tCsK = thr_mma.partition_B(sK_public) + mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma) + gQ_tma = cute.local_tile( + mQ_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + if const_expr(self.preordered_q_scale_tma): + mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma) + gQS_tma = cute.local_tile( + mQS_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + sQS = sQS_public + sKS = sKS_public + + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + if const_expr(self.preordered_q_scale_tma): + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_tma_copy_bytes, + defer_sync=True, + ).make_participants() + if const_expr(self.preordered_q_scale_tma): + qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.qs_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=qs_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + if warp_idx == self.load_warp_id: + if group_has_visible: + q_empty = q_producer.acquire_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_empty = qs_producer.acquire_and_advance() + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, q_tile_idx, 0, hq)], + tQsQS_tma[(None, qs_empty.index)], + tma_bar_ptr=qs_empty.barrier, + ) + qs_empty.commit() + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + cute.copy( + tma_q.atom, + tQgQ_tma[(None, q_tile_idx, 0, hq)], + tQsQ_tma[(None, q_empty.index)], + tma_bar_ptr=q_empty.barrier, + ) + q_empty.commit() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Move block scales into TMEM and issue one FP4 GEMM per visible K tile. + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_full = q_consumer.wait_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_full = qs_consumer.wait_and_advance() + qs_full.release() + q_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + ktile = Int32(0) + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx == self.load_warp_id: + if group_has_visible: + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Load accumulators from TMEM, reduce per-row max, and store scores. + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + q_local_store0 = q_tile_start + epi_tidx + q_global_store0 = q_begin + q_local_store0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta) + q_global_store1 = q_begin + q_local_store1 + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + tile_full = q_tile_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + if tile_mask_free: + if tile_full: + if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if coord_m == epi_tidx and q_local < q_len and k_local < k_len: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len: + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + if tile_full: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._full_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._full_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._partial_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._partial_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + if q_tile_full: + mScores[hq, ktile, q_global_store0] = row_max0 + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = row_max0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if q_tile_full: + mScores[hq, ktile, q_global_store1] = row_max1 + elif q_local_store1 < q_len: + mScores[hq, ktile, q_global_store1] = row_max1 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner) + if ktile < max_k_tiles: + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) + +class Fp4IndexerDecodeQPackSm100: + """Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, heads_k, batch = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = ceil_div(self.scale_groups, 4) + q = cute.make_tensor( + q_ptr, + cute.make_layout( + (total_q, heads_q, _FP4_PACKED_D_BYTES), + stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (heads_q, rest_q_m, rest_g, 32, 4, 4), + stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + q_pack_l = batch * heads_k + q_pack = cute.make_tensor( + q_pack_ptr, + cute.make_layout( + (q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale_pack = cute.make_tensor( + q_scale_pack_ptr, + cute.make_layout( + (q_pack_l, 1, rest_g, 32, 4, 4), + stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,))) + self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch( + grid=(q_pack_l, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQS: cute.Tensor, + mQPack: cute.Tensor, + mQSPack: cute.Tensor, + mCuQ: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + q_pack_l, _, _ = cute.arch.block_idx() + batch_idx = q_pack_l // heads_k + hk = q_pack_l - batch_idx * heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + q_len = q_end - q_begin + qhead_per_kv = heads_q // heads_k + + linear = tidx + while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES): + row = linear // Int32(_FP4_PACKED_D_BYTES) + byte = linear - row * Int32(_FP4_PACKED_D_BYTES) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + if q_local < q_len and h_in_group < qhead_per_kv: + mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte] + else: + mQPack[q_pack_l, row, byte] = cutlass.Uint8(0) + linear += Int32(self.threads_per_cta) + + scale_linear = tidx + while scale_linear < Int32(_PAGE_SIZE * self.scale_groups): + row = scale_linear // Int32(self.scale_groups) + group = scale_linear - row * Int32(self.scale_groups) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + q_abs = q_begin + q_local + if q_local >= q_len or h_in_group >= qhead_per_kv: + q_abs = q_begin + hq = hk * qhead_per_kv + src_rest_m = q_abs // Int32(128) + src_row = q_abs - src_rest_m * Int32(128) + src_row_atom = src_row % Int32(32) + src_row_major = src_row // Int32(32) + dst_row_atom = row % Int32(32) + dst_row_major = row // Int32(32) + rest_g = group // Int32(4) + group_in_rest = group - rest_g * Int32(4) + mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[ + hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest + ] + scale_linear += Int32(self.threads_per_cta) + + +class Fp4IndexerDecodePackedQSm100: + """Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``.""" + + def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + @cute.jit + def __call__( + self, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + _, + _, + _, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + ) = problem_size + page_count = lk // heads_k + q_pack_l = batch * heads_k + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_pack_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,))) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid = (compact_k_groups, heads_k, 1) + else: + grid = (grid_k_groups, batch * heads_k, 1) + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + batch, + has_qo_offset, + max_k_tiles, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset + return True + + @cute.jit + def _packed_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + h_in_group: Int32, + qhead_per_kv: Int32, + q_local: Int32, + q_len: Int32, + k_local: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + batch: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_x, task_y, _ = cute.arch.block_idx() + task_valid = True + batch_idx = Int32(0) + hk = Int32(0) + ktile_group = Int32(0) + q_l = Int32(0) + if const_expr(self.compact_schedule): + hk = task_y + group_base = Int32(0) + scan_batch = Int32(0) + task_valid = False + while scan_batch < batch and not task_valid: + batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch] + batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + task_valid = task_x < group_base + batch_groups + if not task_valid: + group_base += batch_groups + scan_batch += Int32(1) + if task_valid: + batch_idx = scan_batch + ktile_group = task_x - group_base + q_l = batch_idx * heads_k + hk + else: + ktile_group = task_x + q_l = task_y + batch_idx = q_l // heads_k + hk = q_l - batch_idx * heads_k + qhead_per_kv = heads_q // heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + if const_expr(self.compact_schedule): + if not task_valid: + q_len = Int32(0) + k_len = Int32(0) + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + gQ_tma = cute.local_tile( + mQ_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + gQS_tma = cute.local_tile( + mQS_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + if warp_idx == self.load_warp_id: + if group_has_visible: + q_pair_empty = q_producer.acquire_and_advance() + cute.copy( + tma_q.atom, + tQgQ_tma[(None, 0, 0, q_l)], + tQsQ_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, 0, 0, q_l)], + tQsQS_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + q_pair_empty.commit() + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS_public) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS_public) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_pair_full = q_consumer.wait_and_advance() + q_pair_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN) + q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN) + h_global_store = hk * qhead_per_kv + h_store + q_global_store = q_begin + q_local_store + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN) + tile_full = q_pack_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + if tile_mask_free and tile_full: + if const_expr(self.use_tmem_load_red): + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN) + q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN) + k_local = ktile * Int32(_BLOCK_K) + coord_n + valid = self._packed_coord_visible( + coord_m, + epi_tidx, + h_in_group, + qhead_per_kv, + q_local, + q_len, + k_local, + k_len, + causal_offset, + ) + if valid: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = row_max0 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b99aea3f8b4915c03fe8147127374d920970f3 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 forward kernels and combine paths.""" + +from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 + +__all__ = ["SparseAttentionForwardNvfp4KvSm100"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..531b27c9e6b4bd8c1bc74fb1f92ed98a192ca0b2 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd.py @@ -0,0 +1,3020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- Sparse Attention with flat varlen K/V +- Sparse Page Attention with paged K/V +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardSm100: + """SM100 sparse attention forward kernel.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + qk_dtype=None, + pv_dtype=None, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.qk_dtype_param = qk_dtype + self.pv_dtype_param = pv_dtype + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P dtype follows the PV operand policy and is packed into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mV: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_input_dtype = mK.element_type + self.v_input_dtype = mV.element_type + self.qk_dtype = ( + self.q_dtype if const_expr(self.qk_dtype_param is None) else self.qk_dtype_param + ) + if const_expr(self.pv_dtype_param is None): + legacy_fp8_kv_cache = ( + self.q_dtype == cutlass.BFloat16 + and self.k_input_dtype == cutlass.Float8E4M3FN + and self.v_input_dtype == cutlass.Float8E4M3FN + ) + self.pv_dtype = cutlass.BFloat16 if legacy_fp8_kv_cache else self.v_input_dtype + else: + self.pv_dtype = self.pv_dtype_param + self.k_dtype = self.qk_dtype + self.v_dtype = self.pv_dtype + self.p_dtype = self.pv_dtype + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported Q/K/V dtype: {self.q_dtype}") + if const_expr(self.qk_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported qk_dtype: {self.qk_dtype}") + if const_expr(self.pv_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported pv_dtype: {self.pv_dtype}") + if const_expr(self.q_dtype != self.qk_dtype): + raise TypeError("Q storage dtype must match qk_dtype") + if const_expr( + self.k_input_dtype != self.k_dtype + and not (self.k_input_dtype == cutlass.Float8E4M3FN and self.k_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 K -> BF16 QK staging is supported") + if const_expr( + self.v_input_dtype != self.v_dtype + and not (self.v_input_dtype == cutlass.Float8E4M3FN and self.v_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 V -> BF16 PV staging is supported") + self.k_fp8_to_bf16 = ( + self.k_input_dtype == cutlass.Float8E4M3FN + and self.k_dtype == cutlass.BFloat16 + ) + self.v_fp8_to_bf16 = ( + self.v_input_dtype == cutlass.Float8E4M3FN + and self.v_dtype == cutlass.BFloat16 + ) + self.kv_fp8_to_bf16 = self.k_fp8_to_bf16 or self.v_fp8_to_bf16 + self.qk_mma_kind = "f8f6f4" if const_expr(self.qk_dtype.width == 8) else "f16" + self.pv_mma_kind = "f8f6f4" if const_expr(self.pv_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.p_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV = [assume_tensor_aligned(t) for t in (mK, mV)] + + if const_expr(not self.paged_kv): + # Flat varlen K/V use CUTE-managed TMA descriptors, matching FA: + # K: [total_k, h, d] -> [total_k, d, h]. + # V: [total_k, h, d] -> [d, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Sparse Page Attention with page-sized blocks can use the blocked + # paged TMA layout directly. Host input is [page, head, token, dim]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d,h,b) -> (d,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp8_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim), + stride=(self.head_dim, 1), + ), + cute.make_layout((1,)), + ) + sV_fp8_layout = cute.append( + cute.make_layout( + (self.head_dim, self.n_block_size), + stride=(1, self.head_dim), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.p_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms + # ------------------------------------------------------------------ + k_tma_layout = ( + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2]) + ) + v_tma_layout = ( + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2]) + ) + kv_tma_bytes = ( + cute.size_in_bytes(self.k_input_dtype, k_tma_layout) + + cute.size_in_bytes(self.v_input_dtype, v_tma_layout)) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + if const_expr(self.k_fp8_to_bf16): + tma_atom_K, mK = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp8_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim), + ) + else: + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + if const_expr(self.v_fp8_to_bf16): + tma_atom_V, mV = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp8_layout, mode=[0, 1]), + (self.head_dim, self.n_block_size), + ) + else: + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for unified kernel signature. Small-GQA Q load + # uses raw gather4 and keeps mQ_2d as a plain row-major GMEM tensor. + tma_atom_Q = tma_atom_V + else: + tma_atom_Q, mQ_2d = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + if const_expr(self.k_fp8_to_bf16): + mbar_k_tma: cute.struct.MemRange[Int64, 2] + if const_expr(self.v_fp8_to_bf16): + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + if const_expr(self.k_fp8_to_bf16): + sKFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.k_input_dtype, cute.cosize(sK_fp8_layout) + ], + self.buffer_align_bytes] + if const_expr(self.v_fp8_to_bf16): + sVFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.v_input_dtype, cute.cosize(sV_fp8_layout) + ], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp8_layout, sV_fp8_layout, tP_layout, + tma_atom_K, tma_atom_V, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + kv_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + tma_K: cute.Tensor, + tma_V: cute.Tensor, + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp8_layout: cute.Layout, + sV_fp8_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atoms + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + kv_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + if const_expr(self.k_fp8_to_bf16): + sKFp8 = storage.sKFp8.get_tensor(sK_fp8_layout) + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + if const_expr(self.v_fp8_to_bf16): + sVFp8 = storage.sVFp8.get_tensor(sV_fp8_layout) + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_tma_bytes = cute.size_in_bytes( + self.k_input_dtype, + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2])) + v_tma_bytes = cute.size_in_bytes( + self.v_input_dtype, + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + if const_expr(self.k_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_k_ptr, k_tma_bytes) + if const_expr(self.v_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_v_ptr, v_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + if const_expr(self.kv_fp8_to_bf16): + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + if const_expr(self.k_fp8_to_bf16): + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if const_expr(self.v_fp8_to_bf16): + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if warp_idx == Int32(self.total_warps - 1): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + if const_expr(self.kv_fp8_to_bf16): + self._wg_load_kv_maybe_cast( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + sKFp8 if const_expr(self.k_fp8_to_bf16) else None, + sVFp8 if const_expr(self.v_fp8_to_bf16) else None, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + mbar_k_tma_ptr if const_expr(self.k_fp8_to_bf16) else None, + mbar_v_tma_ptr if const_expr(self.v_fp8_to_bf16) else None, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + else: + self._wg_load_kv( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.k_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sKFp8, + sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + False, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.v_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sVFp8, + sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + True, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _convert_fp8x16_to_bf16x16( + self, + src: cute.Tensor, + dst: cute.Tensor, + ): + src_i32 = cute.recast_tensor(src, cutlass.Int32) + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(4): + ( + dst_i32[word_idx * 2], + dst_i32[word_idx * 2 + 1], + ) = utils.cvt_fp8x4_e4m3_bf16x4(src_i32[word_idx]) + + @cute.jit + def _convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + elems_per_load: cutlass.Constexpr[int] = 16 + elems_per_store: cutlass.Constexpr[int] = 8 + chunks_per_row: cutlass.Constexpr[int] = self.head_dim // elems_per_load + r_fp8 = cute.make_rmem_tensor((elems_per_load,), cutlass.Float8E4M3FN) + r_bf16 = cute.make_rmem_tensor((elems_per_load,), cutlass.BFloat16) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * chunks_per_row + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(chunks_per_row) + chunk = task_idx - row * Int32(chunks_per_row) + col = chunk * Int32(elems_per_load) + smem_offset = row * Int32(self.head_dim) + col + s_fp8_ptr = cute.make_ptr( + cutlass.Float8E4M3FN, + sFp8.iterator.toint() + Int64(smem_offset), + mem_space=sFp8.iterator.memspace, + assumed_align=elems_per_load, + ) + s_fp8_vec = cute.make_tensor( + s_fp8_ptr, + cute.make_layout(elems_per_load), + ) + cute.autovec_copy(s_fp8_vec, r_fp8) + self._convert_fp8x16_to_bf16x16(r_fp8, r_bf16) + if const_expr(is_v): + sBf16_view = sBf16[(None, row % Int32(16)), 0, row // Int32(16), 0] + sBf16_vec = cute.local_tile(sBf16_view, (elems_per_load,), (chunk,)) + else: + sBf16_vec = sBf16[ + (row, None), + 0, + (chunk % Int32(4), chunk // Int32(4)), + 0, + ] + r_tiles = cute.logical_divide(r_bf16, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sBf16_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_load // elems_per_store): + cute.autovec_copy(r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv_maybe_cast( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sKFp8: Optional[cute.Tensor], + sVFp8: Optional[cute.Tensor], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + mbar_k_tma_ptr: Optional[cutlass.Pointer], + mbar_v_tma_ptr: Optional[cutlass.Pointer], + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.k_fp8_to_bf16): + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, + 0, + cute.make_layout(1), + gK, + sKFp8, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + else: + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.v_fp8_to_bf16): + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + gV, + sVFp8, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + else: + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + mbar_tma_ptr, + mbar_ready_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + if has_work: + cute.arch.mbarrier_wait(mbar_tma_ptr, 0) + self._convert_fp8_kv_to_bf16_smem( + sFp8, + sBf16, + lane, + warp_idx_in_wg, + num_dequant_warps, + is_v, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_ready_ptr) + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if producer_warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.p_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.p_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (p_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / p_dtype.width`` packed fp32 TMEM columns + # ``// (p_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.p_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1a8d6bf92b16d2943aa5e40fd91e26224ac40 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py @@ -0,0 +1,3305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel with NVFP4 K/V. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- BF16 Q +- packed NVFP4 K/V data +- E4M3 per-1x16 K/V scales in cuBLAS/cuDNN 128x4 tiled layout +- FP32 per-tensor K/V global scales +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardNvfp4KvSm100: + """SM100 sparse attention forward kernel with NVFP4 K/V.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + fp8_pair_dequant: bool = True, + has_k_global_scale: bool = True, + has_v_global_scale: bool = True, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardNvfp4KvSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardNvfp4KvSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.fp8_pair_dequant = fp8_pair_dequant + self.has_k_global_scale = has_k_global_scale + self.has_v_global_scale = has_v_global_scale + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardNvfp4KvSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P is bf16 and starts halfway into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mV: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mKScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened K rows and dim/16 cols + mVScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened V rows and dim/16 cols + mKGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mVGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_cache_dtype = mK.element_type + self.v_cache_dtype = mV.element_type + self.k_scale_dtype = mKScale.element_type + self.v_scale_dtype = mVScale.element_type + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"KVFP4 forward requires BF16 or FP8 E4M3 Q, got {self.q_dtype}") + self.k_dtype = self.q_dtype + self.v_dtype = self.q_dtype + if const_expr(self.k_cache_dtype is not cutlass.Uint8 or self.v_cache_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects packed uint8 K/V, got {self.k_cache_dtype}, {self.v_cache_dtype}" + ) + if const_expr(self.k_scale_dtype is not cutlass.Uint8 or self.v_scale_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects uint8 E4M3 scales, got {self.k_scale_dtype}, {self.v_scale_dtype}" + ) + if const_expr(self.has_k_global_scale and mKGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 K global scale") + if const_expr(self.has_v_global_scale and mVGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 V global scale") + self.mma_kind = "f8f6f4" if const_expr(self.q_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.q_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV, mKScale, mVScale = [ + assume_tensor_aligned(t) for t in (mK, mV, mKScale, mVScale) + ] + + if const_expr(not self.paged_kv): + # Flat varlen K/V: + # K: [total_k, h, d/2] -> [total_k, d/2, h]. + # V: [total_k, h, d/2] -> [d/2, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Host input is [page, head, token, dim/2]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d/2,h,b) -> (d/2,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp4_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim // 2), + stride=(self.head_dim // 2, 1), + ), + cute.make_layout((1,)), + ) + sV_fp4_layout = cute.append( + cute.make_layout( + (self.head_dim // 2, self.n_block_size), + stride=(1, self.head_dim // 2), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms. Packed FP4 K/V are staged by TMA, then dequantized into + # BF16 MMA SMEM layout by the KV load warps. + # ------------------------------------------------------------------ + k_fp4_tma_bytes = cute.size_in_bytes( + self.k_cache_dtype, cute.select(sK_fp4_layout, mode=[0, 1])) + v_fp4_tma_bytes = cute.size_in_bytes( + self.v_cache_dtype, cute.select(sV_fp4_layout, mode=[0, 1])) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_atom_K_fp4, mK_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp4_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim // 2), + ) + tma_atom_V_fp4, mV_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp4_layout, mode=[0, 1]), + (self.head_dim // 2, self.n_block_size), + ) + mK = mK_tma + mV = mV_tma + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for the unified kernel signature. Small-GQA Q + # loading uses raw gather4, so mQ_2d must stay as the plain GMEM + # tensor. The placeholder uses the natural SMEM top-level shape. + tma_atom_Q, _ = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (8, q_load_tile)) + else: + tma_atom_Q, mQ_2d_tma = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + mQ_2d = mQ_2d_tma + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + mbar_k_tma: cute.struct.MemRange[Int64, 2] + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + sKFp4: cute.struct.Align[ + cute.struct.MemRange[self.k_cache_dtype, cute.cosize(sK_fp4_layout)], + self.buffer_align_bytes] + sVFp4: cute.struct.Align[ + cute.struct.MemRange[self.v_cache_dtype, cute.cosize(sV_fp4_layout)], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mKScale, mVScale, mKGlobalScale, mVGlobalScale, + mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp4_layout, sV_fp4_layout, tP_layout, + tma_atom_K_fp4, tma_atom_V_fp4, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + k_fp4_tma_bytes, v_fp4_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp4_layout: cute.Layout, + sV_fp4_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atom + tma_atom_K_fp4: cute.CopyAtom, + tma_atom_V_fp4: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + k_fp4_tma_bytes: cutlass.Constexpr[int], + v_fp4_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sKFp4 = storage.sKFp4.get_tensor(sK_fp4_layout) + sVFp4 = storage.sVFp4.get_tensor(sV_fp4_layout) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_smem_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + v_smem_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_fp4_tma_bytes) + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_fp4_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if ( + warp_idx == Int32(self.total_warps - 1) + and warp_idx >= Int32(self.kv_load_warp_base + self.num_kv_load_warps) + ): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + q_group_start = Int32(0) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + self._wg_load_kv( + tma_atom_K_fp4, tma_atom_V_fp4, + mK, mV, + mKScale, mVScale, + mKGlobalScale, mVGlobalScale, + sPagedKvIdx, + sKFp4, sVFp4, sK, sV, + mbar_k_tma_ptr, mbar_v_tma_ptr, + mbar_k_ptr, mbar_v_ptr, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + num_heads_kv, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_k_from_tma_staging( + mKScale, + mKGlobalScale, + sPagedKvIdx, + sKFp4, sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_v_from_tma_staging( + mVScale, + mVGlobalScale, + sPagedKvIdx, + sVFp4, sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _scale_128x4_offset( + self, + row: Int32, + col: Int32, + scale_cols: cutlass.Constexpr[int], + ) -> Int32: + tiles_n: cutlass.Constexpr[int] = (scale_cols + 3) // 4 + tile_m = row // Int32(128) + tile_n = col // Int32(4) + outer = row % Int32(128) + inner = col % Int32(4) + return ( + (tile_m * Int32(tiles_n) + tile_n) * Int32(512) + + (outer % Int32(32)) * Int32(16) + + (outer // Int32(32)) * Int32(4) + + inner + ) + + @cute.jit + def _load_scale_bf16x2( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return utils.cvt_fp8_e4m3_to_bf16x2_replicated(cutlass.Int32(scale_byte)) + + @cute.jit + def _load_scale_e4m3_u8( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return cutlass.Int32(scale_byte) + + @cute.jit + def _dequant_fp4x16_to_bf16( + self, + src_words: cute.Tensor, + combined_scale_bf16x2: Int32, + dst: cute.Tensor, + ): + r_bf16 = cute.make_rmem_tensor((2,), cutlass.BFloat16) + r_bf16_i32 = cute.recast_tensor(r_bf16, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3 = utils.cvt_fp4x8_e2m1_bf16x8( + src_words[word_idx] + ) + bf16_pairs = (bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3) + for pair_idx in cutlass.range_constexpr(4): + r_bf16_i32[0] = utils.mul_bf16x2( + bf16_pairs[pair_idx], + combined_scale_bf16x2, + ) + dst[word_idx * 8 + 2 * pair_idx + 0] = r_bf16[0] + dst[word_idx * 8 + 2 * pair_idx + 1] = r_bf16[1] + + @cute.jit + def _dequant_fp4x16_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + + @cute.jit + def _dequant_fp4x32_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3_lo: Int32, + scale_e4m3_hi: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3_lo, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx + 2], + scale_e4m3_hi, + ) + dst_i32[word_idx * 2 + 4] = fp8_lo + dst_i32[word_idx * 2 + 5] = fp8_hi + + @cute.jit + def _flat_kv_scale_row( + self, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return token_idx * num_heads_kv + head_kv_idx + + @cute.jit + def _paged_kv_scale_row( + self, + page_idx: Int32, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return (page_idx * num_heads_kv + head_kv_idx) * Int32(self.page_size) + token_idx + + @cute.jit + def _load_k_fp4_to_smem( + self, + sKFp4: cute.Tensor, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mKScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sK_vec = sK[(row, None), 0, pair_col, 0] + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.k_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.k_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.k_dtype, + num_bits_per_copy=elems_per_store * self.k_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + else: + combined_bf16x2 = self._load_scale_bf16x2(mKScale, scale_row, scale_col) + if const_expr(self.has_k_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mKGlobalScale[0], + mKGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + sK_cols = sK[(row, None), 0, scale_col // Int32(2), 0] + sK_vec = cute.local_tile( + sK_cols, + (elems_per_block,), + (scale_col % Int32(2),), + ) + else: + sK_vec = sK[ + (row, None), + 0, + (scale_col % Int32(4), scale_col // Int32(4)), + 0, + ] + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _load_v_fp4_to_smem( + self, + sVFp4: cute.Tensor, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sV: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mVScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_pair,), (pair_col,)) + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.v_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.v_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.v_dtype, + num_bits_per_copy=elems_per_store * self.v_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + combined_bf16x2 = self._load_scale_bf16x2(mVScale, scale_row, scale_col) + if const_expr(self.has_v_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mVGlobalScale[0], + mVGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + else: + sV_cols = sV[(None, row % Int32(16)), 0, row // Int32(16), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_block,), (scale_col,)) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K_fp4, + tma_atom_V_fp4, + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sVFp4: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mbar_k_tma_ptr, + mbar_v_tma_ptr, + mbar_k_ptr, + mbar_v_ptr, + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.paged_kv): + mK_cur = mK[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + mK[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K_fp4, + 0, + cute.make_layout(1), + gK, + sKFp4, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.paged_kv): + mV_cur = mV[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + mV[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V_fp4, + 0, + cute.make_layout(1), + gV, + sVFp4, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_dequant_k_from_tma_staging( + self, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sK: cute.Tensor, + mbar_k_tma_ptr, + mbar_k_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_k_tma_ptr, 0) + self._load_k_fp4_to_smem( + sKFp4, + mKScale, + mKGlobalScale, + sPagedKvIdx, + sK, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + @cute.jit + def _wg_dequant_v_from_tma_staging( + self, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sVFp4: cute.Tensor, + sV: cute.Tensor, + mbar_v_tma_ptr, + mbar_v_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_v_tma_ptr, 0) + self._load_v_fp4_to_smem( + sVFp4, + mVScale, + mVGlobalScale, + sPagedKvIdx, + sV, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if const_expr(do_final_acquire) and producer_warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if const_expr(do_final_acquire) and warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_k_global_scale + ): + k_global = mKGlobalScale[0] + for i in cutlass.range_constexpr(0, cute.size(tSrS_t2r.shape), 2): + tSrS_t2r[i], tSrS_t2r[i + 1] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[i], tSrS_t2r[i + 1]), + (k_global, k_global), + ) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.q_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (q_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / q_dtype.width`` packed fp32 TMEM columns + # ``// (q_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.q_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + mVGlobalScale, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + mVGlobalScale: Optional[cute.Tensor], + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..a3894130432f6483291fe23c064efa7369f6d509 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd/combine.py @@ -0,0 +1,1498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse forward combine kernel and public launcher. + +This keeps the local fake-layout -> real-layout epilogue needed by the lean +sparse forward path. +""" + +# Modified Step 7: O_out write with SMEM fake->real column permutation. +# O_partial dim is in STG.128 fake layout; O_out dim is real layout. +import math +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, Int64, Boolean, const_expr + +from ....src.common import utils +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + +from ....src.common.pack_gqa import PackGQAComb +from ....src.common.tma_utils import ( + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, +) + + +class SparseAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + tile_m: int = 8, + k_block_size: int = 64, + topk: int = 16, + num_threads: int = 256, + stages: int = 4, + use_pdl: bool = False, + min_blocks_per_mp: int = 0, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param tile_m: m block size + :param k_block_size: k block size + :param topk: exact number of split partials + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.topk = topk + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + self.use_pdl = use_pdl + self.min_blocks_per_mp = min_blocks_per_mp + self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16) + self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + tile_m, + k_block_size, + topk, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [ + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + Float32, + ]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if tile_m % 8 != 0: + return False + if topk > 256: + return False + if (tile_m * topk) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store). + # Keep this independent from O_partial: fp8 partial uses 16 elements + # per 128b transaction, while bf16/fp16 O stores must remain 8-wide. + output_copy_elems = universal_copy_bits // self.dtype.width + assert self.k_block_size % output_copy_elems == 0 + gmem_threads_per_row_o = k_block_gmem // output_copy_elems + assert self.num_threads % gmem_threads_per_row_o == 0 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o), + order=(1, 0), + ) + vO_layout = cute.make_layout((1, output_copy_elems)) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tO_layout, + vO_layout, + ) + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.topk, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1) + ) + + # O_partial staging layout. + if const_expr( + self.dtype_partial + in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN] + ): + smem_layout_atom_o = _get_cpasync_smem_layout_atom( + self.dtype_partial, self.k_block_size + ) + self.smem_layout_o = cute.tile_to_shape( + smem_layout_atom_o, + (self.tile_m, self.k_block_size, self.stages), + (0, 1, 2), + ) + else: + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + mLSE_temperature_partial: Optional[cute.Tensor] = None, + mLSE_temperature: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + mSplitCounts: Optional[cute.Tensor] = None, + mOutputScale: Optional[cute.Tensor] = None, + qhead_per_kvhead: Int32 = Int32(1), + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr( + mLSE_temperature_partial is not None + and mLSE_temperature_partial.element_type not in [Float32] + ): + raise TypeError("temperature LSE partial tensor must be Float32") + if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]): + raise TypeError("temperature LSE tensor must be Float32") + if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)): + raise ValueError( + "temperature LSE partial and output tensors must either both be provided or both be None" + ) + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]): + raise ValueError( + "temperature LSE partial tensor must have 3 or 4 dimensions: " + "(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]): + raise ValueError( + "temperature LSE tensor must have 2 or 3 dimensions: " + "(batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mSplitCounts is not None): + if const_expr(mSplitCounts.element_type not in [Int32]): + raise TypeError("split_counts tensor must be Int32") + if const_expr(cu_seqlens is not None): + if const_expr(len(mSplitCounts.shape) != 2): + raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)") + elif const_expr(len(mSplitCounts.shape) != 3): + raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)") + if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]): + raise TypeError("output_scale tensor must be Float32") + + mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b) + # Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads. + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + mLSE_temperature_partial = ( + cute.make_tensor( + mLSE_temperature_partial.iterator, + cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose), + ) + if mLSE_temperature_partial is not None + else None + ) + mLSE_temperature = ( + cute.make_tensor( + mLSE_temperature.iterator, + cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose), + ) + if mLSE_temperature is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + # Output-dtype permutation buffer for Step 7 (tile_m × k_block_size). + # Accumulation stays fp32; the final dtype conversion happens before + # the fake→real SMEM scatter to reduce half-output SMEM pressure. + if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]): + smem_layout_perm = cute.make_layout( + (self.tile_m, self.k_block_size), + stride=(self.k_block_size + 16, 1), + ) + else: + smem_layout_perm = cute.make_ordered_layout( + (self.tile_m, self.k_block_size), order=(1, 0) + ) + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sLSETemperature: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + sO_perm: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch) + # Head separated from seqlen → enables future TMA (contiguous Sq tiles) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + varlen_batch_idx, + semaphore_to_reset, + mSplitCounts, + mOutputScale, + qhead_per_kvhead, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + smem_layout_perm, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + self.use_pdl, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + min_blocks_per_mp=self.min_blocks_per_mp, + use_pdl=self.use_pdl, + ) + + @cute.jit + def decode_flat_row_idx( + self, + idx: Int32, + head_divmod: FastDivmodDivisor, + ): + """Decode flattened tile rows under the H_q-innermost contract.""" + q_idx_local, head_idx = divmod(idx, head_divmod) + return q_idx_local, head_idx + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSE_temperature_partial: Optional[cute.Tensor], + mLSE_temperature: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + mSplitCounts: Optional[cute.Tensor], + mOutputScale: Optional[cute.Tensor], + qhead_per_kvhead: Int32, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout | cute.ComposedLayout, + smem_layout_perm: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + use_pdl: cutlass.Constexpr[bool], + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm) + + # Handle semaphore reset — wait for dependent grids first + if const_expr(use_pdl and semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 + ): + cute.arch.griddepcontrol_wait() + semaphore_to_reset[0] = 0 + + if const_expr(num_splits_dynamic_ptr is not None): + raise ValueError("K2 combine requires compile-time exact topK") + num_splits = Int32(self.topk) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + output_scale = Float32(1.0) + if const_expr(mOutputScale is not None): + output_scale = mOutputScale[0] + + if const_expr(not varlen) or m_block * self.tile_m < max_idx: + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + if const_expr(use_pdl): + cute.arch.griddepcontrol_wait() + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + # `cLSE` (identity tensor for row/split coord tracking) is reused + # later in steps 4-5, so it must be defined on both branches. + cLSE = cute.make_identity_tensor((self.topk, self.tile_m)) + # Reshape mLSE_partial to PackGQA packed layout and delegate the + # tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq) + # into one compound dim with H_q innermost (stride 1), so thread + # rows that vary along h_pos produce one-sector coalesced reads. + # Non-varlen path only — varlen keeps the original inline loop. + if const_expr(not varlen): + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + # mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose + # [3,0,2,1] on [topK,B,Sq,H_q] and dropping B. + # Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim: + mLSE_partial_reord = cute.make_tensor( + mLSE_partial_cur.iterator, + cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2) + # shape ((H_q, Sq), topK) with H_q innermost. + packgqa = PackGQAComb( + m_block_size=self.tile_m, + head_dim_padded=0, # unused for LSE load + check_hdim_oob=False, # unused for LSE load + qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly + ) + packgqa.load_LSE( + mLSE_partial_packed, + sLSE, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_reord = cute.make_tensor( + mLSE_temperature_partial_cur.iterator, + cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_temperature_partial_packed = cute.group_modes( + mLSE_temperature_partial_reord, 0, 2) + packgqa.load_LSE( + mLSE_temperature_partial_packed, + sLSE_temperature, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + else: + # Varlen path keeps the same H_q-innermost flat-row contract: + # after transpose [1, 0, 2], mLSE_partial_cur is + # (q_local, split, head). + # mSplitCounts is the authoritative valid-split count per + # packed (q_abs, kv_head); masked splits stay at -inf and + # therefore drop out of the final kernel LSE_out reduction. + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_copy = cute.tiled_divide( + mLSE_temperature_partial_cur, (1,)) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + row_count = ( + mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead] + if const_expr(mSplitCounts is not None) + else num_splits + ) + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur_copy = ( + mLSE_temperature_partial_copy[None, m_idx, None, head_idx]) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + if const_expr(mLSE_temperature_partial is not None): + cute.copy( + gmem_thr_copy_LSE, + mLSE_temperature_partial_cur_copy[None, si], + tLSEsLSE_temperature[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) + + # Precompute per-row values for flattened (q_local, head) tiles. + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate in tile + idx = m_block * self.tile_m + mi + if idx >= max_idx: + tOhidx[m] = -1 + tOmidx[m] = 0 + tOSplitCount[m] = 0 + tOrOptr[m] = cutlass.Int64(0) + else: + tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod) + if const_expr(mSplitCounts is None): + tOSplitCount[m] = num_splits + elif const_expr(cu_seqlens is None): + tOSplitCount[m] = mSplitCounts[ + batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + else: + tOSplitCount[m] = mSplitCounts[ + offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + tOrOptr[m] = utils.elem_pointer( + mO_partial_cur, + (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]), + ).toint() + + tOpO = None + if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOSplitCount, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + if const_expr(mLSE_temperature_partial is not None): + ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature) + ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature) + cute.copy( + s2r_tiled_copy_LSE, + ts2rsLSE_temperature, + ts2rrLSE_temperature, + ) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row. Invalid splits + # have already been filled with -inf, so Step 5 can write the + # kernel-native LSE_out directly. + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + # Normalize scales + inv_sum = 0.0 + if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur: + final_lse[m] = -Float32.inf + else: + final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + if const_expr(mLSE_temperature_partial is not None): + final_lse_temperature = cute.make_rmem_tensor( + cute.size(ts2rrLSE_temperature, mode=[2]), Float32) + for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_temperature_max = cute.arch.warp_reduction_max( + ts2rrLSE_temperature[None, None, m] + .load() + .reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ), + threads_in_group=threads_per_col, + ) + lse_temperature_max_cur = ( + 0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max + ) + LOG2_E = math.log2(math.e) + lse_temperature_sum_cur = 0.0 + for s in cutlass.range( + cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE_temperature[0, s, m] * LOG2_E + - (lse_temperature_max_cur * LOG2_E), + fastmath=True, + ) + lse_temperature_sum_cur += scale + lse_temperature_sum_cur = cute.arch.warp_reduction_sum( + lse_temperature_sum_cur, threads_in_group=threads_per_col + ) + if ( + max_valid_split[m] < 0 + or lse_temperature_sum_cur == 0.0 + or lse_temperature_sum_cur != lse_temperature_sum_cur + ): + final_lse_temperature[m] = -Float32.inf + else: + final_lse_temperature[m] = ( + cute.math.log(lse_temperature_sum_cur, fastmath=True) + + lse_temperature_max + ) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.tile_m: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # This writeback is the authoritative LSE_out returned by the + # public Sparse Attention / Sparse Page Attention interface. + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + mLSE_cur = mLSE[None, None, batch_idx] + else: + mLSE_cur = cute.domain_offset((offset, 0), mLSE) + if const_expr(mLSE_temperature is not None): + if const_expr(cu_seqlens is None): + mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx] + else: + mLSE_temperature_cur = cute.domain_offset( + (offset, 0), mLSE_temperature) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mLSE_cur[m_idx, head_idx] = final_lse[m] + if const_expr(mLSE_temperature is not None): + mLSE_temperature_cur[m_idx, head_idx] = ( + final_lse_temperature[m]) + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # Flush any outstanding async-copy groups before the local Step-7 + # permutation buffer is read on the tail of the kernel. + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # =============================== + # Step 7: Write final O to gmem (fake→real via SMEM) + # =============================== + + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) + if const_expr(cu_seqlens is None): + mO_cur = mO[None, None, None, batch_idx] + else: + mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + num_vals = const_expr(cute.size(tOcO, mode=[0])) + if const_expr(not use_pdl): + # Direct / standalone calls don't participate in the K1->K2 + # dependency chain. Use a simple per-element real-column store + # path here to keep mixed-shape launches stable. + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO[k]: + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype) + else: + # 7a: fp32 accumulator -> output dtype SMEM with fake→real + # permutation. The dedicated permutation buffer stays separate + # from the O_partial pipeline staging buffer. + sO_perm = sO_perm_buf + + if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]): + # O_partial uses a dtype-specific STG.128 fake layout, but + # sO_perm is in the final O dtype. For all supported fake + # layouts, adjacent fake pairs map to adjacent real columns, + # so write the final BF16/F16 O pair as one 32-bit SMEM store. + assert num_vals % 2 == 0 + r2s_o_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=32, + ) + rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32) + sO_perm_i32_base = cute.make_ptr( + dtype=cutlass.Int32, + value=sO_perm.iterator.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v_pair in cutlass.range(num_vals // 2, unroll_full=True): + v = v_pair * 2 + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o0 = tOrO[v, m, k] + o1 = tOrO[v + 1, m, k] + if const_expr(mOutputScale is not None): + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), + (output_scale, output_scale), + ) + rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype) + smem_pair_ptr = cute.make_ptr( + dtype=cutlass.Int32, + value=( + sO_perm_i32_base.toint() + + Int64( + row_local * sO_perm_i32_row_stride + + real_col // Int32(2) + ) + * Int64(4) + ), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_pair = cute.make_tensor( + smem_pair_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair) + else: + # 7a: iterate over ALL val elements in mode[0]. + # tOcO[v, m, k][1] gives different fake_col for each v. + r2s_o_scalar_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, + ) + rO_scalar = cute.make_rmem_tensor((1,), self.dtype) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + rO_scalar[0] = o_val.to(self.dtype) + smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col)) + smem_scalar_ptr = cute.make_ptr( + dtype=self.dtype, + value=smem_ptr.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=self.dtype.width // 8, + ) + sO_scalar = cute.make_tensor( + smem_scalar_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar) + + cute.arch.sync_threads() + + # 7b: SMEM (real order, output dtype) → GMEM + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOcO_store = gmem_thr_copy_O.partition_D(cO) + tOsO_store = gmem_thr_copy_O.partition_D(sO_perm) + rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + num_store_rows = const_expr(cute.size(tOcO_store, mode=[1])) + num_store_vals = const_expr(cute.size(tOcO_store, mode=[0])) + tOpO_store = None + if const_expr(not self.is_even_k): + tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO_store), unroll_full=True): + tOpO_store[k] = ( + tOcO_store[0, 0, k][1] + < mO_partial.shape[1] - k_block * self.k_block_size + ) + + # Read output dtype from SMEM (now in real column order). + for m in cutlass.range(num_store_rows, unroll_full=True): + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k]) + + # Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7) + for m in cutlass.range(num_store_rows, unroll_full=True): + row_local = tOcO_store[0, m, 0][0] + idx = m_block * self.tile_m + row_local + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mO_cur_copy = cute.tiled_divide( + mO_cur[m_idx, None, head_idx], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + k_idx = tOcO_store[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOSplitCount: cute.Tensor, + tOpO: Optional[cute.Tensor], + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]): + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_cur_copy[None, k_idx, split], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, k].fill(0) + + +def _get_cutlass_dtype(torch_dtype: torch.dtype): + if torch_dtype not in torch2cute_dtype_map: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + return torch2cute_dtype_map[torch_dtype] + + +_combine_compile_cache = {} + + +def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = const_expr(dtype.width // 8) + bytes_per_row = const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout( + (8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), + order=(1, 0), + ), + ) + + +def combine( + o_partial_fake, + lse_partial, + o_out, + lse_out, + *, + lse_temperature_partial=None, + lse_temperature_out=None, + cu_seqlens=None, + seqused=None, + split_counts=None, + output_scale=None, + use_pdl=False, +): + """K2: merge sparse forward split partials into the final output. + + STG.128 fake-layout handling remains an internal implementation detail. + When lse_out is provided, the kernel writes the final authoritative + log-sum-exp for each query row/head directly into that tensor. + + Args: + o_partial_fake: + Batched: [num_splits, batch, Sq, head_q, dim] + Varlen: [num_splits, total_q, head_q, dim] + lse_partial: + Batched: [num_splits, batch, Sq, head_q] + Varlen: [num_splits, total_q, head_q] + o_out: + Batched: [batch, Sq, head_q, dim] + Varlen: [total_q, head_q, dim] + lse_out: + Batched: [batch, Sq, head_q] + Varlen: [total_q, head_q] + lse_temperature_partial: + Optional temperature-scaled LSE partial with the same shape as + lse_partial. + lse_temperature_out: + Optional temperature-scaled final LSE with the same shape as + lse_out. + cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine. + seqused: Optional [batch] int32 effective lengths for combine. + split_counts: Optional int32 rowwise valid split counts prepared from + q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen: + [total_q, head_kv]. + output_scale: Optional fp32 tensor with at least one element. When + provided, the final O accumulator is multiplied once before store. + use_pdl: When True, wait on PDL dependencies from the producer K1 + kernel. When False, launch without PDL waits. + """ + D = o_partial_fake.shape[-1] + num_splits = o_partial_fake.shape[0] + return_temperature_lse = ( + lse_temperature_partial is not None or lse_temperature_out is not None + ) + if (lse_temperature_partial is None) != (lse_temperature_out is None): + raise ValueError( + "lse_temperature_partial and lse_temperature_out must either both be provided or both be None" + ) + if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape: + raise ValueError( + "lse_temperature_partial must have the same shape as lse_partial, " + f"got {lse_temperature_partial.shape} vs {lse_partial.shape}" + ) + if lse_temperature_out is not None: + if lse_out is None: + raise ValueError("lse_temperature_out requires lse_out") + if lse_temperature_out.shape != lse_out.shape: + raise ValueError( + "lse_temperature_out must have the same shape as lse_out, " + f"got {lse_temperature_out.shape} vs {lse_out.shape}" + ) + if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32: + raise TypeError("temperature LSE tensors must be torch.float32") + + partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype) + out_dtype = _get_cutlass_dtype(o_out.dtype) + if output_scale is not None: + if output_scale.dtype != torch.float32: + raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}") + if output_scale.numel() < 1: + raise ValueError("output_scale must contain at least one element") + if output_scale.device != o_out.device: + raise ValueError("output_scale must be on the same device as o_out") + output_scale = output_scale.contiguous() + if split_counts is not None: + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_out.ndim == 4: + if split_counts.ndim != 3: + raise ValueError( + f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[:2] != o_out.shape[:2]: + raise ValueError( + f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}" + ) + else: + if cu_seqlens is None: + raise ValueError("split_counts with varlen output requires cu_seqlens") + if split_counts.ndim != 2: + raise ValueError( + f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[0] != o_out.shape[0]: + raise ValueError( + f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q " + f"({o_out.shape[0]})" + ) + if o_out.shape[-2] % split_counts.shape[-1] != 0: + raise ValueError( + f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})" + ) + qheadperkv = o_out.shape[-2] // split_counts.shape[-1] + else: + qheadperkv = 1 + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous") + if seqused is not None: + if seqused.dtype != torch.int32: + raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}") + if seqused.ndim != 1: + raise ValueError(f"seqused must be rank-1, got {seqused.shape}") + if not seqused.is_contiguous(): + raise ValueError("seqused must be contiguous") + + k_block_size = 128 if D > 64 else 64 + tile_m = 64 + has_cu_seqlens = cu_seqlens is not None + has_seqused = seqused is not None + has_lse = lse_out is not None + has_split_counts = split_counts is not None + has_output_scale = output_scale is not None + min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0 + + key = ( + "combine", + D, + k_block_size, + tile_m, + num_splits, + partial_dtype, + out_dtype, + has_cu_seqlens, + has_seqused, + has_lse, + bool(return_temperature_lse), + has_split_counts, + has_output_scale, + use_pdl, + min_blocks_per_mp, + ) + if key not in _combine_compile_cache: + from ....src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _combine_compile_cache[key] = loaded + else: + from ....quack.compile_utils import make_fake_tensor + + kernel = SparseAttentionForwardCombine( + dtype=out_dtype, + dtype_partial=partial_dtype, + head_dim=D, + tile_m=tile_m, + k_block_size=k_block_size, + topk=num_splits, + use_pdl=use_pdl, + min_blocks_per_mp=min_blocks_per_mp, + # stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM, + # theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35% + # -> 88.64%. Runtime latency within noise (kernel already at HBM + # bandwidth ceiling in practice) but the cleaner SOL profile + # matters for downstream NCU comparison. + stages=2, + ) + div = 128 // partial_dtype.width + if has_cu_seqlens: + total_q, nheads = (cute.sym_int64() for _ in range(2)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, total_q, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + mO = make_fake_tensor( + out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if return_temperature_lse + else None + ) + else: + batch, sq, nheads = (cute.sym_int64() for _ in range(3)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + mO = make_fake_tensor( + out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if return_temperature_lse + else None + ) + if not has_split_counts: + mSplitCounts = None + elif has_cu_seqlens: + total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2)) + mSplitCounts = make_fake_tensor( + Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1 + ) + else: + nheads_kv = cute.sym_int64() + mSplitCounts = make_fake_tensor( + Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2 + ) + mOutputScale = ( + make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0) + if has_output_scale + else None + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + _combine_compile_cache[key] = cute.compile( + kernel, + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + None + if cu_seqlens is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None + if seqused is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None, + None, + None, + mSplitCounts, + mOutputScale, + Int32(qheadperkv), + stream, + options="--enable-tvm-ffi", + ) + save_aot(key, _combine_compile_cache[key]) + + with torch.cuda.nvtx.range("K2_Combine"): + _combine_compile_cache[key]( + o_partial_fake, + lse_partial, + o_out, + lse_out, + lse_temperature_partial, + lse_temperature_out, + cu_seqlens, + seqused, + None, + None, + None, + split_counts, + output_scale, + qheadperkv, + ) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64a0616bd5bb9c987e43b87bcbf9e89001fbb36 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/__init__.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""CUTE DSL launchers for paged fp8 decode forward.""" + +from __future__ import annotations + +import torch + +from .atten_fwd import run_decode_attention +from .combine import run_decode_combine + + +def decode_forward_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + merge_indptr: torch.Tensor, + O_partial: torch.Tensor | None, + LSE_partial: torch.Tensor | None, + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + max_split_count: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + O_partial_dummy: torch.Tensor | None = None, + LSE_partial_dummy: torch.Tensor | None = None, +) -> None: + """Launch dense paged fp8 decode forward and optional compressed combine. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated + placeholder buffers for the non-split path. When supplied, ``run_decode_attention`` + skips the per-call ``torch.empty`` it would otherwise need to satisfy the + kernel's positional arg signature, saving ~5us on small-kv calls. + """ + + run_decode_attention( + q, + k, + v, + page_table, + seqused_k, + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + o_indptr, + out, + lse, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(page_size), + kv_chunk_size_pages=int(kv_chunk_size_pages), + split_kv=bool(split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + O_partial_dummy=O_partial_dummy, + LSE_partial_dummy=LSE_partial_dummy, + ) + if split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode requires O_partial and LSE_partial") + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + run_decode_combine( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q=int(seqlen_q), + q_tokens_per_group=q_tokens_per_group, + max_split_count=int(max_split_count), + ) + + +__all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..9a56bb20363deffd4c850533484427bc128b3c84 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py @@ -0,0 +1,2691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Dense paged fp8 decode forward path. + +This file owns the CUTE DSL entry point for decode attention via +``SparseDecodeAttentionForwardSm100`` — SM100 UTCMMA + persistent +scheduling, paged fp8 Q/K/V, BSA blk128-style intra-warp overlap pipeline. +Forward only. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable, Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cutlass_dsl import BaseDSL +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from ....quack import copy_utils, layout_utils + +from ....src.common import pipeline +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +from ....src.common.pack_gqa import pack_gqa_layout +from ....src.common.tile_scheduler import SchedulingMode +from ....src.sm100.fwd_decode.tile_scheduler import ( + DecodeTileScheduler, + DecodeTileSchedulerArguments, +) + + +class SparseDecodeAttentionForwardSm100: + """SM100 dense paged fp8 decode forward attention (UTCMMA + CLC). + + Scope (Phase 1): + - Dense decode, ``split_kv=False``, single q-tile per work item + (``packed_q = seqlen_q * qhead_per_kv <= tile_m=128``). + - Causal only. KV reverse page loop; first reverse block applies + causal/seqlen mask, the rest is unmasked. + - fp8 Q/K/V, bf16 O, fp32 LSE. P is quantized to fp8_e4m3fn before PV + via ``SoftmaxSm100.apply_exp2_convert`` (mirror of prefill fp8 PV). + - per-batch ``mSeqUsedK[b]`` heterogeneous; no uniform-length assumptions. + + Production scope reached at Phase 4+: + - Multi q-tile (Phase 2), split-KV partial writeback (Phase 3), + CLC persistent scheduling (Phase 4), TC SOL >= 90% (Phase 7). + """ + + # UTCMMA K-tile width (matches prefill SparseAttentionForwardSm100). + k_tile = 64 + + def __init__( + self, + head_dim: int = 128, + qhead_per_kv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + page_size: int = 128, + split_kv: bool = False, + causal: bool = True, + write_lse: bool = True, + disable_softmax_exp2: bool = False, + ): + # --- structural constraints (Phase 1 scope) ------------------------- + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeAttentionForwardSm100 currently supports only D=128, " + f"got D={head_dim}" + ) + if m_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires tile_m=128, got {m_block_size}" + ) + if n_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires n_block_size=128, got {n_block_size}" + ) + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal n_block_size ({n_block_size})" + ) + if qhead_per_kv not in (16, 8, 4, 2, 1): + raise ValueError( + f"qhead_per_kv must be in {{1, 2, 4, 8, 16}}, got {qhead_per_kv}" + ) + if not causal: + raise NotImplementedError( + "decode UMMA forward currently supports only causal=True" + ) + + self.head_dim = int(head_dim) + self.qhead_per_kv = int(qhead_per_kv) + self.m_block_size = int(m_block_size) + self.n_block_size = int(n_block_size) + self.page_size = int(page_size) + self.tile_m = int(m_block_size) + self.split_kv = bool(split_kv) + self.causal = bool(causal) + self.write_lse = bool(write_lse) + self.disable_softmax_exp2 = bool(disable_softmax_exp2) + # FA fp8 SM100 fwd uses a threshold of 4.0 to avoid rescaling O for + # small row-max movements; correction receives acc_scale directly. + self.rescale_threshold = 4.0 + + # q tokens packed per (m_block_size) row group along M. + self.q_tokens_per_group = self.m_block_size // self.qhead_per_kv + + self.mma_tiler_qk = (self.m_block_size, self.n_block_size, self.head_dim) + self.mma_tiler_pv = (self.m_block_size, self.head_dim, self.n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # --- pipeline ring stages (BSA blk128 q_stage=1, s_stage=2) --- + self.q_stage = 1 + self.s_stage = 2 + self.o_stage = 2 + # Keep the fp8 decode KV ring deep enough to cover the K0/Q/K1/V0... + # order. This matches sage's fp8 setting and removes the underfed + # two-stage KV pipeline seen in the q8/16K non-split case. + self.kv_stage = 4 + self.k_stages = 2 + # Match prefill: PV is split at 3/4 of n_block_size for fp8. The + # producer (P store) must publish exactly 3N/4 fp8 columns at the + # signal point; that requires the TMEM-store atom Repetition to be + # ``8`` (one PV ``f8f6f4`` K=32 segment = 8 fp32 packed cols), so + # ``shape[2]=4`` chunks and ``split_idx=3`` lands on the 3N/4 + # boundary exactly. The previous N/2 cap was a workaround for + # ``Repetition(16)`` whose coarser chunk boundary could not + # represent 3N/4. + self.split_P_arrive = self.n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # --- warp layout (16 warps / 512 threads) — BSA-aligned (Phase 1.10.6b) + # 0-3 softmax WG 0 + # 4-7 softmax WG 1 + # 8-11 correction WG (acc_O rescale across pages + final epilogue + # write-back; participates in TmemPtr barrier) + # 12 MMA issue warp + # 13 spare / future CLC scheduler + # 14 load warp (serial Q + K + V TMA loads) + # 15 empty / register-budget reserve + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.correction_warp_base = ( + self.softmax1_warp_base + self.warps_per_group) + self.mma_warp_id = self.correction_warp_base + self.warps_per_group + self.spare_warp_id = self.mma_warp_id + 1 + self.load_warp_id = self.spare_warp_id + 1 + self.empty_warp_id = self.load_warp_id + 1 + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps + + # --- TMEM layout (fp8 P width-pack: 4 fp8 lanes per fp32 column) --- + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for head_dim_v=128 + # P (fp8) overlays the second half of each S tile via recast_ptr. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = self.n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * self.n_block_size + # fp8 P occupies n_block_size * fp8_width / fp32_width = n/4 fp32 cols. + # P offset is set in __call__ once q_dtype is known (defer to Phase 1.3). + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # --- register budget per role (BSA hdim>=96 default) --- + self.num_regs_softmax = 184 + self.num_regs_correction = 88 + self.num_regs_other = 56 + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_epilogue = self.num_regs_other + self.num_regs_empty = self.num_regs_other + + # exp2 emulation for causal: matches prefill ex2_emu_freq=16. + # disable_softmax_exp2 (Phase 7 SOL gate) bypasses both emulation and + # native exp2 — the convert pass becomes a pure fp32 -> fp8 cast. + self.ex2_emu_freq = 16 if (self.causal and not self.disable_softmax_exp2) else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # --- SM100 cluster config (single-CTA for decode, no 2-CTA pair) - + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + self.use_clc_scheduler = True + self.scheduling_mode = SchedulingMode.CLC + self.sched_stages = 2 + self.clc_scheduler_warp_id = self.empty_warp_id + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # Phase 1.2+ fills in the body. Phase 1.1 keeps signatures stable so + # the rest of the codepath (run_decode_attention dispatch in 1.10) + # can wire to this class without further churn. + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # [B, Sq, Hq, D] fp8 + mK: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mV: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mPageTable: cute.Tensor, # [B, max_pages] int32 + mSeqUsedK: cute.Tensor, # [B] int32 + mRequestIndices: cute.Tensor, # [work_capacity] int32 + mQoTileIndices: cute.Tensor, # [work_capacity] int32 + mKvTileIndices: cute.Tensor, # [work_capacity] int32 + mBlockValidMask: cute.Tensor, # [work_capacity] int32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] bf16 + mLSE: cute.Tensor, # [total_q, Hq] fp32 + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + softmax_scale: Float32, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + stream: cuda.CUstream = None, + ): + # --- dtype contract ------------------------------------------------ + if const_expr(mQ.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA Q must be Float8E4M3FN") + if const_expr(mK.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA K must be Float8E4M3FN") + if const_expr(mV.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA V must be Float8E4M3FN") + if const_expr(mO.element_type is not cutlass.BFloat16): + raise TypeError("decode UMMA output O must be BFloat16") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode UMMA output LSE must be Float32") + if const_expr(self.split_kv): + if const_expr(mO_partial is None or mO_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 O_partial") + if const_expr(mLSE_partial is None or mLSE_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 LSE_partial") + + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = ( + mO_partial.element_type if const_expr(self.split_kv) + else mO.element_type + ) + # f8f6f4 MMA descriptor kind for fp8 Q/K/V. + self.mma_kind = "f8f6f4" + # fp8 P width-pack ratio: each fp32 TMEM column holds 4 fp8 P lanes. + # Computed here so __init__ stays dtype-agnostic and the TMEM offsets + # can later be derived from this ratio in Phase 1.3. + elem_bytes = const_expr(self.q_dtype.width // 8) + p_cols_as_fp32 = const_expr( + self.n_block_size * self.q_dtype.width // Float32.width + ) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + + mQ, mK, mV, mO, mLSE = [ + assume_tensor_aligned(t) for t in (mQ, mK, mV, mO, mLSE) + ] + if const_expr(mO_partial is not None): + mO_partial = assume_tensor_aligned(mO_partial) + if const_expr(mLSE_partial is not None): + mLSE_partial = assume_tensor_aligned(mLSE_partial) + mO_epilogue = mO_partial if const_expr(self.split_kv) else mO + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO_epilogue) + self.epi_tile = (self.m_block_size, self.head_dim) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T + PV. PV uses MN-major V operand (V already + # transposed in the layout below) and a TMEM operand source for P. + # Phase 1.4 builds tiled_mma_qk; Phase 1.5 adds tiled_mma_pv so sV + # layout can derive the MN-major swizzle. + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # Paged K/V tensor view permutation. + # Input layout [num_pages, Hkv, page_size, D] (nhsd) is permuted to + # [page_size, D, Hkv, num_pages] for the paged TMA descriptor (K). + # V gets an additional (s,d) swap to become MN-major: + # [D, page_size, Hkv, num_pages]. + # ------------------------------------------------------------------ + mK_paged = cute.make_tensor( + mK.iterator, cute.select(mK.layout, mode=[2, 3, 1, 0]) + ) + mV_kv = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=[2, 3, 1, 0]) + ) + mV_paged = cute.make_tensor( + mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3]) + ) + + # ------------------------------------------------------------------ + # Q SMEM layout + BSA/FA PackGQA full-tile TMA atom. + # + # Runtime Q is [B, Sq, Hq, D]. We transpose to [Sq, D, Hq, B], then + # fold qhead_per_kv into the M dimension: + # ((qhead_per_kv, Sq), D, Hkv, B) + # This lets one Q TMA load cover the whole packed (tile_m, D) tile + # instead of issuing one TMA per q token. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + mQ = cute.make_tensor( + mQ.iterator, cute.select(mQ.layout, mode=[1, 3, 2, 0])) + nheads_kv = mK.shape[1] + mQ = pack_gqa_layout(mQ, self.qhead_per_kv, nheads_kv, head_idx=2) + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + + # ------------------------------------------------------------------ + # K / V SMEM layouts + TMA atoms (paged). + # sK uses the QK MMA operand B swizzle; sV uses the PV MMA operand B + # swizzle (MN-major). tP_layout is the TMEM-side P descriptor — no + # SMEM is actually allocated for P, it overlays the S region in TMEM + # via cute.recast_ptr in Phase 1.7. + # ------------------------------------------------------------------ + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + tma_atom_K, mK_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK_paged, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + tma_atom_V, mV_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV_paged, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # ------------------------------------------------------------------ + # Phase 1.10.6b-B-2: TMA-store atom for the epilogue write-back. + # Non-split writes bf16 final O; split-KV writes fp32 O_partial. + # sO follows FA/BSA epilogue layout: one full m_block x D tile in + # SMEM. Both paths expose global O as a packed-GQA tensor view so the + # final store is a full BSA-style m_block x D TMA tile. + # ------------------------------------------------------------------ + sO_layout = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.q_stage, + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_heads_kv_tma = mK.shape[1] + total_o_rows_tma = ( + mO_epilogue.shape[0] + // (num_heads_kv_tma * self.qhead_per_kv) + ) + head_stride_tma = self.head_dim + o_row_stride_tma = ( + num_heads_kv_tma * self.qhead_per_kv * self.head_dim) + kv_head_stride_tma = self.qhead_per_kv * self.head_dim + mO_epilogue_tma = cute.make_tensor( + mO_epilogue.iterator, + cute.make_layout( + ((self.qhead_per_kv, total_o_rows_tma), self.head_dim, num_heads_kv_tma), + stride=((head_stride_tma, o_row_stride_tma), 1, kv_head_stride_tma), + ), + ) + tma_atom_O, mO_tma = cpasync.make_tiled_tma_atom( + tma_store_op, + mO_epilogue_tma, + cute.select(sO_layout, mode=[0, 1]), + self.epi_tile, + ) + + # Pre-multiply softmax scale by log2(e) so the inner exp2 path can + # operate without re-scaling at every iteration. Mirrors prefill. + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + + work_capacity = mRequestIndices.shape[0] + num_heads_kv = mK.shape[1] + tile_sched_args = DecodeTileSchedulerArguments( + Int32(work_capacity), + Int32(num_heads_kv), + cluster_shape_mn=self.cluster_shape_mn, + ) + tile_sched_params = DecodeTileScheduler.to_underlying_arguments( + tile_sched_args, + scheduling_mode=self.scheduling_mode, + ) + self.tile_scheduler_cls = DecodeTileScheduler + grid = DecodeTileScheduler.get_grid_shape(tile_sched_params) + + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + # ------------------------------------------------------------------ + # SharedStorage mirrors BSA blk128's pipeline mesh for dense paged + # decode: Q, shared K/V, S/P/O, P-lastsplit, O-acc, O-epilogue and + # softmax stats mbarriers, plus the TMEM allocator state and SMEM + # staging tensors. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] + mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_P_full_lastsplit: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_O_full: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_softmax_stats0: cute.struct.MemRange[Int64, 2] + mbar_softmax_stats1: cute.struct.MemRange[Int64, 2] + mbar_O_epi: cute.struct.MemRange[Int64, self.s_stage * 2] + # Phase 1.10.6b-B-2: bf16 sO SMEM staging buffer for the TMA + # store epilogue. Sized for one full m_block_size × head_dim + # tile (single stage; overlap with sQ left for later perf tune). + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # ------------------------------------------------------------------ + # Launch — decode tasks are consumed from the + # (work_idx, head_kv_idx) scheduler space. In CLC mode grid is the + # BSA-style hardware problem shape; in static mode it is capped to the + # SM count and each CTA walks the flattened task stream. + # ------------------------------------------------------------------ + # q_tma_bytes (and Phase 1.5+: kv_tma_bytes / q_subtile_bytes) are + # recomputed inside the kernel from the constexpr SMEM layouts. + # Passing them as Constexpr[int] kernel args ended up marshalling + # to dynamic Int32 here, which then tripped MbarrierArray's + # `if tx_count < 0` check inside PipelineTmaUmma.create. + self.kernel( + mQ, mK_paged, mV_paged, + mPageTable, mSeqUsedK, + mRequestIndices, mQoTileIndices, mKvTileIndices, mBlockValidMask, + mSplitCounts, mOIndptr, + mO, mO_tma, mLSE, + mO_partial, mLSE_partial, + softmax_scale_log2, + sQ_layout, sK_layout, sV_layout, tP_layout, sO_layout, + tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, + tiled_mma_qk, tiled_mma_pv, + tile_sched_params, + seqlen_q, page_size, kv_chunk_size_pages, + Int32(num_heads_kv), + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=( + self.cluster_shape_mnk + if cute.size(self.cluster_shape_mnk) > 1 else None + ), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + # --- runtime tensors ------------------------------------------------- + mQ: cute.Tensor, # [((qhead_per_kv, Sq), D, Hkv, B)] + mK_paged: cute.Tensor, # [page_size, D, Hkv, num_pages] fp8 + mV_paged: cute.Tensor, # [D, page_size, Hkv, num_pages] fp8 + mPageTable: cute.Tensor, + mSeqUsedK: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mBlockValidMask: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mO_tma: cute.Tensor, + mLSE: cute.Tensor, + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + # --- scalars --------------------------------------------------------- + softmax_scale_log2: Float32, + # --- SMEM layouts ---------------------------------------------------- + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + # --- TMA atoms ------------------------------------------------------- + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + # --- TiledMma -------------------------------------------------------- + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: DecodeTileScheduler.Params, + # --- Int32 iteration bounds ------------------------------------------ + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, work-item dispatch. + # ------------------------------------------------------------------ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + if warp_idx == Int32(0): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + # ------------------------------------------------------------------ + # SMEM allocation — same SharedStorage type was registered on the + # class in __call__ (Phase 1.3). Every warp materialises the same + # storage view; later phases populate sQ/sK/sV/mbar contents. + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + # sQ is the MMA-operand layout and now also the Q TMA load target: + # PackGQA makes the global Q view match the full BSA (tile_m, D) tile. + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + # ------------------------------------------------------------------ + # TMEM allocator — MMA warp performs the allocation, all softmax / + # store / MMA warps participate in the TmemPtr named barrier that + # broadcasts the allocator pointer. Spare warp and KV-load warps + # do not touch TMEM directly. + # ------------------------------------------------------------------ + # TmemPtr participants: 2 softmax WGs (8 warps) + correction WG + # (4 warps) + MMA warp = 13 warps × WARP_SIZE. Load / spare / + # empty warps don't touch TMEM and don't arrive on this barrier. + tmem_alloc_warps: cutlass.Constexpr[int] = ( + self.warps_per_group * 3 + 1) + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + ) + tmem_cols = self.tmem_total + + # ------------------------------------------------------------------ + # Cluster layout + warp-specialized pipelines. + # Mirrors prefill (src/sm100/fwd/atten_fwd.py:617-683): cta_layout_vmnk + # is rebuilt in-kernel from tiled_mma_qk.thr_id.shape so its size is + # constexpr (the `cute.size(cta_layout_vmnk) == 1` check inside + # PipelineTmaUmma.create folds at compile time). pipeline_q is + # joined by the BSA S/P/O and shared K/V pipelines below. + # ------------------------------------------------------------------ + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + # One softmax WG participates per S/P/O stage; correction and the + # epilogue warp handle O rescale and TMA write-back. + softmax_warps = ThreadCooperativeGroup(self.warps_per_group) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + + # Recompute TMA byte counts inside the kernel from the constexpr SMEM + # layouts — see note in __call__ above the self.kernel(...) call for + # why these can't be plumbed through as Constexpr[int] kernel args. + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + k_tma_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # Decode KV follows BSA's single K/V ring: K0 is primed before Q, + # then K1, V0, K2, V1, ... share one PipelineTmaUmma state while + # landing in separate sK/sV SMEM tensors. For fp8 decode K/V TMA + # tiles have the same byte count, so the shared barrier uses K's count. + pipeline_kv = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=k_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + # ------------------------------------------------------------------ + # BSA pipeline mesh. + # pipeline_s_p_o — MMA→{softmax,correction} (8-warp cluster + # consumer). MMA producer_commit signals + # "S ready"; consumer_release signals "P stored + # and acc_O rescaled — MMA can issue next QK". + # pipeline_o_acc — MMA→correction (acc_O updated by PV). + # pipeline_sm_stats0/1 — softmax→correction stage-local stats. + # This avoids the per-warp NamedBarrier used by + # the BSA reference while preserving the same + # first/rescale/final signal sequence. + # pipeline_o_epi — correction→epilogue warp 13 (final O ready). + # ------------------------------------------------------------------ + softmax_correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE + * (self.warps_per_group + self.warps_per_group) # = 256 + ) + correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group # = 128 + ) + epilogue_warp_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE # warp 13 = 32 threads + ) + + pipeline_s_p_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_warps, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o_acc = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_O_full.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats0 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats0.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_sm_stats1 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats1.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_o_epi = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_O_epi.data_ptr(), + num_stages=self.s_stage, + producer_group=correction_threads, + consumer_group=epilogue_warp_threads, + defer_sync=True, + ) + + # Fence mbar init across all regular pipelines. CLC pipeline setup + # follows the BSA ordering: arrive after mbar init, create scheduler + # state, then wait before TMEM allocation and role dispatch. + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps = ( + self.threads_per_cta // cute.arch.WARP_SIZE + ) * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_clc_consumer_warps, + ) + clc_pipeline = cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ) + tile_scheduler = self.tile_scheduler_cls.create( + tile_sched_params, clc_response_ptr=clc_response_ptr + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + tile_scheduler.set_clc_pipeline( + clc_pipeline, clc_consumer_state) + else: + clc_pipeline = None + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # Single load warp issues Q + K + V TMA serially; no inter-warp + # broadcast / Q-load WG barrier needed (the BSA-aligned layout + # collapses the previous 4-warp Q-load fan-out into one warp). + + # ------------------------------------------------------------------ + # Phase 1.10.3: pre-dispatch TMEM partitions for softmax read/write. + # Mirrors prefill softmax body setup + # (src/sm100/fwd/atten_fwd.py:807-829, 1891-1921). Built once across + # all warps so each softmax WG can take its stage slice. + # ------------------------------------------------------------------ + thr_mma_qk_pre = tiled_mma_qk.get_slice(0) + qk_acc_shape_pre = thr_mma_qk_pre.partition_shape_C( + self.mma_tiler_qk[:2]) + tStS_base_pre = thr_mma_qk_pre.make_fragment_C(qk_acc_shape_pre) + tStS_pre = cute.make_tensor( + tStS_base_pre.iterator, + cute.append( + tStS_base_pre.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tScS_pre = thr_mma_qk_pre.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS_pre = tScS_pre[(None, None), 0, 0] + # fp8 P occupies n_block_size * fp8_width / fp32_width fp32 cols. + tilePlikeFP32 = const_expr( + self.mma_tiler_qk[1] * self.q_dtype.width // Float32.width) + tmem_load_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + # Repetition(8) gives ``tStP_r2t.shape[2] = tilePlikeFP32 / 8 = 4`` + # chunks for fp8 (tilePlikeFP32=32), with each chunk publishing + # 8 fp32 cols = 32 fp8 cols = exactly one PV ``f8f6f4`` K=32 + # segment. ``split_idx = 4 * 3N/4 / N = 3`` aligns the early + # publish edge to the producer/consumer K boundary. Larger + # Repetition (e.g. 16) would coarsen shape[2] to 2 and force + # split_idx to floor to 1, publishing only N/2 of P before MMA's + # first three K=32 segments need cols 0..3N/4 — that mismatch is + # the NaN source the workaround used to dodge with split=N/2. + tmem_store_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), + Float32, + ) + tmem_store_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tmem_load_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + # ------------------------------------------------------------------ + # Warp role dispatch. Bodies are filled in Phase 1.3-1.9: + # softmax WG 0/1 (warps 0-3, 4-7) — softmax + P fp32->fp8 convert + # store / Q-load WG (warps 8-11) — Q TMA gather + epilogue store + # MMA warp (warp 12) — UTCMMA QK + PV issue + # correction WG (warps 8-11) — per-page acc_O rescale + epilogue + # MMA warp (warp 12) — UTCMMA QK + PV issue + # spare warp (warp 13) — empty / future CLC scheduler + # load warp (warp 14) — serial Q + K + V TMA loads + # empty warp (warp 15) — register-budget reserve + # ------------------------------------------------------------------ + is_softmax0_warp = ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + ) + is_softmax1_warp = ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.correction_warp_base) + ) + is_correction_warp = ( + warp_idx >= Int32(self.correction_warp_base) + and warp_idx < Int32(self.mma_warp_id) + ) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + is_spare_warp = warp_idx == Int32(self.spare_warp_id) + is_load_warp = warp_idx == Int32(self.load_warp_id) + is_empty_warp = warp_idx == Int32(self.empty_warp_id) + + if const_expr(self.use_clc_scheduler): + if warp_idx == Int32(self.clc_scheduler_warp_id): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + self.clc_scheduler_warp(clc_pipeline, tile_scheduler) + is_empty_warp = False + + if is_softmax0_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg0 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg0 + self.softmax_loop( + 0, + self.softmax0_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats0, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_softmax1_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg1 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg1 + self.softmax_loop( + 1, + self.softmax1_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats1, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_correction_warp: + cute.arch.setmaxregister_decrease(self.num_regs_correction) + # Participate in TmemPtr handshake so the MMA warp can free. + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_corr + + self.correction_loop( + tiled_mma_pv, + tStS_pre, + tScS_pre, + tmem_load_vec_atom_pre, + pipeline_s_p_o, + pipeline_sm_stats0, + pipeline_sm_stats1, + pipeline_o_acc, + pipeline_o_epi, + sO, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mSplitCounts, + mOIndptr, + mLSE, + mLSE_partial, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + num_heads_kv, + softmax_scale_log2, + ) + tmem_alloc_barrier.arrive() + + if is_spare_warp: + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + self.epilogue_s2g( + mO_tma, + sO, + tma_atom_O, + pipeline_o_epi, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mOIndptr, + mBlockValidMask, + tile_scheduler, + seqlen_q, + ) + + if is_load_warp: + self.load( + tiled_mma_qk, + tiled_mma_pv, + mQ, + mK_paged, + mV_paged, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_q, + pipeline_kv, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + if is_empty_warp: + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + if is_mma_warp: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ---------------------------------------------------------------- + # MMA warp — Phase 1.6: QK fp8×fp8→fp32 UMMA. Phase 1.10.1 now + # wraps the body in the real TMEM allocator lifecycle: + # tmem.allocate(cols) -> wait_for_alloc -> retrieve_ptr + # -> ... QK work ... + # -> relinquish_alloc_permit -> tmem_alloc_barrier.arrive_and_wait + # -> free(ptr, cols) + # Softmax WG 0/1 participate via wait_for_alloc + retrieve_ptr + + # tmem_alloc_barrier.arrive (4+4+1 = 9 warps). + # ---------------------------------------------------------------- + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr # consumed by gemm_pv via raw TMEM offsets + + self.mma( + sQ, + sK, + sV, + tP_layout, + tiled_mma_qk, + tiled_mma_pv, + pipeline_q, + pipeline_kv, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_o_acc, + mRequestIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + # Phase 1.10.1: TMEM allocator teardown. + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + + @cute.jit + def clc_scheduler_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync, + tile_scheduler: DecodeTileScheduler, + ) -> None: + clc_producer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, + self.sched_stages, + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.producer_acquire(clc_producer_state) + mbarrier_addr = clc_pipeline.producer_get_barrier( + clc_producer_state) + tile_scheduler.advance_to_next_work( + mbarrier_addr=mbarrier_addr, + response_stage=clc_producer_state.index, + ) + clc_producer_state.advance() + + clc_pipeline.consumer_wait(clc_consumer_state) + work_tile = tile_scheduler.get_current_work( + response_stage=clc_consumer_state.index) + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() + clc_pipeline.producer_tail(clc_producer_state) + + @cute.jit + def correction_loop( + self, + tiled_mma_pv: cute.TiledMma, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tmem_load_vec_atom_pre: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_sm_stats0: pipeline.PipelineAsync, + pipeline_sm_stats1: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + pipeline_o_epi: pipeline.PipelineAsync, + sO: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mLSE: cute.Tensor, + mLSE_partial: Optional[cute.Tensor], + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + softmax_scale_log2: Float32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg_corr = warp_idx - Int32(self.correction_warp_base) + group_tidx_corr = ( + warp_idx_in_wg_corr * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + + # First iter: no correction is required. Notify MMA that the + # initial O slots are available, matching BSA's correction_loop. + for stage_init in cutlass.range_constexpr(self.s_stage): + pipeline_s_p_o.consumer_release_w_index(Int32(stage_init)) + + o_corr_consumer_phase = Int32(0) + sm_stats0_consumer_phase = Int32(0) + sm_stats1_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + thr0_rs = tiled_mma_pv.get_slice(0) + pv_acc_shape_rs_c = thr0_rs.partition_shape_C( + self.mma_tiler_pv[:2]) + tOtO_base_rs_c = thr0_rs.make_fragment_C(pv_acc_shape_rs_c) + tOtO_rs_c = cute.make_tensor( + tOtO_base_rs_c.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base_rs_c.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tScS_vec_layout_corr = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec_corr = cute.make_tensor( + tScS_pre.iterator, tScS_vec_layout_corr) + tSAcc_corr0 = tStS_pre[(None, None), 0, 0, 0] + tSAcc_corr1 = tStS_pre[(None, None), 0, 0, 1] + tStS_vec0_layout_corr = cute.composition( + tSAcc_corr0.layout, cute.make_layout((self.m_block_size, 2))) + tStS_vec1_layout_corr = cute.composition( + tSAcc_corr1.layout, cute.make_layout((self.m_block_size, 2))) + tStStats0_t2r_src = cute.make_tensor( + tSAcc_corr0.iterator, tStS_vec0_layout_corr) + tStStats1_t2r_src = cute.make_tensor( + tSAcc_corr1.iterator, tStS_vec1_layout_corr) + thr_tmem_load_vec = tcgen05.make_tmem_copy( + tmem_load_vec_atom_pre, + tStStats0_t2r_src, + ).get_slice(group_tidx_corr) + tStStats0_t2r = thr_tmem_load_vec.partition_S(tStStats0_t2r_src) + tStStats1_t2r = thr_tmem_load_vec.partition_S(tStStats1_t2r_src) + tScStats_t2r = thr_tmem_load_vec.partition_D(tScS_vec_corr) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_corr = mRequestIndices[work_idx] + qo_tile_corr = mQoTileIndices[work_idx] + seqused_k_corr = mSeqUsedK[batch_idx_corr] + split_idx_corr = mKvTileIndices[work_idx] + kv_pages_corr = ( + seqused_k_corr + page_size - Int32(1)) // page_size + kv_page_begin_corr = split_idx_corr * kv_chunk_size_pages + kv_page_end_corr = cutlass.min( + kv_pages_corr, + kv_page_begin_corr + kv_chunk_size_pages, + ) + page_count_corr = kv_page_end_corr - kv_page_begin_corr + block_iter_count_corr = ( + page_count_corr + Int32(1)) & ~Int32(1) + stage0_count_corr = block_iter_count_corr // Int32(2) + stage1_count_corr = block_iter_count_corr // Int32(2) + + if stage0_count_corr > Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + if stage1_count_corr > Int32(0): + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + for page_rel_corr in cutlass.range( + Int32(self.s_stage), block_iter_count_corr, unroll=1 + ): + # sm_stats[0] now holds the deferred-exp2 log2-delta: + # 0.0 means "no rescale needed", a negative value is the + # raw delta that needs exp2 to become a true scale factor. + if (page_rel_corr & Int32(1)) == Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 0], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 1], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(1)) + + for stage_wait in cutlass.range_constexpr(self.s_stage): + stage_count_wait = ( + stage0_count_corr + if const_expr(stage_wait == 0) + else stage1_count_corr + ) + if stage_count_wait > Int32(0): + pipeline_o_acc.consumer_wait_w_index_phase( + Int32(stage_wait), o_corr_consumer_phase) + + row_sum0 = Float32(0.0) + row_sum1 = Float32(0.0) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + for stage_final in cutlass.range_constexpr(self.s_stage): + if const_expr(stage_final == 0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum0 = tSrStats[0] + row_max0 = tSrStats[1] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum1 = tSrStats[0] + row_max1 = tSrStats[1] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + zero0 = row_sum0 == Float32(0.0) or row_sum0 != row_sum0 + zero1 = row_sum1 == Float32(0.0) or row_sum1 != row_sum1 + rm0 = -Float32.inf if zero0 else row_max0 + rm1 = -Float32.inf if zero1 else row_max1 + row_max_comb = cutlass.max(rm0, rm1) + row_max_safe = ( + Float32(0.0) + if row_max_comb == -Float32.inf + else row_max_comb + ) + scale0 = ( + Float32(0.0) + if zero0 + else cute.math.exp2( + (rm0 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + scale1 = ( + Float32(0.0) + if zero1 + else cute.math.exp2( + (rm1 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + row_sum_comb = row_sum0 * scale0 + row_sum1 * scale1 + combined_zero_or_nan = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + inv_sum = cute.arch.rcp_approx( + Float32(1.0) + if combined_zero_or_nan else row_sum_comb) + final_scale0 = scale0 * inv_sum + final_scale1 = scale1 * inv_sum + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(0), corr_epi_producer_phase) + self.correction_epilogue_combine( + tiled_mma_pv, + sO[None, None, 0], + group_tidx_corr, + final_scale0, + final_scale1, + ) + + if const_expr(self.write_lse or self.split_kv): + if group_tidx_corr < Int32(self.m_block_size): + is_bad_lse = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + LN2 = Float32(math.log(2.0)) + lse_val = ( + -Float32.inf if is_bad_lse + else ( + row_max_safe * softmax_scale_log2 + + cute.math.log2(row_sum_comb, fastmath=True) + ) * LN2 + ) + tok_lse = group_tidx_corr // Int32(self.qhead_per_kv) + if tok_lse < seqlen_q: + h_in_kv_lse = ( + group_tidx_corr + - tok_lse * Int32(self.qhead_per_kv)) + q_idx_lse = ( + qo_tile_corr * Int32(self.q_tokens_per_group) + + tok_lse + ) + h_abs_lse = ( + head_kv_idx * Int32(self.qhead_per_kv) + + h_in_kv_lse + ) + if const_expr(self.split_kv): + q_tokens_per_group = Int32( + self.q_tokens_per_group) + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row_lse = ( + mOIndptr[batch_idx_corr] + + split_idx_corr * q_stride_partial + + q_idx_lse + ) + mLSE_partial[ + partial_row_lse, h_abs_lse] = lse_val + else: + q_abs_lse = ( + batch_idx_corr * seqlen_q + q_idx_lse) + mLSE[q_abs_lse, h_abs_lse] = lse_val + + for stage_release in cutlass.range_constexpr(self.s_stage): + stage_count_release = ( + stage0_count_corr + if const_expr(stage_release == 0) + else stage1_count_corr + ) + if stage_count_release > Int32(0): + pipeline_s_p_o.consumer_release_w_index( + Int32(stage_release)) + pipeline_o_acc.consumer_release_w_index( + Int32(stage_release)) + if block_iter_count_corr > Int32(0): + o_corr_consumer_phase = ( + o_corr_consumer_phase ^ Int32(1)) + + pipeline_o_epi.producer_commit_w_index(Int32(0)) + corr_epi_producer_phase = ( + corr_epi_producer_phase ^ Int32(1)) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), corr_epi_producer_phase) + + @cute.jit + def epilogue_s2g( + self, + mO_tma: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + pipeline_o_epi: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mOIndptr: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + ) -> None: + epi_consumer_phase = Int32(0) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + split_idx = mKvTileIndices[work_idx] + + pipeline_o_epi.consumer_wait_w_index_phase( + Int32(0), epi_consumer_phase) + q_tokens_per_group = Int32(self.q_tokens_per_group) + gO = cute.local_tile( + mO_tma[None, None, head_kv_idx], + self.epi_tile, + (None, 0), + ) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO) + if const_expr(not self.split_kv): + q_abs = ( + batch_idx * seqlen_q + + qo_tile * q_tokens_per_group + ) + dst_idx = q_abs // q_tokens_per_group + else: + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row = ( + mOIndptr[batch_idx] + + split_idx * q_stride_partial + + qo_tile * q_tokens_per_group + ) + dst_idx = partial_row // q_tokens_per_group + store_O(src_idx=Int32(0), dst_idx=dst_idx) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0) + pipeline_o_epi.consumer_release_w_index(Int32(0)) + epi_consumer_phase = epi_consumer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def correction_epilogue_combine( + self, + tiled_mma_pv: cute.TiledMma, + sO: cute.Tensor, + tidx: Int32, + scale0: Float32, + scale1: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr_mma.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr_mma.make_fragment_C(pv_acc_shape) + tOtO = cute.make_tensor( + tOtO_base.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tOsO = thr_mma.get_slice(0).partition_C(sO) + tOcO_full = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = ( + 8 * 32 // self.o_dtype.width + ) + tOsO_i = cute.logical_divide( + tOsO, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOcO_i = cute.logical_divide( + tOcO_full, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO0_i = cute.logical_divide( + tOtO[None, None, None, 0], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO1_i = cute.logical_divide( + tOtO[None, None, None, 1], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_load_atom = sm100_utils.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=self.use_2cta_instrs, + ) + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO0_i[(None, None), 0]) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load) + tiled_smem_store = cute.make_tiled_copy_D( + smem_copy_atom, tiled_tmem_load) + tOtO0_t2r = thr_tmem_load.partition_S( + tOtO0_i[(None, None), None]) + tOtO1_t2r = thr_tmem_load.partition_S( + tOtO1_i[(None, None), None]) + tOsO_s2r = copy_utils.partition_D_position_independent( + thr_tmem_load, tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D( + tOcO_i[(None, None), None]) + + for col_pass_idx in cutlass.range( + self.head_dim // corr_tile_size, unroll_full=True): + tOtO0_t2r_i = tOtO0_t2r[None, 0, 0, col_pass_idx] + tOtO1_t2r_i = tOtO1_t2r[None, 0, 0, col_pass_idx] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, col_pass_idx] + frg_shape = tOcO_t2r[None, 0, 0, col_pass_idx].shape + tOrO0_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + tOrO1_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + is_zero_output = ( + scale0 == Float32(0.0) and scale1 == Float32(0.0) + ) + if not is_zero_output: + cute.copy(tiled_tmem_load, tOtO0_t2r_i, tOrO0_frg) + cute.copy(tiled_tmem_load, tOtO1_t2r_i, tOrO1_frg) + for j in cutlass.range( + 0, cute.size(tOrO0_frg), 2, unroll_full=True + ): + o0_a, o0_b = cute.arch.mul_packed_f32x2( + (tOrO0_frg[j], tOrO0_frg[j + 1]), + (scale0, scale0), + ) + o1_a, o1_b = cute.arch.mul_packed_f32x2( + (tOrO1_frg[j], tOrO1_frg[j + 1]), + (scale1, scale1), + ) + tOrO0_frg[j], tOrO0_frg[j + 1] = ( + cute.arch.add_packed_f32x2( + (o0_a, o0_b), (o1_a, o1_b)) + ) + else: + tOrO0_frg.fill(Float32(0.0)) + copy_utils.cvt_copy(tiled_smem_store, tOrO0_frg, tOsO_r2s_i) + cute.arch.fence_view_async_shared() + + @cute.jit + def correction_rescale( + self, + tiled_mma_pv: cute.TiledMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + tOcO = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = 16 + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition( + tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count: cutlass.Constexpr[int] = self.head_dim // corr_tile_size + for fi in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment( + tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor( + tOtO_t2r.iterator + fi * corr_tile_size, + tOtO_t2r.layout, + ) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range( + 0, cute.size(tOrO_frg), 2, unroll_full=True + ): + tOrO_frg[j], tOrO_frg[j + 1] = ( + cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + ) + tOtO_r2t_i = cute.make_tensor( + tOtO_r2t.iterator + fi * corr_tile_size, + tOtO_r2t.layout, + ) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tP_layout: cute.ComposedLayout, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + thr_mma_qk = tiled_mma_qk.get_slice(0) + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0_layout = tSrQ[None, None, None, 0].layout + tSrK0_layout = tSrK[None, None, None, 0].layout + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, 0].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, q_smem_base, tSrQ0_layout, + var_name_prefix="decode_q_smem_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="decode_qk_idesc") + gemm_qk = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0_layout, + smem_var_name_prefix="decode_q_smem_desc", + idesc_var_name="decode_qk_idesc", + smem_offset=0, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP_base = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = const_expr(Float32.width // self.v_dtype.width) + tP_stage_stride = const_expr( + self.tmem_stage_stride * tP_width_ratio) + tOrP = cute.make_tensor( + tOrP_base.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP_base.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + tOrV = tiled_mma_pv.make_fragment_B(sV) + pv_mma_op = tiled_mma_pv.op + sm100_helpers.declare_ptx_idesc( + pv_mma_op, var_name="decode_pv_idesc") + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage) + phase_s0 = Int32(0) + phase_s1 = Int32(0) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_mma = mRequestIndices[work_idx] + split_idx_mma = mKvTileIndices[work_idx] + seqused_k_mma = mSeqUsedK[batch_idx_mma] + kv_pages_mma = ( + seqused_k_mma + page_size - Int32(1)) // page_size + kv_page_begin_mma = split_idx_mma * kv_chunk_size_pages + kv_page_end_mma = cutlass.min( + kv_pages_mma, + kv_page_begin_mma + kv_chunk_size_pages, + ) + page_count_mma = kv_page_end_mma - kv_page_begin_mma + block_iter_count_mma = ( + page_count_mma + Int32(1)) & ~Int32(1) + + pipeline_q.consumer_wait_w_index_phase( + Int32(0), mma_q_consumer_phase) + mma_q_consumer_phase = mma_q_consumer_phase ^ Int32(1) + if block_iter_count_mma > Int32(0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(0)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(1): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(1)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(self.s_stage): + for page_rel_pv in cutlass.range( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + unroll=1, + ): + pv_slot = page_rel_pv & Int32(1) + pv_stage_iter = page_rel_pv // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + page_rel_qk = page_rel_pv + Int32(self.s_stage) + qk_slot = page_rel_qk & Int32(1) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + qk_slot * Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(qk_slot) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + pipeline_q.consumer_release_w_index(Int32(0)) + + if block_iter_count_mma > Int32(0): + page_rel_epi_begin = cutlass.max( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin, block_iter_count_mma, unroll=1 + ): + pv_slot = page_rel_epi & Int32(1) + pv_stage_iter = page_rel_epi // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + pipeline_o_acc.producer_commit_w_index(pv_slot) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def softmax_loop( + self, + stage: cutlass.Constexpr[int], + warp_base: cutlass.Constexpr[int], + softmax_scale_log2: Float32, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tilePlikeFP32: cutlass.Constexpr[int], + tmem_load_atom_pre: cute.CopyAtom, + tmem_store_atom_pre: cute.CopyAtom, + tmem_store_vec_atom_pre: cute.CopyAtom, + thr_mma_qk_pre: cute.core.ThrMma, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg = warp_idx - Int32(warp_base) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stage_i32 = Int32(stage) + + tSAcc = tStS_pre[(None, None), 0, 0, stage] + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom_pre, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS_pre) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32)), + ) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, + tStP_layout, + ) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom_pre, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + tScS_vec_layout = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec = cute.make_tensor(tScS_pre.iterator, tScS_vec_layout) + tStS_vec_layout = cute.composition( + tSAcc.layout, cute.make_layout((self.m_block_size, 2))) + tStStats_r2t_dst = cute.make_tensor( + tSAcc.iterator, tStS_vec_layout) + thr_tmem_store_vec = tcgen05.make_tmem_copy( + tmem_store_vec_atom_pre, + tStStats_r2t_dst, + ).get_slice(group_tidx) + tStStats_r2t = thr_tmem_store_vec.partition_D(tStStats_r2t_dst) + tScStats_r2t = thr_tmem_store_vec.partition_S(tScS_vec) + tScP_shape = ( + self.mma_tiler_qk[0] // thr_mma_qk_pre.thr_id.shape, + tilePlikeFP32, + ) + + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32, + ) + s_consumer_phase = Int32(0) + sm_stats_producer_phase = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=self.rescale_threshold, + ) + softmax.reset() + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + seqused_k = mSeqUsedK[batch_idx] + split_idx = mKvTileIndices[work_idx] + kv_pages = ( + seqused_k + page_size - Int32(1)) // page_size + kv_page_begin = split_idx * kv_chunk_size_pages + kv_page_end = cutlass.min( + kv_pages, kv_page_begin + kv_chunk_size_pages + ) + page_count = kv_page_end - kv_page_begin + block_iter_count = (page_count + Int32(1)) & ~Int32(1) + if const_expr(stage == 0): + stage_page_count = block_iter_count // Int32(2) + else: + stage_page_count = block_iter_count // Int32(2) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seqlen_q, + seqused_k, + False, + False, + False, + True, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + qhead_per_kvhead_packgqa=self.qhead_per_kv, + ) + wg_count = stage_page_count + if wg_count > Int32(0): + page_rel0 = stage_i32 + page_rel0_clamped = cutlass.min( + page_rel0, page_count - Int32(1)) + page_idx_global = kv_page_end - Int32(1) - page_rel0_clamped + kv_valid_cols = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global * page_size, + ) + if page_rel0 >= page_count: + kv_valid_cols = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, + mask, + stage_i32, + s_consumer_phase, + page_idx_global, + qo_tile, + kv_valid_cols, + tStS_t2r, + tScS_t2r, + tStP_r2t, + tSrP_r2t_f32, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, + warp_idx_in_wg, + tStStats_r2t, + tScStats_r2t, + sm_stats_producer_phase, + is_first=True, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + for stage_iter in cutlass.range( + Int32(1), wg_count, unroll=1 + ): + page_rel = ( + stage_iter * Int32(self.s_stage) + stage_i32) + page_rel_clamped = cutlass.min( + page_rel, page_count - Int32(1)) + page_idx_global_n = ( + kv_page_end - Int32(1) - page_rel_clamped) + kv_valid_cols_n = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global_n * page_size, + ) + # Dummy-iter analysis: with s_stage=2, the WG that + # handles stage_i32=0 only ever sees page_rel ≤ + # block_iter_count - 2 < page_count → NEVER dummy. + # The WG with stage_i32=1 sees page_rel = + # block_iter_count - 1 at its last iter, which + # equals page_count iff page_count is odd → only + # WG1 may need the runtime mask_dummy_only guard. + # Pass None for WG0 so the const_expr branch in + # softmax_step eliminates the runtime check + # entirely (compile-time disappears). + if const_expr(stage == 0): + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + # mask_dummy_only=None → no runtime check + ) + else: + is_dummy = page_rel >= page_count + if is_dummy: + kv_valid_cols_n = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + mask_dummy_only=is_dummy, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = softmax.row_sum[0] + tSrStats[1] = softmax.row_max[0] + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + else: + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = Float32(0.0) + tSrStats[1] = -Float32.inf + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + mask: AttentionMask, + stage: Int32, + s_phase: Int32, + page_idx: Int32, + qo_tile: Int32, + kv_valid_cols: Int32, + tStS_t2r: cute.Tensor, + tScS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tSrP_r2t_f32: cute.Tensor, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_vec: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tStStats_r2t: cute.Tensor, + tScStats_r2t: cute.Tensor, + sm_stats_producer_phase: Int32, + is_first: cutlass.Constexpr[bool], + apply_mask: cutlass.Constexpr[bool] = True, + mask_dummy_only: Optional[cutlass.Boolean] = None, + ) -> Int32: + # apply_mask=False is the inner-page fast path: skip both the seqlen + # bounds check and the causal-diagonal check, which together cost ~15 + # cyc per iter on the producer pre-publication critical path that + # gates correction WG's consumer_wait (top long_scoreboard PC in NCU). + # Callers must only set apply_mask=False when they can prove the tile + # is fully unmasked (no partial-page seqlen tail, no causal diagonal + # cut). + # + # mask_dummy_only (runtime bool, used only when apply_mask=False): + # when True the iter is a "dummy" rounded-up iter that needs the + # mask to zero out garbage S — runs the mask at runtime cost. For + # non-dummy iters it stays the fast no-mask path. + pipeline_s_p_o.consumer_wait_w_index_phase(stage, s_phase) + sm_stats_try_acquire = ( + pipeline_sm_stats.producer_try_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + ) + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(apply_mask): + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + elif const_expr(mask_dummy_only is not None): + if mask_dummy_only: + # Dummy iter — zero everything via mask (kv_valid_cols=0 + # makes mask_r2p_lambda set all positions to -inf). + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + # Publish acc_scale in log2-domain (un-exp2'd); correction WG does + # the exp2 only when an actual rescale fires. Removes MUFU.EX2 from + # the sm_stats publication critical path that gates correction's + # consumer_wait (the dominant long_scoreboard hot PC in NCU). + row_max, acc_scale_log2 = softmax.update_row_max_deferred_exp2( + tSrS_t2r.load(), is_first) + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase, sm_stats_try_acquire) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = acc_scale_log2 + tSrStats[1] = row_max + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # exp2 for the internal row_sum carry happens AFTER producer_commit, so + # it no longer extends correction's consumer-wait window. + # acc_scale_log2 == 0.0 in the threshold/first-iter paths makes + # exp2(0)=1.0, which is the no-rescale identity for the row_sum carry — + # semantically equivalent to the original ``acc_scale=1.0`` branch. + if const_expr(is_first): + row_sum_init = Float32(0.0) + else: + acc_scale_mult = cute.math.exp2(acc_scale_log2, fastmath=True) + row_sum_init = softmax.row_sum[0] * acc_scale_mult + # Bulk EX2 emulation parameters. + # + # ex2_emu_freq=16 emulate exp2 with FFMA2 polynomial on + # 15 of every 16 (j, k) positions; the + # remaining 1/16 still issues MUFU.EX2. + # This cuts the MUFU.EX2 throughput bottleneck + # in the softmax inner loop (≈22k cyc + # saved per stage at baseline). + # ex2_emu_res=3 degree-3 polynomial; res=4 broke + # kv=1024 close-tolerance even with + # poly_degree=5 — 3 is the most aggressive + # setting that still passes cos_sim ≥ 0.99 + # against the reference for the fp8 PV path. + # ex2_emu_start_frg=1 skip the emulation for fragment index 0 + # (preserves accuracy on the first iter + # where row_max is least settled). + # + # If you tune these, re-run the variable-kv self-consistency check + # (split vs non-split must stay at cos_min ≥ 0.99). + softmax.row_sum[0] = softmax.scale_apply_exp2_convert_sum( + tSrS_t2r, + row_max, + tSrP_r2t, + row_sum_init, + ex2_emu_freq=16, + ex2_emu_res=3, + ex2_emu_start_frg=1, + ) + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k], + ) + if const_expr(self.split_P_arrive > 0): + split_P_arrive_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive + // self.n_block_size + ) + if const_expr(k + 1 == split_P_arrive_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_s_p_o.consumer_release_w_index(stage) + cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_p_lastsplit.producer_commit_w_index(stage) + else: + pipeline_s_p_o.consumer_release_w_index(stage) + return sm_stats_producer_phase + + @cute.jit + def load( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mQ: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + cute.arch.setmaxregister_decrease(self.num_regs_load) + thr_mma_qk_ld = tiled_mma_qk.get_slice(0) + thr_mma_pv_ld = tiled_mma_pv.get_slice(0) + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_ld = mRequestIndices[work_idx] + qo_tile_ld = mQoTileIndices[work_idx] + split_idx_ld = mKvTileIndices[work_idx] + seqused_k_ld = mSeqUsedK[batch_idx_ld] + kv_pages_ld = ( + seqused_k_ld + page_size - Int32(1)) // page_size + kv_page_begin_ld = split_idx_ld * kv_chunk_size_pages + kv_page_end_ld = cutlass.min( + kv_pages_ld, kv_page_begin_ld + kv_chunk_size_pages + ) + page_count_ld = kv_page_end_ld - kv_page_begin_ld + block_iter_count_ld = ( + page_count_ld + Int32(1)) & ~Int32(1) + physical_page_v0 = Int32(0) + physical_page_v1 = Int32(0) + + mQ_cur_ld = mQ[None, None, None, batch_idx_ld][ + None, None, head_kv_idx + ] + tiler_gQ_ld = ( + (self.mma_tiler_qk[0] * self.q_stage), + self.head_dim, + ) + gQ_ld = cute.local_tile( + mQ_cur_ld, tiler_gQ_ld, (qo_tile_ld, 0)) + gQ_ld = layout_utils.select( + cute.flat_divide(gQ_ld, (self.mma_tiler_qk[0],)), + mode=[0, 2, 1], + ) + tSgQ_ld = thr_mma_qk_ld.partition_A(gQ_ld) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ_ld, sQ + ) + mK_cur_ld = mK_paged[None, None, head_kv_idx, None] + gK_ld = cute.local_tile( + mK_cur_ld, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + tSgK_ld = thr_mma_qk_ld.partition_B(gK_ld) + tKsK_ld, tKgK_ld = cpasync.tma_partition( + tma_atom_K, 0, cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_ld, 0, 3), + ) + mV_cur_ld = mV_paged[None, None, head_kv_idx, None] + gV_ld = cute.local_tile( + mV_cur_ld, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + tOgV_ld = thr_mma_pv_ld.partition_B(gV_ld) + tVsV_ld, tVgV_ld = cpasync.tma_partition( + tma_atom_V, 0, cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV_ld, 0, 3), + ) + + if block_iter_count_ld > Int32(0): + # Prime K0 before Q; then follow BSA order + # K1, V0, K2, V1, ... + page_idx_ld0 = kv_page_end_ld - Int32(1) + physical_page_v0 = mPageTable[batch_idx_ld, page_idx_ld0] + physical_page_v1 = physical_page_v0 + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v0, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + self.load_Q( + load_Q_fn_full, + pipeline_q, + Int32(0), + q_producer_phase, + ) + q_producer_phase = q_producer_phase ^ Int32(1) + + if block_iter_count_ld > Int32(0): + if block_iter_count_ld > Int32(1): + page_rel_k1 = cutlass.min( + Int32(1), page_count_ld - Int32(1)) + page_idx_ld1 = kv_page_end_ld - Int32(1) - page_rel_k1 + physical_page_v1 = mPageTable[ + batch_idx_ld, page_idx_ld1] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v1, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + if block_iter_count_ld > Int32(2): + for page_rel in cutlass.range( + Int32(0), + block_iter_count_ld - Int32(2), + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + page_rel_k_ld = cutlass.min( + page_rel + Int32(2), + page_count_ld - Int32(1), + ) + page_idx_k_ld = ( + kv_page_end_ld - Int32(1) - page_rel_k_ld) + physical_page_k_ld = mPageTable[ + batch_idx_ld, page_idx_k_ld] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_k_ld, + pipeline_kv, + kv_producer_state, + ) + if (page_rel & Int32(1)) == Int32(0): + physical_page_v0 = physical_page_k_ld + else: + physical_page_v1 = physical_page_k_ld + kv_producer_state.advance() + + page_rel_epi_begin_ld = cutlass.max( + Int32(0), + block_iter_count_ld - Int32(2), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin_ld, + block_iter_count_ld, + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel_epi, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel_epi & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.consumer_advance() + + pipeline_kv.producer_tail(kv_producer_state) + pipeline_q.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), q_producer_phase) + + @cute.jit + def load_Q( + self, + load_Q_fn: Callable, + pipeline_q: pipeline.PipelineAsync, + stage: Int32, + phase: Int32, + ) -> None: + pipeline_q.producer_acquire_w_index_phase(stage, phase) + load_Q_fn( + src_idx=Int32(0), + dst_idx=stage, + tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage), + ) + + @cute.jit + def load_KV_physical( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + physical_page: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + ) -> None: + pipeline_kv.producer_acquire(producer_state) + cute.copy( + tma_atom, + tXgX[(None, 0, physical_page)], + tXsX[(None, producer_state.index)], + tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state), + ) + +_atten_compile_cache: dict[tuple[object, ...], object] = {} + + +def run_decode_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + disable_softmax_exp2: bool = False, + O_partial_dummy: Optional[torch.Tensor] = None, + LSE_partial_dummy: Optional[torch.Tensor] = None, +) -> None: + """Launch the SM100 UMMA paged decode attention CUTE DSL kernel. + + qhead_per_kv is derived from input shapes (q.shape[1] // k.shape[1]). + disable_softmax_exp2 toggles the sage-style host flag (decision §1.7); + default False keeps full ex2 emulation. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` let callers pre-allocate the + placeholder buffers for the non-split path, avoiding ~5us of per-call + ``torch.empty`` overhead in tight decoding loops. + """ + + q_dtype = torch2cute_dtype_map[q.dtype] + o_dtype = torch2cute_dtype_map[out.dtype] + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + write_lse = bool(return_lse) or bool(split_kv) + if int(seqlen_q) != q_tokens_per_group: + raise NotImplementedError( + "decode fp8 currently assumes one full packed-q tile: " + f"seqlen_q must equal {q_tokens_per_group}, got {seqlen_q}" + ) + key = ( + "decode_attention", + q.shape[-1], + q_dtype, + o_dtype, + bool(split_kv), + bool(causal), + int(qhead_per_kv), + int(seqlen_q), + bool(write_lse), + bool(disable_softmax_exp2), + ) + if key not in _atten_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + head_q = cute.sym_int64() + num_pages = cute.sym_int64() + head_kv = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + max_pages = cute.sym_int64() + work_capacity = cute.sym_int64() + partial_rows = cute.sym_int64() + partial_rows_flat = cute.sym_int64() + head_dim = int(q.shape[-1]) + kernel = SparseDecodeAttentionForwardSm100( + head_dim=head_dim, + qhead_per_kv=int(qhead_per_kv), + page_size=int(page_size), + split_kv=bool(split_kv), + causal=bool(causal), + write_lse=bool(write_lse), + disable_softmax_exp2=bool(disable_softmax_exp2), + ) + # Always pass non-None fake tensors so the @cute.kernel positional + # arg marshalling stays stable; the kernel only reads these when + # split_kv=True (decision #10 epilogue branch). + fake_O_partial = make_fake_tensor( + Float32, (partial_rows_flat, head_dim), divisibility=4) + fake_LSE_partial = make_fake_tensor( + Float32, (partial_rows, head_q), divisibility=1, leading_dim=1) + # Q is passed as a [B, Sq, Hq, D] view so the kernel can build the same + # PackGQA TMA view used by FA/BSA and issue one full-tile Q TMA. + # O still uses the compact 2D view for the packed-GQA TMA epilogue. + total_q_flat = cute.sym_int64() + _atten_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor( + q_dtype, (batch, int(seqlen_q), head_q, head_dim), + divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(Int32, (batch, max_pages), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(o_dtype, (total_q_flat, head_dim), divisibility=128 // o_dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + fake_O_partial, + fake_LSE_partial, + Float32(float(softmax_scale)), + Int32(int(seqlen_q)), + Int32(int(page_size)), + Int32(int(kv_chunk_size_pages)), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + q_4d = q.view( + q.shape[0] // int(seqlen_q), int(seqlen_q), q.shape[1], q.shape[2]) + out_2d = out.view(out.shape[0] * out.shape[1], out.shape[2]) + # Compile keeps non-None fake partial buffers for positional stability + # (see fake_O_partial / fake_LSE_partial above). Runtime callers that + # don't need them (split_kv=False) pass None; allocate small uninitialized + # dummy buffers so the kernel signature still matches without launching + # torch fill kernels. + if O_partial is None: + # Reuse caller-cached dummy when available (e.g. the + # SparseDecodePagedAttentionWrapper plan() pre-allocation), else + # allocate a small placeholder on the fly. + O_partial_kernel = ( + O_partial_dummy + if O_partial_dummy is not None + else torch.empty( + (1, q.shape[2]), dtype=torch.float32, device=q.device) + ) + else: + O_partial_kernel = O_partial.view( + O_partial.shape[0] * O_partial.shape[1], O_partial.shape[2]) + if LSE_partial is None: + LSE_partial = ( + LSE_partial_dummy + if LSE_partial_dummy is not None + else torch.empty( + (1, q.shape[1]), dtype=torch.float32, device=q.device) + ) + with torch.cuda.nvtx.range("Decode_Attention"): + _atten_compile_cache[key]( + q_4d, k, v, page_table, seqused_k, + request_indices, qo_tile_indices, kv_tile_indices, block_valid_mask, + split_counts, o_indptr, + out_2d, lse, O_partial_kernel, LSE_partial, + softmax_scale, seqlen_q, page_size, kv_chunk_size_pages, + ) + + +__all__ = ["SparseDecodeAttentionForwardSm100", "run_decode_attention"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bab26c200fff9c62644849b18e55f060fa8783f --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Paged decode split-KV scheduling backed by the precompiled Torch op. + +The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is +built ahead of time by kernel-builder. The op returns the schedule arrays +plus a fixed-order scalar summary, which is reassembled into the schedule +dict here. +""" + +from __future__ import annotations + +import torch + +from ....._ops import ops + +# Order of the scalar summary returned by the op; must match +# csrc/build_decode_schedule.cu. +_SCALAR_KEYS = ( + "split_kv", + "cta_tile_q", + "num_q_tiles", + "kv_chunk_size_pages", + "kv_chunk_size_tokens", + "work_count", + "padded_work_count", + "partial_rows", + "max_split_count", + "max_grid_size", + "active_blocks_per_sm", + "num_sms", + "base_cta", +) + + +def build_decode_schedule( + seqused_k: torch.Tensor, + *, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: int = 0, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> dict[str, object]: + """GPU-only schedule build: single CUDA kernel produces all schedule + index arrays on device. Only a small summary tensor is D2H'd at the end + so the wrapper can size O_partial, pick the kernel grid, and choose + split/non-split compile path. + + ``max_seqlen_k`` is required as the host-side worst-case bound for + padding the work-tile arrays. + """ + + ( + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + kv_pages, + merge_indptr, + o_indptr, + scalars, + ) = ops.build_decode_schedule( + seqused_k, + int(page_size), + int(seqlen_q), + int(num_qo_heads), + int(num_kv_heads), + int(head_dim), + int(max_seqlen_k), + bool(enable_cuda_graph), + int(max_grid_size), + int(fixed_split_size), + bool(disable_split_kv), + ) + + raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars))) + raw["split_kv"] = bool(raw["split_kv"]) + raw["request_indices"] = request_indices + raw["qo_tile_indices"] = qo_tile_indices + raw["kv_tile_indices"] = kv_tile_indices + raw["block_valid_mask"] = block_valid_mask + raw["split_counts"] = split_counts + raw["kv_pages"] = kv_pages + raw["merge_indptr"] = merge_indptr + raw["o_indptr"] = o_indptr + + # The CUDA kernel writes into worst-case-padded buffers (size = + # batch * num_q_tiles * max_pages_global) but only the first + # ``padded_work_count`` entries are valid. Downstream consumers + # (tile_scheduler) take grid size from ``request_indices.shape[0]`` + # so we narrow the views to that count; the underlying allocation + # is unchanged so this is a view, no copy. + pad = int(raw["padded_work_count"]) + for key in ( + "request_indices", + "qo_tile_indices", + "kv_tile_indices", + "block_valid_mask", + ): + raw[key] = raw[key].narrow(0, 0, pad) + return raw + + +__all__ = ["build_decode_schedule"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..3d308bd26c281e744cc7289b1265d8192c1f39e7 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/combine.py @@ -0,0 +1,680 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""LDGSTS split-KV combine for paged decode attention.""" + +import math +from functools import partial +from typing import Type + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.cute.nvgpu import cpasync + +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map + + +class SparseDecodeForwardCombine: + """Combine split-KV decode partials with FA-style LDGSTS staging. + + ``mO_partial`` and ``mLSE_partial`` use the split-major padded layout: + ``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where + ``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``. + A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one + ``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM + via ``cpasync.CopyG2SOp`` before the split reduction. + """ + + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + *, + tile_m: int = 64, + k_block_size: int = 128, + max_splits: int = 4, + num_threads: int = 256, + stages: int = 2, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}" + ) + if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]: + raise TypeError(f"Unsupported output dtype: {dtype}") + if dtype_partial is not Float32: + raise TypeError("decode O_partial must be Float32") + if k_block_size != head_dim: + raise NotImplementedError("decode combine currently uses one D=128 k block") + if tile_m % 8 != 0: + raise ValueError("decode combine tile_m must be divisible by 8") + if max_splits < 1 or max_splits > 256: + raise ValueError("decode combine max_splits must be in [1, 256]") + + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.max_splits = max_splits + self.num_threads = num_threads + self.stages = stages + self.is_even_k = head_dim % k_block_size == 0 + + def _setup_attributes(self) -> None: + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 + if self.k_block_size % 128 == 0 + else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout + ) + + lse_copy_bits = Float32.width + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, cute.make_layout(1) + ) + + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.max_splits, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) + ) + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32 + mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] + mLSE: cute.Tensor, # [total_q, Hq] fp32 + seqlen_q: Int32, + q_tokens_per_group: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mO_partial.element_type is not Float32): + raise TypeError("decode O_partial tensor must be Float32") + if const_expr(mLSE_partial.element_type is not Float32): + raise TypeError("decode LSE_partial tensor must be Float32") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode LSE tensor must be Float32") + if const_expr(mO.element_type != self.dtype): + raise TypeError("decode O tensor dtype must match kernel dtype") + if const_expr(mSplitCounts.element_type is not Int32): + raise TypeError("decode split_counts tensor must be Int32") + if const_expr(mOIndptr.element_type is not Int32): + raise TypeError("decode o_indptr tensor must be Int32") + + mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [ + assume_tensor_aligned(t) + for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE) + ] + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.tile_m], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + total_q = mO.shape[0] + head_q = mO.shape[1] + batch = mSplitCounts.shape[0] + head_divmod = FastDivmodDivisor(head_q) + grid = ( + cute.ceil_div(seqlen_q * head_q, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mSplitCounts, + mOIndptr, + mO, + mLSE, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + head_divmod, + Int32(total_q), + Int32(head_q), + seqlen_q, + q_tokens_per_group, + ).launch( + grid=grid, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + head_divmod: FastDivmodDivisor, + total_q: Int32, + head_q: Int32, + seqlen_q: Int32, + q_tokens_per_group: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + + split_count = mSplitCounts[batch_idx] + q_stride = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + max_idx = seqlen_q * head_q + + if m_block * Int32(self.tile_m) < max_idx: + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + partial_base = mOIndptr[batch_idx] + q_idx + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < split_count: + partial_row = partial_base + si * q_stride + lse_ptr = ( + mLSE_partial.iterator + + Int64(partial_row) * Int64(head_q) + + Int64(q_head) + ) + lse_gmem_ptr = cute.make_ptr( + Float32, + lse_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + lse_src = cute.make_tensor(lse_gmem_ptr, (1,)) + cute.copy( + gmem_thr_copy_LSE, + lse_src, + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOqidx = cute.make_rmem_tensor(num_rows, Int32) + tOhidx = cute.make_rmem_tensor(num_rows, Int32) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] + idx = m_block * Int32(self.tile_m) + mi + if idx >= max_idx: + tOqidx[m] = Int32(0) + tOhidx[m] = -Int32(1) + else: + tOqidx[m], tOhidx[m] = divmod(idx, head_divmod) + + load_O_partial = partial( + self.load_O_partial, + mO_partial, + mOIndptr, + gmem_tiled_copy_O_partial, + tOsO_partial, + tOqidx, + tOhidx, + tOcO, + batch_idx, + q_stride, + split_count, + head_q, + k_block, + ) + + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < split_count: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + max_valid_idx = -Int32(1) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + + lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max + LOG2_E = Float32(math.log2(math.e)) + lse_sum_cur = Float32(0.0) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + (ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E, + fastmath=True, + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = ( + Float32(0.0) + if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur) + else cute.arch.rcp_approx(lse_sum_cur) + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + if mi < Int32(self.tile_m): + sMaxValidSplit[mi] = max_valid_split[m] + + if k_block == Int32(0): + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + q_abs = batch_idx * seqlen_q + q_idx + mLSE[q_abs, q_head] = lse_sum[m] + + cute.arch.sync_threads() + + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max( + thr_max_valid_split, + sMaxValidSplit[tOcO[0, m, 0][0]], + ) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(Float32(0.0)) + + stage_load = self.stages - 1 + stage_compute = 0 + for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4): + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] + + split_to_load = s + Int32(self.stages - 1) + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0): + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0): + q_abs = batch_idx * seqlen_q + tOqidx[m] + row_ptr = ( + mO.iterator + + ( + (Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + mO_partial: cute.Tensor, + mOIndptr: cute.Tensor, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOsO_partial: cute.Tensor, + tOqidx: cute.Tensor, + tOhidx: cute.Tensor, + tOcO: cute.Tensor, + batch_idx: Int32, + q_stride: Int32, + split_count: Int32, + head_q: Int32, + k_block: Int32, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= Int32(0): + if split < split_count: + partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m] + row_ptr = ( + mO_partial.iterator + + ( + (Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO_partial.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_partial_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_partial_row_copy = cute.tiled_divide( + mO_partial_row, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_row_copy[None, k_idx], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, None].fill(Float32(0.0)) + + +_combine_compile_cache: dict[tuple[object, ...], object] = {} + + +def _next_power_of_2(x: int) -> int: + return 1 << (max(int(x), 1) - 1).bit_length() + + +def run_decode_combine( + O_partial: torch.Tensor, + LSE_partial: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + *, + seqlen_q: int, + q_tokens_per_group: int, + max_split_count: int, +) -> None: + """Launch LDGSTS decode split-KV combine.""" + + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + if lse.dtype != torch.float32: + raise TypeError(f"lse must be torch.float32, got {lse.dtype}") + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_indptr.dtype != torch.int32: + raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}") + if out.ndim != 3 or O_partial.ndim != 3: + raise ValueError("decode combine expects O tensors with shape [rows, heads, D]") + if LSE_partial.ndim != 2 or lse.ndim != 2: + raise ValueError("decode combine expects LSE tensors with shape [rows, heads]") + if out.shape[1:] != O_partial.shape[1:]: + raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}") + if lse.shape != out.shape[:2]: + raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}") + if LSE_partial.shape != O_partial.shape[:2]: + raise ValueError( + f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}" + ) + if split_counts.ndim != 1 or o_indptr.ndim != 1: + raise ValueError("split_counts and o_indptr must be rank-1 tensors") + if o_indptr.shape != (split_counts.shape[0] + 1,): + raise ValueError( + f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)" + ) + seqlen_q = int(seqlen_q) + q_tokens_per_group = int(q_tokens_per_group) + if seqlen_q <= 0: + raise ValueError("seqlen_q must be positive") + if q_tokens_per_group <= 0: + raise ValueError("q_tokens_per_group must be positive") + if out.shape[0] != split_counts.shape[0] * seqlen_q: + raise ValueError( + f"out rows {out.shape[0]} must equal batch*seqlen_q " + f"{split_counts.shape[0]}*{seqlen_q}" + ) + + max_split_count = int(max_split_count) + if max_split_count <= 0: + raise ValueError("max_split_count must be positive") + if max_split_count > 256: + raise NotImplementedError( + f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}" + ) + max_splits = max(4, _next_power_of_2(max_split_count)) + tile_m = 64 + k_block_size = int(out.shape[-1]) + stages = 2 + + dtype = torch2cute_dtype_map[out.dtype] + key = ( + "decode_combine_ldgsts", + out.shape[-1], + dtype, + O_partial.dtype, + seqlen_q, + q_tokens_per_group, + tile_m, + k_block_size, + max_splits, + stages, + ) + if key not in _combine_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + partial_rows = cute.sym_int64() + head_q = cute.sym_int64() + head_dim = int(out.shape[-1]) + kernel = SparseDecodeForwardCombine( + dtype=dtype, + dtype_partial=Float32, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + max_splits=max_splits, + stages=stages, + ) + _combine_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4), + make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + Int32(seqlen_q), + Int32(q_tokens_per_group), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"): + _combine_compile_cache[key]( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q, + q_tokens_per_group, + ) + + +__all__ = ["SparseDecodeForwardCombine", "run_decode_combine"] diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..13b487402bf52d008b7ff7edbe9d584f366256b9 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Decode-specific tile scheduler for paged fp8 attention. + +The pre-schedule step builds a dense worklist over decode KV chunks. Static +persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id. +CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``, +and maps the canceled CTA coordinate back to the same logical task space. +""" + +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ....quack.cute_dsl_utils import ParamsBase + +from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo + + +@dataclass +class DecodeTileSchedulerArguments(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + +class DecodeTileScheduler: + """Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks.""" + + @dataclass + class Params(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + num_heads_kv_divmod: FastDivmodDivisor + total_tasks: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + def __init__( + self, + params: Params, + task_idx: Int32, + clc_scheduler=None, + clc_pipeline=None, + clc_consumer_state=None, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ): + self.params = params + self._task_idx = task_idx + self._clc_scheduler = clc_scheduler + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + self._clc_response_ptr = clc_response_ptr + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: DecodeTileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1" + total_tasks = args.work_capacity * args.num_heads_kv + return DecodeTileScheduler.Params( + args.work_capacity, + args.num_heads_kv, + FastDivmodDivisor(args.num_heads_kv), + total_tasks, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + @staticmethod + def _clc_grid_shape(params: Params): + return ( + cute.round_up(params.work_capacity, params.cluster_shape_m), + params.num_heads_kv, + Int32(1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ) -> "DecodeTileScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + from cutlass.utils import ( + ClcDynamicPersistentTileScheduler, + ClcDynamicPersistentTileSchedulerParams, + ) + + cutlass_params = ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + block_idx = cute.arch.block_idx() + grid_dim = cute.arch.grid_dim() + clc_scheduler = ClcDynamicPersistentTileScheduler.create( + cutlass_params, + block_idx, + grid_dim, + clc_response_ptr, + ) + return DecodeTileScheduler( + params, + block_idx[0], + clc_scheduler, + clc_response_ptr=clc_response_ptr, + loc=loc, + ip=ip, + ) + + if const_expr(params.cluster_shape_m == 1): + task_idx = cute.arch.block_idx()[0] + else: + task_idx = cute.arch.cluster_idx()[0] + return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return DecodeTileScheduler._clc_grid_shape(params) + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m + grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + @cute.jit + def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo: + work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod) + return WorkTileInfo( + (Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)), + is_valid, + ) + + @cute.jit + def _clc_work_to_coords(self, work) -> WorkTileInfo: + work_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + work_idx = work_idx // self.params.cluster_shape_m + return WorkTileInfo( + ( + Int32(work_idx), + Int32(work.tile_idx[1]), + Int32(0), + Int32(0), + ), + work.is_valid_tile, + ) + + @cute.jit + def _clc_response_to_work( + self, + response_stage: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + # CLC responses are 16B opaque records. The scheduler warp can query + # the next stage before all consumer warps have read the current one, + # so each pipeline stage needs its own response slot. + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response( + response_ptr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32( + self.params.cluster_shape_m) + return WorkTileInfo( + ( + Int32(m_idx) + cta_idx_in_cluster, + Int32(n_idx), + Int32(l_idx), + Int32(0), + ), + is_valid, + ) + + @cute.jit + def get_current_work( + self, + response_stage: Int32 = Int32(0), + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_response_to_work( + response_stage, loc=loc, ip=ip) + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + is_valid = self._task_idx < self.params.total_tasks + return self._task_to_work(self._task_idx, is_valid) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_scheduler.initial_work_tile_info() + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work( + self, + *, + loc=None, + ip=None, + mbarrier_addr=None, + response_stage: Int32 = Int32(0), + ): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + assert mbarrier_addr is not None + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + with cute.arch.elect_one(): + cute.arch.issue_clc_query( + mbarrier_addr, response_ptr, loc=loc, ip=ip) + else: + assert mbarrier_addr is None + if const_expr(self.params.cluster_shape_m == 1): + self._task_idx += cute.arch.grid_dim()[0] + else: + self._task_idx += cute.arch.cluster_dim()[0] + + def consumer_advance(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + response_stage = self._clc_consumer_state.index + self._clc_pipeline.consumer_wait(self._clc_consumer_state) + work_tile = self.get_current_work(response_stage=response_stage) + self._clc_pipeline.consumer_release(self._clc_consumer_state) + self._clc_consumer_state.advance() + return work_tile + self.advance_to_next_work() + return self.get_current_work() + + def set_clc_pipeline(self, clc_pipeline, clc_consumer_state): + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return DecodeTileScheduler(*obj_list, loc=self._loc) diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py new file mode 100644 index 0000000000000000000000000000000000000000..8e59b3d55bd3e9b164dac1e474dd648501c1aa51 --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_k2q_csr.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse k2q CSR builder for SM100. + +Thin dispatcher that calls the CUDA C++ kernel pipeline in +``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and +``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than +silently falling back to a torch-reference path. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL + + +_SUPPORTED_TOPK = (4, 8, 16, 32) +_SUPPORTED_BLK_KV = 128 + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +class SparseK2qCsrBuilderSm100: + """Build the k2q CSR reverse index for sparse attention on SM100. + + The public API matches the historical CUTE DSL builder so callers + (``sparse_index_utils.build_k2q_csr``, attention kernels) need no + changes. Internally the kernel pipeline runs five CUDA C++ kernels: + ``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem`` + -> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``). + """ + + def __init__(self) -> None: + # No persistent state — the JIT-compiled extension is loaded + # lazily by ``src.sm100.build_k2q_csr`` on first call. + self._run = None + self._run_with_schedule = None + + def _ensure_loaded(self) -> None: + if self._run is None: + from ...src.sm100.build_k2q_csr import ( + run_build_k2q_csr, + run_build_k2q_csr_with_schedule, + ) + self._run = run_build_k2q_csr + self._run_with_schedule = run_build_k2q_csr_with_schedule + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + *, + total_k: int, + blk_kv: int = 128, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]: + # ---- Validation ---------------------------------------------------- + if blk_kv != _SUPPORTED_BLK_KV: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports blk_kv == " + f"{_SUPPORTED_BLK_KV}, got {blk_kv}" + ) + if q2k_indices.dtype != torch.int32: + raise TypeError( + f"q2k_indices must be torch.int32, got {q2k_indices.dtype}" + ) + if q2k_indices.ndim != 3: + raise ValueError( + f"q2k_indices must be rank-3 [head_kv, total_q, topK], " + f"got shape {tuple(q2k_indices.shape)}" + ) + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous") + if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32") + if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must share shape [B + 1]" + ) + if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda): + raise ValueError("all inputs must be CUDA tensors") + if ( + q2k_indices.device != cu_seqlens_q.device + or q2k_indices.device != cu_seqlens_k.device + ): + raise ValueError("all inputs must share a device") + if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous(): + raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous") + + total_k = int(total_k) + if total_k < 0: + raise ValueError(f"total_k must be non-negative, got {total_k}") + + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + if topk not in _SUPPORTED_TOPK: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports topK in " + f"{_SUPPORTED_TOPK}, got {topk}" + ) + + batch = int(cu_seqlens_q.shape[0] - 1) + if batch < 0: + raise ValueError("cu_seqlens tensors must have shape [B + 1]") + if return_schedule and max_seqlen_k is None: + raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True") + max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k + max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv) + if total_rows is not None: + total_rows = int(total_rows) + elif total_k % blk_kv == 0: + total_rows = total_k // blk_kv + else: + total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv) + if total_rows < 0: + raise ValueError(f"total_rows must be non-negative, got {total_rows}") + total_rows = max(total_rows, 0) + nnz_upper_bound = total_q * topk + qhead_per_kv = int(qhead_per_kv) + if qhead_per_kv <= 0: + raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}") + if return_schedule: + if max_seqlen_q is None: + raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True") + max_seqlen_q = int(max_seqlen_q) + + # ---- Output tensors ------------------------------------------------ + device = q2k_indices.device + k2q_row_ptr = torch.empty( + (head_kv, total_rows + 1), dtype=torch.int32, device=device, + ) + k2q_q_indices = torch.empty( + (head_kv, nnz_upper_bound), dtype=torch.int32, device=device, + ) + schedule = None + if return_schedule: + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), dtype=torch.int32, device=device + ) + work_count = torch.empty((1,), dtype=torch.int32, device=device) + qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.empty( + (total_q, head_kv), dtype=torch.int32, device=device + ) + schedule = SparseAttentionSchedule( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + qsplit_indices=qsplit_indices, + split_counts=split_counts, + target_q_per_cta=target_q_per_cta, + ) + + # Empty workload short-circuit (the CUDA path also handles this, + # but doing it here saves a JIT load for trivial calls). + if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0: + k2q_row_ptr.zero_() + k2q_q_indices.fill_(-1) + if schedule is not None: + schedule.work_count.zero_() + schedule.split_counts.zero_() + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices + + self._ensure_loaded() + with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"): + if schedule is None: + self._run( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + topk, + blk_kv, + total_rows, + max_kv_blocks, + ) + else: + self._run_with_schedule( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + schedule.scheduler_metadata, + schedule.work_count, + schedule.qsplit_indices, + schedule.split_counts, + topk, + blk_kv, + total_rows, + max_kv_blocks, + schedule.target_q_per_cta, + schedule.work_capacity, + max_seqlen_q, + ) + if schedule is not None: + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices diff --git a/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..662e48f905249913a381f5d11a3f0c49626e98bd --- /dev/null +++ b/build/torch212-cxx11-cu130-x86_64-linux/src/sm100/prepare_scheduler.py @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Prepare scheduler for SM100 sparse attention. + +The scheduler converts uneven CSR k2q row fanout into a flat worklist consumed +by sparse attention kernels. Each work item covers a contiguous q-index range +within one (head_kv, csr row) and carries the decoded batch/KV-block coordinate. +""" + +from dataclasses import dataclass +from typing import Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32, const_expr + +from ...src.common import copy_utils, utils +from ...src.common.cute_dsl_utils import ( + assume_tensor_aligned, + to_cute_tensor as to_cute_tensor_kvouter, +) + + +_PREPARE_COMPILE_CACHE: dict = {} + + +@dataclass +class SparseAttentionSchedule: + enabled: bool + scheduler_metadata: Optional[torch.Tensor] + work_count: Optional[torch.Tensor] + qsplit_indices: Optional[torch.Tensor] = None + split_counts: Optional[torch.Tensor] = None + target_q_per_cta: int = 0 + + @property + def work_capacity(self) -> int: + return 0 if self.scheduler_metadata is None else int(self.scheduler_metadata.shape[0]) + + +SparseSchedulePlan = SparseAttentionSchedule + + +class SparseAttentionScheduleModel: + """Host-side helpers for sparse attention schedule sizing.""" + + @staticmethod + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + @staticmethod + def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + def _target_q_per_cta( + self, + *, + total_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + num_sm = torch.cuda.get_device_properties(device).multi_processor_count + if usable_SM_count > 0: + num_sm = min(int(usable_SM_count), num_sm) + q_tokens_per_group = 128 // qhead_per_kv + total_refs_upper = total_q * topk * head_kv + desired_work_items = max(num_sm * 2, 1) + total_groups_upper = self._ceil_div(max(total_refs_upper, 1), q_tokens_per_group) + target_groups_per_cta = min( + 512, + max(1, self._ceil_div(total_groups_upper, desired_work_items)), + ) + return target_groups_per_cta * q_tokens_per_group + + def balanced_target_q_per_cta( + self, + *, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + q_tokens_per_group = 128 // qhead_per_kv + occupancy_target = self._target_q_per_cta( + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + sink_balance_cap = max(q_tokens_per_group, int(topk) * int(blk_kv) * 2) + target = min(max(occupancy_target, q_tokens_per_group), sink_balance_cap) + return self._round_up(target, q_tokens_per_group) + + def flat_schedule_capacity( + self, + *, + total_rows: int, + total_q: int, + topk: int, + head_kv: int, + target_q_per_cta: int, + ) -> int: + row_upper = max(total_rows, 0) * max(head_kv, 1) + refs_upper = max(total_q, 0) * max(topk, 1) * max(head_kv, 1) + split_upper = self._ceil_div(max(refs_upper, 1), max(target_q_per_cta, 1)) + return max(1, row_upper + split_upper) + + +SPARSE_SCHEDULE_MODEL = SparseAttentionScheduleModel() + + +class SparseAttentionPrepareFlatScheduleSm100: + """Build a compact flat worklist by splitting each CSR row into chunks.""" + + def __init__( + self, + *, + num_threads: int = 128, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + self.warps_per_cta = num_threads // 32 + + @cute.jit + def _emit_work( + self, + mSchedulerMetadata: cute.Tensor, + work_idx: Int32, + work_capacity: Int32, + head_kv_idx: Int32, + row_linear: Int32, + q_begin: Int32, + q_count: Int32, + batch_idx: Int32, + kv_block_idx: Int32, + ): + if work_idx < work_capacity: + mSchedulerMetadata[work_idx, Int32(0)] = head_kv_idx + mSchedulerMetadata[work_idx, Int32(1)] = row_linear + mSchedulerMetadata[work_idx, Int32(2)] = q_begin + mSchedulerMetadata[work_idx, Int32(3)] = q_count + mSchedulerMetadata[work_idx, Int32(4)] = batch_idx + mSchedulerMetadata[work_idx, Int32(5)] = kv_block_idx + + @cute.jit + def _rows_in_batch( + self, + mCuSeqlensK: cute.Tensor, + batch_idx: Int32, + blk_kv: Int32, + ) -> Int32: + seqlen = mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + return (seqlen + blk_kv - Int32(1)) // blk_kv + + @cute.jit + def _rows_before_level( + self, + mCuSeqlensK: cute.Tensor, + level: Int32, + blk_kv: Int32, + ) -> Int32: + total = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + total += cutlass.min(rows, level) + return total + + @cute.jit + def _max_rows_per_batch( + self, + mCuSeqlensK: cute.Tensor, + blk_kv: Int32, + ) -> Int32: + max_rows = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + max_rows = cutlass.max(max_rows, rows) + return max_rows + + @cute.jit + def _decode_sparse_row_linear( + self, + mCuSeqlensK: cute.Tensor, + row_linear: Int32, + blk_kv: Int32, + ) -> tuple[Int32, Int32]: + lo = Int32(0) + hi = self._max_rows_per_batch(mCuSeqlensK, blk_kv) + while lo < hi: + mid = (lo + hi) // Int32(2) + rows_before_next = self._rows_before_level( + mCuSeqlensK, + mid + Int32(1), + blk_kv, + ) + if rows_before_next <= row_linear: + lo = mid + Int32(1) + else: + hi = mid + + level = lo + offset = row_linear - self._rows_before_level(mCuSeqlensK, level, blk_kv) + active_idx = Int32(0) + batch_idx = Int32(0) + found = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + if found == Int32(0): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + if rows > level: + if active_idx == offset: + batch_idx = b + found = Int32(1) + active_idx += Int32(1) + return batch_idx, level + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + blk_kv: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mCuSeqlensK.element_type != Int32): + raise TypeError("mCuSeqlensK must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount = [ + assume_tensor_aligned(t) + for t in (mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount) + ] + total_rows = mK2qCounts.shape[1] - Int32(1) + total_row_heads = total_rows * num_heads_kv + grid_ctas = cute.ceil_div(total_row_heads, self.warps_per_cta) + + self.kernel( + mK2qCounts, + mCuSeqlensK, + mSchedulerMetadata, + mWorkCount, + target_q_per_cta, + work_capacity, + num_heads_kv, + total_rows, + blk_kv, + ).launch( + grid=(grid_ctas,), + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + total_rows: Int32, + blk_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + lane_idx = tidx % Int32(32) + warp_idx = tidx // Int32(32) + row_head_idx = block_idx * Int32(self.warps_per_cta) + warp_idx + total_row_heads = total_rows * num_heads_kv + + head_kv_idx = Int32(0) + row_linear = Int32(0) + row_count = Int32(0) + num_chunks = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + if row_head_idx < total_row_heads: + row_linear = row_head_idx // num_heads_kv + head_kv_idx = row_head_idx - row_linear * num_heads_kv + if lane_idx == Int32(0): + row_start = mK2qCounts[head_kv_idx, row_linear] + row_end = mK2qCounts[head_kv_idx, row_linear + Int32(1)] + row_count = row_end - row_start + batch_idx, kv_block_idx = self._decode_sparse_row_linear( + mCuSeqlensK, + row_linear, + blk_kv, + ) + if row_count > Int32(0): + num_chunks = ( + row_count + target_q_per_cta - Int32(1) + ) // target_q_per_cta + row_count = cute.arch.shuffle_sync(row_count, offset=0) + num_chunks = cute.arch.shuffle_sync(num_chunks, offset=0) + batch_idx = cute.arch.shuffle_sync(batch_idx, offset=0) + kv_block_idx = cute.arch.shuffle_sync(kv_block_idx, offset=0) + + chunk_idx = lane_idx + while chunk_idx < num_chunks: + work_idx = cute.arch.atomic_add( + mWorkCount.iterator.llvm_ptr, + Int32(1), + sem="relaxed", + scope="gpu", + ) + q_begin = chunk_idx * target_q_per_cta + q_count = cutlass.min(target_q_per_cta, row_count - q_begin) + self._emit_work( + mSchedulerMetadata, + work_idx, + work_capacity, + head_kv_idx, + row_linear, + q_begin, + q_count, + batch_idx, + kv_block_idx, + ) + chunk_idx += Int32(32) + + +class SparseAttentionPrepareFwdSplitAtomicSm100: + """Build packed q_idx/split_slot metadata for fwd K1 without K1 atomics.""" + + def __init__( + self, + *, + num_threads: int = 256, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + + @cute.struct + class SharedStorage: + sRow: cute.struct.MemRange[Int32, 3] + + self.shared_storage = SharedStorage + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + work_capacity: Int32, + max_seqlen_q: Int32, + topk: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mK2qIndices.element_type != Int32): + raise TypeError("mK2qIndices must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + if const_expr(mK2qQSplitIndices.element_type != Int32): + raise TypeError("mK2qQSplitIndices must be Int32") + if const_expr(mSplitCounts.element_type != Int32): + raise TypeError("mSplitCounts must be Int32") + if const_expr(mCuSeqlensQ.element_type != Int32): + raise TypeError("mCuSeqlensQ must be Int32") + ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) = [ + assume_tensor_aligned(t) + for t in ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) + ] + self.kernel( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + max_seqlen_q, + topk, + ).launch( + grid=(work_capacity,), + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + max_seqlen_q: Int32, + topk: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + if block_idx < mWorkCount[Int32(0)]: + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sRow = storage.sRow.get_tensor(cute.make_layout((3,))) + head_kv_idx = mSchedulerMetadata[block_idx, Int32(0)] + row_linear = mSchedulerMetadata[block_idx, Int32(1)] + q_begin = mSchedulerMetadata[block_idx, Int32(2)] + q_count = mSchedulerMetadata[block_idx, Int32(3)] + batch_idx_t0 = mSchedulerMetadata[block_idx, Int32(4)] + + if tidx == Int32(0): + row_start_t0 = mK2qCounts[head_kv_idx, row_linear] + q_begin + sRow[0] = row_start_t0 + sRow[1] = q_count + sRow[2] = batch_idx_t0 + cute.arch.barrier() + row_start = sRow[0] + row_count = sRow[1] + batch_idx = sRow[2] + qi = tidx + while qi < row_count: + edge = row_start + qi + q_idx = mK2qIndices[head_kv_idx, edge] + if q_idx >= Int32(0) and q_idx < max_seqlen_q: + q_abs = mCuSeqlensQ[batch_idx] + q_idx + split_ptr = utils.elem_pointer( + mSplitCounts, + (q_abs, head_kv_idx), + ) + split_slot = copy_utils.atomic_add_i32(split_ptr) + if split_slot < topk: + mK2qQSplitIndices[head_kv_idx, edge] = ( + q_idx | ((split_slot & Int32(0xFF)) << Int32(24)) + ) + qi += Int32(self.num_threads) + + +def _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + work_capacity: int, + max_seqlen_q: int, + topk: int, +): + key = ( + "sparse_prepare_fwd_split_atomic_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFwdSplitAtomicSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(split_counts), + to_cute_tensor_kvouter(cu_seqlens_q), + Int32(work_capacity), + Int32(max_seqlen_q), + Int32(topk), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def _get_sparse_prepare_flat_schedule( + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + target_q_per_cta: int, + scheduler_metadata_capacity: int, + head_kv: int, + blk_kv: int, +): + key = ( + "sparse_prepare_flat_schedule_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFlatScheduleSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(cu_seqlens_k), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + Int32(target_q_per_cta), + Int32(scheduler_metadata_capacity), + Int32(head_kv), + Int32(blk_kv), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def prepare_sparse_flat_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + if not enabled: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + + total_rows = int(k2q_row_ptr.shape[1] - 1) + if total_rows <= 0 or head_kv <= 0: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), + dtype=torch.int32, + device=device, + ) + work_count = torch.zeros((1,), dtype=torch.int32, device=device) + scheduler_metadata.zero_() + + compiled_prepare = _get_sparse_prepare_flat_schedule( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFlatSchedule"): + compiled_prepare( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + + return SparseSchedulePlan( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + target_q_per_cta=target_q_per_cta, + ) + +def prepare_sparse_fwd_schedule_and_split( + *, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + max_seqlen_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + blk_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + plan = prepare_sparse_fwd_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=blk_kv, + device=device, + enabled=enabled, + usable_SM_count=usable_SM_count, + ) + if not plan.enabled: + return plan + if plan.scheduler_metadata is None or plan.work_count is None: + raise RuntimeError("fwd GPU schedule requires metadata") + if topk > 255: + raise ValueError(f"packed qsplit metadata supports topK <= 255, got {topk}") + if max_seqlen_q >= (1 << 24): + raise ValueError( + "packed qsplit metadata supports batch-local q_idx < 2^24, " + f"got max_seqlen_q={max_seqlen_q}" + ) + if k2q_qsplit_indices.shape != k2q_q_indices.shape: + raise ValueError("k2q_qsplit_indices shape must match k2q_q_indices") + if split_counts.dtype != torch.int32 or k2q_qsplit_indices.dtype != torch.int32: + raise TypeError("split metadata tensors must be torch.int32") + if split_counts.shape != (total_q, head_kv): + raise ValueError( + f"split_counts must have shape ({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if cu_seqlens_q.dtype != torch.int32: + raise TypeError("cu_seqlens_q must be torch.int32") + if cu_seqlens_q.ndim != 1 or not cu_seqlens_q.is_contiguous(): + raise ValueError("cu_seqlens_q must be a contiguous rank-1 tensor") + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + with torch.cuda.nvtx.range("SparseAttention_InitFwdSplitState"): + split_counts.zero_() + + compiled_split = _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFwdSplit_Atomic"): + compiled_split( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + plan.qsplit_indices = k2q_qsplit_indices + plan.split_counts = split_counts + return plan + + +def prepare_sparse_fwd_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + return prepare_sparse_flat_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=int(total_q), + topk=int(topk), + blk_kv=int(blk_kv), + head_kv=int(head_kv), + qhead_per_kv=int(qhead_per_kv), + device=device, + enabled=bool(enabled), + usable_SM_count=int(usable_SM_count), + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7d5e4ade468de366bb73eed0ccb38d4e358cf8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100. + +Hub-kernel packaging of the CuTe-DSL sparse attention stack from +https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The +host-side helper kernels (CSR builder, decode scheduler) are precompiled +Torch ops; the attention kernels are compiled at runtime through +nvidia-cutlass-dsl. +""" + +# Sparse attention forward / decode. +from .interface import ( + SparseDecodePagedAttentionWrapper, + sparse_atten_func, + sparse_atten_nvfp4_kv_func, + sparse_decode_atten_func, +) + +# CSR + schedule construction. +from .sparse_index_utils import build_k2q_csr + +# SM100 fused CSR builder. +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + +# FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK +# selection + q2k construction remain caller-owned downstream steps. +from .fp4_indexer_interface import fp4_indexer_block_scores + +# NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention. +from .quantize import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + quantize_bf16_to_nvfp4_128x4, + quantize_kv_bf16_to_nvfp4_128x4, + swizzle_nvfp4_scale_to_128x4, +) + +__version__ = "0.1.1" + +__all__ = [ + # attention + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "SparseDecodePagedAttentionWrapper", + # indexing / CSR + "fp4_indexer_block_scores", + "build_k2q_csr", + "SparseK2qCsrBuilderSm100", + # nvfp4 quantization helpers + "Nvfp4QuantizedTensor", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "dequantize_nvfp4_128x4_to_bf16", + "swizzle_nvfp4_scale_to_128x4", + "nvfp4_global_scale_from_amax", +] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_msa_cuda_09d7851.abi3.so b/build/torch212-cxx11-cu132-x86_64-linux/_msa_cuda_09d7851.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0aff2ac5b7d09bf6cb1ae44cb7befb137bbc9530 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_msa_cuda_09d7851.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d91c623337055f07630205f86db31d922124d14ea12b3ee6aa092ebdfb405a8 +size 1414048 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/_ops.py b/build/torch212-cxx11-cu132-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6be2da4d5d784683e9e2fb8bfe08e93847dc6640 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _msa_cuda_09d7851 +ops = torch.ops._msa_cuda_09d7851 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_msa_cuda_09d7851::{op_name}" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/fp4_indexer_interface.py b/build/torch212-cxx11-cu132-x86_64-linux/fp4_indexer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..48dc1d05480355d2af4f4e47142ae4cd692184b0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/fp4_indexer_interface.py @@ -0,0 +1,1061 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Public FP4 sparse-attention indexer block-score interface.""" + +from __future__ import annotations + +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32 +from cutlass.cute.runtime import make_ptr + +from .src.sm100.fp4_indexer import ( + Fp4FormatSpec, + Fp4IndexerDecodePackedQSm100, + Fp4IndexerDecodeQPackSm100, + Fp4IndexerScaleReorderSm100, + Fp4IndexerStagedMmaSm100, + _BLOCK_K, + _DECODE_K_TILES_PER_CTA, + _DECODE_PACK_Q_LEN, + _DECODE_QHEAD_PER_KV, + _FP4_PACKED_D_BYTES, + _HEAD_DIM, + _MMA_TILER_MN, + _PAGE_SIZE, + ceil_div, + k_tiles_per_cta_for, + normalize_fp4_format, +) + + +_PUBLIC_SCALE_LAYOUT = "public" +_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma" +_FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _device_arch(device: torch.device) -> tuple[int, int]: + major, minor = torch.cuda.get_device_capability(device) + return int(major), int(minor) + + +def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool: + return device_arch >= (10, 3) + + +def normalize_scale_layout(scale_layout: str) -> str: + """Normalize and validate FP4 indexer scale layout mode. + + Parameters + ---------- + scale_layout : str + Either ``"public"`` for logical scale tensors or ``"preordered_mma"`` + for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``. + + Returns + ------- + str + The normalized scale layout string. + """ + + scale_layout = str(scale_layout) + if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT): + raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}") + return scale_layout + + +def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int: + if q_len <= 0 or k_len <= 0: + return 0 + q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0]) + k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta) + group_tokens = k_tiles_per_cta * _BLOCK_K + causal_offset = int(k_len) - int(q_len) + tasks = 0 + for q_tile_idx in range(q_tile_count): + q_tile_start = q_tile_idx * _MMA_TILER_MN[0] + q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1) + visible_limit = q_tile_last + causal_offset + if visible_limit >= 0: + tasks += min(k_group_count, visible_limit // group_tokens + 1) + return tasks + + +def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int: + """Conservative X-grid bound for per-batch causal prefill compact mapping.""" + + if max_q_len <= 0 or max_k_len <= 0: + return 0 + q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0]) + candidates = {int(max_q_len)} + for q_tile_idx in range(q_tile_count): + q_len = q_tile_idx * _MMA_TILER_MN[0] + 1 + if q_len <= max_q_len: + candidates.add(q_len) + return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates) + + +def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None: + if tensor.device != device: + raise ValueError(f"{name} must be on the same CUDA device") + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None: + fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None) + allowed = {torch.uint8, torch.int8} + if fp4_x2_dtype is not None: + allowed.add(fp4_x2_dtype) + if tensor.dtype not in allowed: + raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}") + + +def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 3: + raise ValueError(f"{name} must have shape [total_q, Hq, 64]") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor: + if tensor.ndim != 4: + raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]") + if int(tensor.shape[-2]) != _PAGE_SIZE: + raise ValueError(f"{name}.shape[-2] must be 128") + if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128") + _require_fp4_packed_dtype(tensor, name=name) + if tensor.dtype == torch.uint8: + return tensor + return tensor.view(torch.uint8) + + +def validate_q_scale_thg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + total_q: int, + heads: int, +) -> None: + """Validate public Q FP4 scale layout ``[total_q, Hq, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical Q scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + total_q : int + Total query token count. + heads : int + Number of Q heads. + """ + + expected = (int(total_q), int(heads), fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def validate_k_scale_phsg( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + page_count: int, + heads: int, +) -> None: + """Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``. + + Parameters + ---------- + scale : torch.Tensor + Logical K scale tensor. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + page_count : int + Number of physical KV pages. + heads : int + Number of KV heads. + """ + + expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups) + if tuple(scale.shape) != expected: + raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + if not scale.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``.""" + + spec = normalize_fp4_format(fp4_format) + return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l)) + + +def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g) + + +def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return contiguous storage shape for preordered MMA scales.""" + + spec = normalize_fp4_format(fp4_format) + return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4) + + +def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]: + """Return element strides for ``fp4_indexer_mma_scale_storage_shape``.""" + + spec = normalize_fp4_format(fp4_format) + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1) + + +def validate_mma_scale_storage( + scale: torch.Tensor, + *, + name: str, + fmt: Fp4FormatSpec, + mn: int, + l: int, +) -> None: + """Validate preordered MMA scale storage expected by the FP4 indexer. + + Parameters + ---------- + scale : torch.Tensor + Tensor view whose shape/stride should match + ``fp4_indexer_mma_scale_storage_shape`` and + ``fp4_indexer_mma_scale_storage_stride``. + name : str + Name used in validation error messages. + fmt : Fp4FormatSpec + FP4 format specification from ``normalize_fp4_format``. + mn : int + Logical M/N extent of the scale domain. + l : int + Logical batch/head extent folded into the final layout dimension. + """ + + expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name) + expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name) + if tuple(scale.shape) != expected_shape: + raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}") + if tuple(scale.stride()) != expected_stride: + raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}") + if scale.dtype != fmt.torch_scale_dtype: + raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}") + + +def _empty_mma_scale_tensor( + *, + mn: int, + l: int, + spec: Fp4FormatSpec, + device: torch.device, +) -> torch.Tensor: + rest_m = ceil_div(mn, 128) + rest_g = ceil_div(spec.scale_groups, 4) + storage = torch.empty( + (int(l), rest_m, rest_g, 32, 4, 4), + dtype=spec.torch_scale_dtype, + device=device, + ) + return storage.permute(3, 4, 1, 5, 2, 0) + + +def _compile_fp4_scale_reorder_kernel( + *, + fmt: Fp4FormatSpec, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_scale_reorder_sm100_1cta", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_reorder_scales_for_mma_cute( + q_scale: torch.Tensor, + k_scale: torch.Tensor, + *, + fp4_format: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reorder public Q/K FP4 scales to MMA-friendly storage. + + Parameters + ---------- + q_scale : torch.Tensor + Public Q scale tensor with shape ``[total_q, Hq, G]``. + k_scale : torch.Tensor + Public K scale tensor with shape ``[page_count, Hk, 128, G]``. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by + ``validate_mma_scale_storage``. These tensors can be passed to + ``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``. + """ + + spec = normalize_fp4_format(fp4_format) + if q_scale.device != k_scale.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device") + _require_cuda_tensor(q_scale, name="q_scale") + _require_cuda_tensor(k_scale, name="k_scale") + if q_scale.ndim != 3: + raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}") + if k_scale.ndim != 4: + raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}") + total_q, heads_q, _ = (int(v) for v in q_scale.shape) + page_count, heads_k, _, _ = (int(v) for v in k_scale.shape) + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + + q_scale_mma = _empty_mma_scale_tensor( + mn=total_q, + l=heads_q, + spec=spec, + device=q_scale.device, + ) + k_scale_mma = _empty_mma_scale_tensor( + mn=_PAGE_SIZE, + l=page_count * heads_k, + spec=spec, + device=k_scale.device, + ) + + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + q_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_mma_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_mma.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + problem_size = ( + Int32(total_q), + Int32(heads_q), + Int32(page_count), + Int32(heads_k), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream) + compiled = _compile_fp4_scale_reorder_kernel( + fmt=spec, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + q_scale_mma_ptr=q_scale_mma_ptr, + k_scale_mma_ptr=k_scale_mma_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_scale_ptr, + k_scale_ptr, + q_scale_mma_ptr, + k_scale_mma_ptr, + problem_size, + stream, + ) + return q_scale_mma, k_scale_mma + + +def _compile_fp4_decode_q_pack_kernel( + *, + fmt: Fp4FormatSpec, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_q_pack_sm100", + fmt.name, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _pack_decode_q_for_mma( + q_bytes: torch.Tensor, + q_scale_storage: torch.Tensor, + cu_seqlens_q: torch.Tensor, + *, + fmt: Fp4FormatSpec, + heads_q: int, + heads_k: int, + batch: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_pack = torch.empty( + (batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + dtype=torch.uint8, + device=q_bytes.device, + ) + q_scale_pack = torch.empty( + fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name), + dtype=fmt.torch_scale_dtype, + device=q_bytes.device, + ) + if q_pack.data_ptr() % 128 != 0: + raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA") + if q_scale_pack.data_ptr() % 32 != 0: + raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned") + q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(q_bytes.shape[0]), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream) + compiled = _compile_fp4_decode_q_pack_kernel( + fmt=fmt, + q_ptr=q_ptr, + q_scale_ptr=q_scale_ptr, + q_pack_ptr=q_pack_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + q_scale_ptr, + q_pack_ptr, + q_scale_pack_ptr, + cu_seqlens_q_ptr, + problem_size, + stream, + ) + return q_pack, q_scale_pack + + +def _compile_fp4_decode_packed_q_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_decode_packed_q_sm100", + fmt.name, + bool(causal), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerDecodePackedQSm100( + fmt=fmt.name, + causal=causal, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def _run_fp4_decode_packed_q_scores( + q_pack: torch.Tensor, + k_bytes: torch.Tensor, + q_scale_pack: torch.Tensor, + k_scale_storage: torch.Tensor, + scores: torch.Tensor, + kv_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + qo_offset_arg: torch.Tensor, + *, + fmt: Fp4FormatSpec, + causal: bool, + has_qo_offset: int, + heads_q: int, + heads_k: int, + batch: int, + max_k_tiles: int, + total_q: int, + device_arch: tuple[int, int], + use_tmem_load_red: bool, +) -> None: + page_count = int(k_bytes.shape[0]) + rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA) + compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA) + compact_schedule = compact_groups < rectangular_groups + if compact_schedule: + scores.fill_(float("-inf")) + + q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128) + q_scale_pack_ptr = make_ptr( + fmt.cutlass_scale_dtype, + q_scale_pack.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + k_scale_ptr = make_ptr( + fmt.cutlass_scale_dtype, + k_scale_storage.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + problem_size = ( + Int32(_PAGE_SIZE), + Int32(max_k_tiles * _PAGE_SIZE), + Int32(_HEAD_DIM), + Int32(batch * heads_k), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream) + compiled = _compile_fp4_decode_packed_q_kernel( + fmt=fmt, + causal=causal, + compact_schedule=compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_pack_ptr=q_pack_ptr, + k_ptr=k_ptr, + q_scale_pack_ptr=q_scale_pack_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_pack_ptr, + k_ptr, + q_scale_pack_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + + +def _compile_fp4_qk_kernel( + *, + fmt: Fp4FormatSpec, + causal: bool, + preordered_q_scale_tma: bool, + compact_schedule: bool, + device_arch: tuple[int, int], + use_tmem_load_red: bool, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, +): + key = ( + "fp4_indexer_staged_mma_sm100", + fmt.name, + bool(causal), + bool(preordered_q_scale_tma), + bool(compact_schedule), + device_arch, + ) + if key not in _FP4_COMPILE_CACHE: + kernel = Fp4IndexerStagedMmaSm100( + fmt=fmt.name, + causal=causal, + preordered_q_scale_tma=preordered_q_scale_tma, + compact_schedule=compact_schedule, + use_tmem_load_red=use_tmem_load_red, + ) + _FP4_COMPILE_CACHE[key] = cute.compile( + kernel, + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return _FP4_COMPILE_CACHE[key] + + +def fp4_indexer_block_scores( + q_fp4: torch.Tensor, + k_fp4: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + *, + max_seqlen_q: int, + max_seqlen_k: int, + kv_indices: torch.Tensor, + fp4_format: str, + causal: bool = False, + qo_offset: Optional[torch.Tensor] = None, + scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT, +) -> torch.Tensor: + """Return FP4 QK max scores per 128-token KV page. + + Parameters + ---------- + q_fp4 : torch.Tensor + Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last + dimension stores two FP4 values per byte for logical head dimension + 128. + k_fp4 : torch.Tensor + Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``. + q_scale : torch.Tensor + Q scale tensor. With ``scale_layout="public"``, shape is + ``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use + ``fp4_indexer_reorder_scales_for_mma_cute`` output layout. + k_scale : torch.Tensor + K scale tensor. With ``scale_layout="public"``, shape is + ``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the + preordered MMA scale layout. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + cu_page_offsets : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request + page counts. + max_seqlen_q : int + Maximum Q sequence length. + max_seqlen_k : int + Maximum KV sequence length. + kv_indices : torch.Tensor + Flattened physical page indices with shape ``[sum_pages]`` and dtype + int32. + fp4_format : str + ``"mxfp4"`` or ``"nvfp4"``. + causal : bool, optional + Whether to apply causal masking. + qo_offset : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid + only when ``causal=True``. + scale_layout : str, optional + ``"public"`` accepts logical public scale tensors and launches a scale + reorder kernel. ``"preordered_mma"`` expects preordered MMA scale + tensors and skips the reorder. + + Returns + ------- + torch.Tensor + Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32. + Entries beyond the valid KV page range are ``-inf``. + """ + + spec = normalize_fp4_format(fp4_format) + causal = bool(causal) + scale_layout = normalize_scale_layout(scale_layout) + use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE + q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4") + k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4") + total_q, heads_q, _ = (int(v) for v in q_bytes.shape) + page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape) + if page_size != _PAGE_SIZE: + raise ValueError(f"k_fp4 page_size must be 128, got {page_size}") + if heads_q % heads_k != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + _require_cuda_tensor(q_fp4, name="q_fp4") + _require_cuda_tensor(k_fp4, name="k_fp4") + device_arch = _device_arch(q_fp4.device) + use_tmem_load_red = _supports_tmem_load_red(device_arch) + _require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device) + _require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device) + _require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device) + if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device: + raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4") + if scale_layout == _PUBLIC_SCALE_LAYOUT: + validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q) + validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k) + else: + validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q) + validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k) + batch = int(cu_seqlens_q.shape[0]) - 1 + if batch < 0: + raise ValueError("cu_seqlens_q must have shape [B + 1]") + if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape: + raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]") + if q_bytes.data_ptr() % 128 != 0: + raise ValueError("q_fp4 data pointer must be 128B aligned for TMA") + if k_bytes.data_ptr() % 128 != 0: + raise ValueError("k_fp4 data pointer must be 128B aligned for TMA") + if kv_indices is None: + raise ValueError("kv_indices is required") + if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1: + raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device") + if not kv_indices.is_contiguous(): + raise ValueError("kv_indices must be contiguous") + if qo_offset is not None: + if not causal: + raise ValueError("qo_offset is only valid when causal=True") + if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,): + raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device") + if not qo_offset.is_contiguous(): + raise ValueError("qo_offset must be contiguous") + + m_extent = int(max_seqlen_q) + max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE) + n_aligned = max_k_tiles * _PAGE_SIZE + if max_k_tiles == 0: + return torch.full( + (heads_q, 0, total_q), + float("-inf"), + dtype=torch.float32, + device=q_fp4.device, + ) + + scores = torch.empty( + (heads_q, max_k_tiles, total_q), + dtype=torch.float32, + device=q_fp4.device, + ) + if qo_offset is None: + qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device) + has_qo_offset = 0 + else: + qo_offset_arg = qo_offset + has_qo_offset = 1 + if scale_layout == _PUBLIC_SCALE_LAYOUT: + q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute( + q_scale, + k_scale, + fp4_format=spec.name, + ) + else: + q_scale_arg = q_scale + k_scale_arg = k_scale + scale_assumed_align = 32 + if q_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + if k_scale_arg.data_ptr() % scale_assumed_align != 0: + raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale") + use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV + if use_decode_packed_q: + q_pack, q_scale_pack = _pack_decode_q_for_mma( + q_bytes, + q_scale_arg, + cu_seqlens_q, + fmt=spec, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + ) + _run_fp4_decode_packed_q_scores( + q_pack, + k_bytes, + q_scale_pack, + k_scale_arg, + scores, + kv_indices, + cu_seqlens_q, + cu_seqlens_k, + cu_page_offsets, + qo_offset_arg, + fmt=spec, + causal=causal, + has_qo_offset=has_qo_offset, + heads_q=heads_q, + heads_k=heads_k, + batch=batch, + max_k_tiles=max_k_tiles, + total_q=total_q, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + ) + return scores + prefill_compact_task_count = 0 + prefill_compact_schedule = False + if causal and has_qo_offset == 0: + k_tiles_per_cta = k_tiles_per_cta_for(causal) + q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0]) + k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta) + rectangular_task_count = q_tile_count * k_group_count + prefill_compact_task_count = min( + rectangular_task_count, + _causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta), + ) + prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19 + if prefill_compact_schedule: + scores.fill_(float("-inf")) + q_ptr = make_ptr( + cutlass.Uint8, + q_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + k_ptr = make_ptr( + cutlass.Uint8, + k_bytes.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=128, + ) + q_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + q_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + k_scale_ptr = make_ptr( + spec.cutlass_scale_dtype, + k_scale_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=scale_assumed_align, + ) + scores_ptr = make_ptr( + cutlass.Float32, + scores.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + kv_indices_ptr = make_ptr( + cutlass.Int32, + kv_indices.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_q_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_q.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_seqlens_k_ptr = make_ptr( + cutlass.Int32, + cu_seqlens_k.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + cu_page_offsets_ptr = make_ptr( + cutlass.Int32, + cu_page_offsets.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + qo_offset_ptr = make_ptr( + cutlass.Int32, + qo_offset_arg.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + problem_size = ( + Int32(m_extent), + Int32(n_aligned), + Int32(_HEAD_DIM), + Int32(batch * heads_q), + Int32(page_count * heads_k), + Int32(heads_q), + Int32(heads_k), + Int32(batch), + Int32(max_k_tiles), + Int32(total_q), + Int32(has_qo_offset), + Int32(prefill_compact_task_count), + ) + stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream) + compiled = _compile_fp4_qk_kernel( + fmt=spec, + causal=causal, + preordered_q_scale_tma=use_preordered_q_scale_tma, + compact_schedule=prefill_compact_schedule, + device_arch=device_arch, + use_tmem_load_red=use_tmem_load_red, + q_ptr=q_ptr, + k_ptr=k_ptr, + q_scale_ptr=q_scale_ptr, + k_scale_ptr=k_scale_ptr, + scores_ptr=scores_ptr, + kv_indices_ptr=kv_indices_ptr, + cu_seqlens_q_ptr=cu_seqlens_q_ptr, + cu_seqlens_k_ptr=cu_seqlens_k_ptr, + cu_page_offsets_ptr=cu_page_offsets_ptr, + qo_offset_ptr=qo_offset_ptr, + problem_size=problem_size, + stream=stream, + ) + compiled( + q_ptr, + k_ptr, + q_scale_ptr, + k_scale_ptr, + scores_ptr, + kv_indices_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + cu_page_offsets_ptr, + qo_offset_ptr, + problem_size, + stream, + ) + return scores + + +__all__ = [ + "fp4_indexer_block_scores", +] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/interface.py b/build/torch212-cxx11-cu132-x86_64-linux/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9e507961840b3322238646ffffe3e97cf5d13130 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/interface.py @@ -0,0 +1,2011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse attention interface. + +Current delivery scope: + - head dimension is supported only for D=128 + +Public API: + sparse_atten_func(...) + sparse_decode_atten_func(...) + SparseDecodePagedAttentionWrapper + +Internal forward core: + _sparse_atten_csr_varlen_forward(...) + +Preprocessing (external, done once): + q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr() + -> k2q_row_ptr [head_kv, total_rows + 1] int32 + -> k2q_q_indices [head_kv, total_q * topK] int32 +""" + +import math +import os +from typing import Optional + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 +from cutlass.cute.runtime import from_dlpack + +from .src.sm100.fwd.combine import combine +from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100 +from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 +from .src.sm100.prepare_scheduler import ( + SparseAttentionSchedule, + prepare_sparse_fwd_schedule_and_split, +) +from .src.sm100.decode_schedule import ( + DecodeAttentionSchedule, + prepare_decode_schedule, +) +from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter +from .src.common.tma_utils import ( + create_q_gather4_tma_desc, +) + +_compile_cache: dict = {} +_TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12 +_SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32) +_SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn) +_SUPPORTED_DECODE_QHEAD_PER_KV = 16 + + +def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype: + supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn} + if partial_dtype not in supported: + raise TypeError( + "partial_dtype must be one of torch.float32 / torch.bfloat16 / " + "torch.float16 / torch.float8_e4m3fn, " + f"got {partial_dtype}" + ) + return partial_dtype + + +def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype: + dtype = fallback if dtype is None else dtype + if dtype not in _SUPPORTED_FWD_MMA_DTYPES: + raise TypeError( + f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}" + ) + return dtype + + +def _resolve_forward_mma_dtypes( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qk_dtype: Optional[torch.dtype], + pv_dtype: Optional[torch.dtype], +) -> tuple[torch.dtype, torch.dtype]: + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + if pv_dtype is None: + # Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V + # stages both K and V as BF16 compute operands. + if ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ): + pv_dtype = torch.bfloat16 + else: + pv_dtype = v.dtype + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype") + + if q.dtype != qk_dtype: + raise ValueError( + "qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported" + ) + if k.dtype != qk_dtype: + if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16): + raise ValueError( + "unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported" + ) + if v.dtype != pv_dtype: + if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16): + raise ValueError( + "unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported" + ) + return qk_dtype, pv_dtype + + +def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4): + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True) + return tensor.mark_layout_dynamic(leading_dim=0) + + +def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype): + if dtype == torch.bfloat16: + return cutlass.BFloat16 + if dtype == torch.float16: + return cutlass.Float16 + if dtype == torch.float8_e4m3fn: + return cutlass.Float8E4M3FN + raise TypeError( + f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}" + ) + + +def _prepare_paged_kv_for_tma(k, v, blk_kv: int): + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + return k, v + + +def _validate_cu_seqlens( + cu_seqlens: torch.Tensor, + *, + name: str, + device: torch.device, +) -> None: + if cu_seqlens.device != device: + raise ValueError(f"{name} must be on the same device as q") + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must have shape [B + 1]") + if cu_seqlens.shape[0] < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int: + return int(k2q_row_ptr.shape[1] - 1) + + +def _validate_csr_varlen_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in _SUPPORTED_FWD_DTYPES: + raise TypeError( + "CSR sparse forward supports only torch.bfloat16 and " + f"torch.float8_e4m3fn Q/K/V, got {q.dtype}" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("q, k, v must be on the same device") + mixed_fp8_kv_bf16_q = ( + q.dtype == torch.bfloat16 + and k.dtype == torch.float8_e4m3fn + and v.dtype == torch.float8_e4m3fn + ) + if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype): + raise ValueError( + "q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache" + ) + if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]: + raise ValueError("q, k, v must have the same head dimension") + dim = q.shape[-1] + if dim != 128: + raise NotImplementedError( + f"CSR sparse forward currently supports only D=128, got D={dim}" + ) + if page_table is None: + if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]: + raise ValueError("k and v must have the same [Hkv, D] tail dimensions") + head_kv = k.shape[-2] + else: + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must have the same Hkv and D" + ) + head_kv = k.shape[1] + if ( + q.device != k2q_row_ptr.device + or q.device != k2q_q_indices.device + ): + raise ValueError("CSR metadata must be on the same device as q") + if ( + k2q_row_ptr.dtype != torch.int32 + or k2q_q_indices.dtype != torch.int32 + ): + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + total_q = q.shape[0] + + head_q = q.shape[1] + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < total_q * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + total_k = k.shape[0] + if k.ndim != 3 or v.ndim != 3: + raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]") + if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]): + raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]") + else: + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2 or page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "Sparse Page Attention requires k and v to have shape " + "[num_pages, Hkv, page_size, D]" + ) + if k.shape != v.shape: + raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}") + if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]: + raise ValueError( + "Sparse Page Attention k and v must match " + "[num_pages, Hkv, page_size, D]" + ) + page_size = int(k.shape[2]) + if page_size != blk_kv: + raise ValueError( + f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; " + "require page_size == blk_kv" + ) + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_csr_varlen_nvfp4_kv_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + page_table: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]") + if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn): + raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}") + if q.shape[-1] != 128: + raise NotImplementedError( + f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}" + ) + if k.dtype != torch.uint8 or v.dtype != torch.uint8: + raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}") + if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8: + raise TypeError( + "KVFP4 block scales must be torch.uint8 E4M3 tensors, got " + f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}" + ) + if k_global_scale is not None and k_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None") + if v_global_scale is not None and v_global_scale.dtype != torch.float32: + raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None") + tensors = ( + k, + v, + k_scale_128x4, + v_scale_128x4, + k2q_row_ptr, + k2q_q_indices, + cu_seqlens_q, + cu_seqlens_k, + ) + optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None) + if any(t.device != q.device for t in tensors + optional_tensors): + raise ValueError("KVFP4 inputs and metadata must be on the same device as q") + if k.shape != v.shape: + raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}") + packed_dim = q.shape[-1] // 2 + scale_cols = q.shape[-1] // 16 + if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2: + raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors") + if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols: + raise ValueError( + "KVFP4 block scales must have at least D/16 columns; " + f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}" + ) + if k_global_scale is not None and k_global_scale.numel() < 1: + raise ValueError("KVFP4 K global scale must contain at least one element") + if v_global_scale is not None and v_global_scale.numel() < 1: + raise ValueError("KVFP4 V global scale must contain at least one element") + + if page_table is None: + if seqused_k is not None: + raise ValueError("seqused_k is only supported together with page_table") + if k.ndim != 3: + raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]") + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + total_k = int(k.shape[0]) + head_kv = int(k.shape[1]) + required_scale_rows = total_k * head_kv + else: + if k.ndim != 4: + raise ValueError( + "KVFP4 Sparse Page Attention requires k/v shape " + "[num_pages, Hkv, page_size, D/2]" + ) + if k.shape[-1] != packed_dim: + raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError( + f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}" + ) + head_kv = int(k.shape[1]) + required_scale_rows = int(k.shape[0]) * head_kv * page_size + if page_table.device != q.device: + raise ValueError("page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if page_table.stride(-1) != 1: + raise ValueError("page_table must be contiguous in the last dimension") + if seqused_k is not None: + if seqused_k.device != q.device: + raise ValueError("seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("seqused_k must be torch.int32") + if not seqused_k.is_contiguous(): + raise ValueError("seqused_k must be contiguous") + + padded_scale_rows = ((required_scale_rows + 127) // 128) * 128 + padded_scale_cols = ((scale_cols + 3) // 4) * 4 + for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)): + if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols: + raise ValueError( + f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, " + f"need at least {(padded_scale_rows, padded_scale_cols)}" + ) + + if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device: + raise ValueError("CSR metadata must be on the same device as q") + if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32: + raise TypeError("CSR metadata tensors must be torch.int32") + if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2: + raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device) + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device) + if cu_seqlens_k.shape != cu_seqlens_q.shape: + raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q") + batch = int(cu_seqlens_q.shape[0] - 1) + if page_table is not None and page_table.shape[0] != batch: + raise ValueError("page_table must have shape [B, max_num_pages_per_seq]") + if seqused_k is not None and seqused_k.shape != (batch,): + raise ValueError("seqused_k must have shape [B]") + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv not in (1, 2, 4, 8, 16): + raise NotImplementedError( + "KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}" + ) + if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv: + raise ValueError("CSR metadata head dimension must match KV head count") + if k2q_q_indices.shape[1] < q.shape[0] * topK: + raise ValueError( + f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})" + ) + if k2q_row_ptr.shape[1] < 1: + raise ValueError("k2q_row_ptr must contain at least one row pointer column") + if topK not in _SUPPORTED_SPARSE_TOPK: + raise ValueError( + f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}" + ) + return batch, head_kv + + +def _validate_sparse_decode_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, +) -> tuple[int, int]: + if q.ndim != 3: + raise ValueError("decode attention requires q to have shape [total_q, Hq, D]") + if k.ndim != 4 or v.ndim != 4: + raise ValueError( + "decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]" + ) + if q.device != k.device or q.device != v.device: + raise ValueError("decode q, k, and v must be on the same device") + if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype: + raise TypeError( + "decode attention currently supports only torch.float8_e4m3fn Q/K/V" + ) + if k.shape != v.shape: + raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}") + if q.shape[-1] != 128 or k.shape[-1] != 128: + raise NotImplementedError( + f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}" + ) + if not bool(causal): + raise NotImplementedError("decode attention currently supports only causal=True") + page_size = int(k.shape[2]) + if page_size != int(blk_kv): + raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}") + + head_kv = int(k.shape[1]) + head_q = int(q.shape[1]) + if head_q % head_kv != 0: + raise ValueError("decode q.shape[1] must be divisible by Hkv") + qhead_per_kv = head_q // head_kv + if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV: + raise NotImplementedError( + "decode attention currently supports only " + f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}" + ) + + if page_table is None: + raise ValueError("decode attention requires page_table") + if page_table.device != q.device: + raise ValueError("decode page_table must be on the same device as q") + if page_table.dtype != torch.int32: + raise TypeError("decode page_table must be torch.int32") + if page_table.ndim != 2: + raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]") + batch = int(page_table.shape[0]) + if page_table.stride(-1) != 1: + raise ValueError("decode page_table must be contiguous in the last dimension") + + if seqused_k is None: + raise ValueError("decode attention requires seqused_k") + if seqused_k.device != q.device: + raise ValueError("decode seqused_k must be on the same device as q") + if seqused_k.dtype != torch.int32: + raise TypeError("decode seqused_k must be torch.int32") + if seqused_k.shape != (batch,): + raise ValueError("decode seqused_k must have shape [B]") + if not seqused_k.is_contiguous(): + raise ValueError("decode seqused_k must be contiguous") + + seqlen_q = int(seqlen_q) + max_seqlen_k = int(max_seqlen_k) + if seqlen_q <= 0 or max_seqlen_k <= 0: + raise ValueError("decode seqlen_q and max_seqlen_k must be positive") + if int(q.shape[0]) != batch * seqlen_q: + raise ValueError("decode q.shape[0] must equal batch * seqlen_q") + + if q2k_indices is not None: + if q2k_indices.device != q.device: + raise ValueError("decode q2k_indices must be on the same device as q") + if q2k_indices.dtype != torch.int32: + raise TypeError("decode q2k_indices must be torch.int32") + if q2k_indices.ndim != 3: + raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]") + if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]: + raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]") + if not q2k_indices.is_contiguous(): + raise ValueError("decode q2k_indices must be contiguous") + return batch, head_kv + + +def _validate_schedule_common( + schedule: SparseAttentionSchedule, + *, + device: torch.device, +) -> None: + if schedule.scheduler_metadata is None: + raise ValueError("schedule.scheduler_metadata is required") + if schedule.work_count is None: + raise ValueError("schedule.work_count is required") + metadata = schedule.scheduler_metadata + work_count = schedule.work_count + if metadata.device != device or work_count.device != device: + raise ValueError("schedule tensors must be on the same device as q") + if metadata.dtype != torch.int32 or work_count.dtype != torch.int32: + raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32") + if metadata.ndim != 2 or metadata.shape[1] != 6: + raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]") + if work_count.shape != (1,): + raise ValueError("schedule.work_count must have shape [1]") + if not metadata.is_contiguous() or not work_count.is_contiguous(): + raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous") + + +def _validate_fwd_schedule( + schedule: SparseAttentionSchedule, + *, + q: torch.Tensor, + k2q_q_indices: torch.Tensor, + head_kv: int, +) -> None: + _validate_schedule_common(schedule, device=q.device) + if schedule.qsplit_indices is None: + raise ValueError("schedule.qsplit_indices is required for forward") + if schedule.split_counts is None: + raise ValueError("schedule.split_counts is required for forward") + qsplit = schedule.qsplit_indices + split_counts = schedule.split_counts + if qsplit.device != q.device or split_counts.device != q.device: + raise ValueError("forward schedule tensors must be on the same device as q") + if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32: + raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32") + if qsplit.shape != k2q_q_indices.shape: + raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices") + total_q = q.shape[0] + if split_counts.shape != (total_q, head_kv): + raise ValueError( + "schedule.split_counts must have shape " + f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if not qsplit.is_contiguous() or not split_counts.is_contiguous(): + raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous") + + +def sparse_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, + usable_SM_count: int = -1, + qk_dtype: Optional[torch.dtype] = None, + pv_dtype: Optional[torch.dtype] = None, +): + """Run SM100 CSR block-sparse varlen attention. + + This is the public forward-only sparse attention API. It consumes + query-to-key block selections converted to CSR metadata by + ``build_k2q_csr`` and supports both dense KV layout and paged KV layout. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Dense layout ``[total_k, Hkv, 128]`` or paged layout + ``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K + may be FP8 E4M3 while QK compute uses BF16 staging. + v : torch.Tensor + Same layout and head count as ``k``. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + max_seqlen_q : int + Maximum Q sequence length in the batch. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + KV block size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return LSE computed with logits scaled by + ``softmax_scale / lse_temperature_scale``. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. Supported values are + FP32, BF16, FP16, and FP8 E4M3. + return_softmax_lse : bool, optional + If True, return ``(out, softmax_lse)`` or + ``(out, softmax_lse, temperature_lse)``. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Shape ``[batch_size]``, dtype int32. Effective KV length per request + for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. If omitted, the schedule is built + during the call. + usable_SM_count : int, optional + Maximum number of SMs used by the scheduler. ``-1`` uses all SMs. + qk_dtype : torch.dtype, optional + Compile-time MMA operand dtype for QK. Defaults to Q storage dtype, + except supported FP8 K/V cache staging modes. + pv_dtype : torch.dtype, optional + Compile-time MMA operand dtype for PV. Defaults to V storage dtype, + except supported FP8 K/V cache staging modes. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + + Notes + ----- + ``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support + head dimension 128 only. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_inputs( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + max_seqlen_q = int(max_seqlen_q) + max_seqlen_k = int(max_seqlen_k) + + return _sparse_atten_csr_varlen_forward( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + int(topK), + int(blk_kv), + bool(causal), + float(softmax_scale), + lse_temperature_scale, + return_temperature_lse, + partial_dtype, + bool(return_softmax_lse), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + schedule, + int(usable_SM_count), + int(batch), + int(head_kv), + int(max_seqlen_q), + int(max_seqlen_k), + qk_dtype, + pv_dtype, + ) + + +def sparse_atten_nvfp4_kv_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: Optional[torch.Tensor], + v_global_scale: Optional[torch.Tensor], + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: Optional[float] = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: torch.dtype = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Run SM100 CSR sparse attention with packed NVFP4 K/V. + + Parameters + ---------- + q : torch.Tensor + Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and + FP8 E4M3. + k : torch.Tensor + Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged + layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8 + because each byte packs two FP4 values. + v : torch.Tensor + Packed NVFP4 V data with the same shape as ``k``. + k_scale_128x4 : torch.Tensor + K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8 + containing FP8 E4M3 scale values. + v_scale_128x4 : torch.Tensor + V block scales in the same 128x4 tiled storage. + k_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for K. May be ``None``. + v_global_scale : torch.Tensor, optional + FP32 tensor/global dequant scale for V. May be ``None``. The V global + scale is applied in the combine stage. + k2q_row_ptr : torch.Tensor + CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32. + k2q_q_indices : torch.Tensor + CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype + int32. + topK : int + Number of selected KV blocks per query. Supported values are + ``4, 8, 16, 32``. + cu_seqlens_q, cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV + lengths. + max_seqlen_q, max_seqlen_k : int + Maximum Q and KV sequence lengths in the batch. + blk_kv : int, optional + KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``. + causal : bool, optional + Whether to apply causal masking. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + lse_temperature_scale : float, optional + Extra divisor used only for temperature-scaled LSE output. + return_temperature_lse : bool, optional + If True, also return temperature-scaled LSE. Requires + ``return_softmax_lse=True``. + partial_dtype : torch.dtype, optional + Accumulation dtype for per-block partial O. + return_softmax_lse : bool, optional + If True, return LSE together with the output. + page_table : torch.Tensor, optional + Paged-KV physical page table with shape + ``[batch_size, max_num_pages_per_seq]`` and dtype int32. + seqused_k : torch.Tensor, optional + Effective KV length per request for paged causal attention. + schedule : SparseAttentionSchedule, optional + Prebuilt sparse forward schedule. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE + outputs have shape ``[total_q, Hq]`` and dtype float32. + """ + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse and not return_softmax_lse: + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + partial_dtype = _normalize_partial_dtype(partial_dtype) + + if cu_seqlens_q is None or cu_seqlens_k is None: + raise ValueError( + "sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k" + ) + batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + topK, + blk_kv, + page_table, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + ) + total_q, head_q, dim = q.shape + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + + schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q.contiguous(), + k.contiguous(), + v.contiguous(), + k_scale_128x4.contiguous(), + v_scale_128x4.contiguous(), + None if k_global_scale is None else k_global_scale.contiguous(), + None if v_global_scale is None else v_global_scale.contiguous(), + k2q_row_ptr.contiguous(), + k2q_q_indices.contiguous(), + k2q_qsplit_indices.contiguous(), + split_counts.contiguous(), + cu_seqlens_q.contiguous(), + cu_seqlens_k.contiguous(), + None if page_table is None else page_table.contiguous(), + None if seqused_k is None else seqused_k.contiguous(), + O_partial, + LSE_partial, + LSE_temperature_partial, + float(softmax_scale), + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + int(blk_kv), + head_kv, + int(max_seqlen_q), + causal=bool(causal), + schedule=schedule, + ) + + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + output_scale=v_global_scale, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def sparse_decode_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor] = None, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = True, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + schedule: Optional[DecodeAttentionSchedule] = None, + O_partial: Optional[torch.Tensor] = None, + LSE_partial: Optional[torch.Tensor] = None, +): + """Run forward-only paged FP8 decode attention. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8 + E4M3 dtype. + v : torch.Tensor + Paged V cache with the same shape and dtype as ``k``. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype + int32. ``None`` selects the dense all-KV decode path. + page_table : torch.Tensor + Physical page table with shape ``[batch_size, max_num_pages_per_seq]`` + and dtype int32. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per request. + seqlen_q : int + Uniform query length per request. Ragged Q lengths should use prefill + or append paths instead. + max_seqlen_k : int + Maximum KV sequence length in the batch. + blk_kv : int, optional + Page size. Must match ``k.shape[2]``. + causal : bool, optional + Whether to apply causal masking. Current decode kernel requires True. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + schedule : DecodeAttentionSchedule, optional + Prebuilt decode schedule. + O_partial, LSE_partial : torch.Tensor, optional + Optional split-KV partial workspaces. Normally owned by + ``SparseDecodePagedAttentionWrapper``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output with shape ``q.shape``. Optional LSE has shape + ``[batch_size * seqlen_q, Hq]`` and dtype float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + batch, head_kv = _validate_sparse_decode_inputs( + q, + k, + v, + q2k_indices, + page_table=page_table, + seqused_k=seqused_k, + seqlen_q=seqlen_q, + max_seqlen_k=max_seqlen_k, + blk_kv=blk_kv, + causal=causal, + ) + head_q = int(q.shape[1]) + head_dim = int(q.shape[2]) + if schedule is None: + schedule = prepare_decode_schedule( + seqused_k=seqused_k.contiguous(), + page_size=int(blk_kv), + seqlen_q=int(seqlen_q), + num_qo_heads=head_q, + num_kv_heads=head_kv, + head_dim=head_dim, + max_seqlen_k=int(max_seqlen_k), + ) + if schedule.split_kv: + if O_partial is None: + O_partial = torch.empty( + (schedule.partial_rows, head_q, head_dim), + dtype=torch.float32, + device=q.device, + ) + if LSE_partial is None: + LSE_partial = torch.empty( + (schedule.partial_rows, head_q), + dtype=torch.float32, + device=q.device, + ) + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + lse = torch.empty( + q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q), + dtype=torch.float32, + device=q.device, + ) + _call_sparse_decode_forward_sm100_paged_fp8( + q.contiguous(), + k.contiguous(), + v.contiguous(), + None if q2k_indices is None else q2k_indices.contiguous(), + page_table.contiguous(), + seqused_k.contiguous(), + out, + lse, + schedule, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + max_seqlen_k=int(max_seqlen_k), + blk_kv=int(blk_kv), + causal=bool(causal), + return_lse=bool(return_softmax_lse), + ) + if return_softmax_lse: + return out, lse + return out + + +class SparseDecodePagedAttentionWrapper: + """Plan/run helper for paged FP8 decode attention. + + Use this wrapper when the same page table shape and sequence metadata are + reused across multiple decode layers. ``plan`` validates metadata and + allocates persistent schedules/workspaces; ``run`` then launches the decode + kernel with lower per-call overhead than ``sparse_decode_atten_func``. + """ + + def __init__(self, *, blk_kv: int = 128, causal: bool = True): + self.blk_kv = int(blk_kv) + self.causal = bool(causal) + self.batch: Optional[int] = None + self.num_qo_heads: Optional[int] = None + self.num_kv_heads: Optional[int] = None + self.head_dim: Optional[int] = None + self.page_table: Optional[torch.Tensor] = None + self.seqused_k: Optional[torch.Tensor] = None + self.q2k_indices: Optional[torch.Tensor] = None + self.seqlen_q: Optional[int] = None + self.max_seqlen_k: Optional[int] = None + self.is_sparse: bool = False + self.decode_schedule: Optional[DecodeAttentionSchedule] = None + self.request_indices: Optional[torch.Tensor] = None + self.qo_tile_indices: Optional[torch.Tensor] = None + self.kv_tile_indices: Optional[torch.Tensor] = None + self.merge_indptr: Optional[torch.Tensor] = None + self.o_indptr: Optional[torch.Tensor] = None + self.block_valid_mask: Optional[torch.Tensor] = None + self.kv_pages: Optional[torch.Tensor] = None + self.split_counts: Optional[torch.Tensor] = None + self.split_kv: bool = False + self.cta_tile_q: int = 0 + self.num_q_tiles: int = 0 + self.kv_chunk_size_pages: int = 0 + self.kv_chunk_size_tokens: int = 0 + self.work_count: int = 0 + self.padded_work_count: int = 0 + self.O_partial: Optional[torch.Tensor] = None + self.LSE_partial: Optional[torch.Tensor] = None + # Cached dummy buffers used in non-split path to satisfy the kernel's + # positional arg signature without per-call torch.empty (saves ~5us + # on every run() for small kv). + self._O_partial_dummy: Optional[torch.Tensor] = None + self._LSE_partial_dummy: Optional[torch.Tensor] = None + # When the caller doesn't ask for LSE, the kernel still needs a valid + # tensor pointer to write to. Cache a small placeholder so run() can + # skip the per-call torch.empty for it as well. + self._lse_dummy: Optional[torch.Tensor] = None + + def plan( + self, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + q2k_indices: Optional[torch.Tensor] = None, + num_qo_heads: Optional[int] = None, + num_kv_heads: Optional[int] = None, + head_dim: Optional[int] = 128, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, + ) -> "SparseDecodePagedAttentionWrapper": + """Prepare decode scheduling metadata and reusable workspaces. + + Parameters + ---------- + page_table : torch.Tensor + Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps + logical pages to physical KV-cache pages. + seqused_k : torch.Tensor + Shape ``[batch_size]``, dtype int32. Effective KV length per + request. + seqlen_q : int + Uniform query length per request. + max_seqlen_k : int + Maximum KV sequence length in the batch. + q2k_indices : torch.Tensor, optional + Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and + dtype int32. ``None`` selects the dense all-KV path. + num_qo_heads : int + Number of Q/O heads. + num_kv_heads : int + Number of KV heads. Current decode kernel requires + ``num_qo_heads / num_kv_heads == 16`` at run time. + head_dim : int, optional + Head dimension. Must be 128. + enable_cuda_graph : bool, optional + Build schedule metadata compatible with CUDA graph capture. + max_grid_size : int, optional + Override maximum CTA count used by the scheduler. + fixed_split_size : int, optional + Force a fixed split-KV chunk size in pages. + disable_split_kv : bool, optional + Disable split-KV even for long KV sequences. + + Returns + ------- + SparseDecodePagedAttentionWrapper + ``self``, planned and ready for ``run``. + """ + if page_table.ndim != 2: + raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]") + if page_table.dtype != torch.int32: + raise TypeError("decode plan requires page_table to be torch.int32") + if seqused_k.dtype != torch.int32: + raise TypeError("decode plan requires seqused_k to be torch.int32") + if not page_table.is_cuda or not seqused_k.is_cuda: + raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors") + if page_table.device != seqused_k.device: + raise ValueError("decode plan requires page_table and seqused_k on the same device") + if page_table.stride(-1) != 1: + raise ValueError("decode plan requires page_table contiguous in the last dimension") + if seqused_k.shape != (int(page_table.shape[0]),): + raise ValueError("decode plan requires seqused_k with shape [B]") + if q2k_indices is not None and q2k_indices.dtype != torch.int32: + raise TypeError("decode plan requires q2k_indices to be torch.int32") + if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0: + raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k") + if num_qo_heads is None or num_kv_heads is None or head_dim is None: + raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim") + if head_dim is not None and int(head_dim) != 128: + raise NotImplementedError("decode plan currently supports only head_dim=128") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads") + + self.batch = int(page_table.shape[0]) + self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads) + self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads) + self.head_dim = None if head_dim is None else int(head_dim) + self.page_table = page_table.contiguous() + self.seqused_k = seqused_k.contiguous() + self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous() + self.seqlen_q = int(seqlen_q) + self.max_seqlen_k = int(max_seqlen_k) + self.is_sparse = q2k_indices is not None + + # max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++ + # schedule launcher because the decode attn kernel always runs at + # 1 CTA/SM (its register/smem budget saturates the SM). Callers + # can still override via the explicit max_grid_size kwarg. + schedule = prepare_decode_schedule( + seqused_k=self.seqused_k, + page_size=self.blk_kv, + seqlen_q=self.seqlen_q, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seqlen_k=self.max_seqlen_k, + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=max_grid_size, + fixed_split_size=fixed_split_size, + disable_split_kv=bool(disable_split_kv), + ) + self.decode_schedule = schedule + self.request_indices = schedule.request_indices + self.qo_tile_indices = schedule.qo_tile_indices + self.kv_tile_indices = schedule.kv_tile_indices + self.merge_indptr = schedule.merge_indptr + self.o_indptr = schedule.o_indptr + self.block_valid_mask = schedule.block_valid_mask + self.kv_pages = schedule.kv_pages + self.split_counts = schedule.split_counts + self.split_kv = schedule.split_kv + self.cta_tile_q = schedule.cta_tile_q + self.num_q_tiles = schedule.num_q_tiles + self.kv_chunk_size_pages = schedule.kv_chunk_size_pages + self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens + self.work_count = schedule.work_count + self.padded_work_count = schedule.padded_work_count + if schedule.split_kv: + self.O_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self.LSE_partial = torch.empty( + (schedule.partial_rows, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + self._O_partial_dummy = None + self._LSE_partial_dummy = None + else: + self.O_partial = None + self.LSE_partial = None + # decode_forward_paged_fp8 always wants non-None partial buffers + # for the kernel's positional arg layout (compile keeps the slot + # alive even when split_kv=False). Allocate once here and reuse. + self._O_partial_dummy = torch.empty( + (1, self.head_dim), + dtype=torch.float32, + device=page_table.device, + ) + self._LSE_partial_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + # LSE dummy is shape (1, head_q) — used when caller doesn't request + # LSE and the schedule isn't split-KV (split-KV always writes LSE). + self._lse_dummy = torch.empty( + (1, self.num_qo_heads), + dtype=torch.float32, + device=page_table.device, + ) + return self + + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: Optional[float] = None, + return_softmax_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + ): + """Launch decode using metadata cached by ``plan``. + + Parameters + ---------- + q : torch.Tensor + Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3. + k : torch.Tensor + Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``. + v : torch.Tensor + Paged V cache with the same shape as ``k``. + softmax_scale : float, optional + Softmax scale. Defaults to ``1 / sqrt(128)``. + return_softmax_lse : bool, optional + If True, return ``(out, lse)``. + out : torch.Tensor, optional + Preallocated BF16 output buffer with shape ``q.shape``. + lse : torch.Tensor, optional + Preallocated float32 LSE buffer with shape ``[total_q, Hq]``. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + BF16 output, optionally with float32 LSE. + """ + if self.decode_schedule is None: + raise RuntimeError("decode wrapper must be planned before run") + if self.is_sparse: + # Sparse path still goes through the validating wrapper for now; + # only the dense fast path is collapsed. + return sparse_decode_atten_func( + q, k, v, self.q2k_indices, + page_table=self.page_table, seqused_k=self.seqused_k, + seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, + blk_kv=self.blk_kv, causal=self.causal, + softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse, + schedule=self.decode_schedule, + O_partial=self.O_partial, LSE_partial=self.LSE_partial, + ) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + if out is None: + out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device) + if lse is None: + if return_softmax_lse or self.split_kv: + # Real LSE needed — must allocate per-call (shape depends on q). + lse = torch.empty( + q.shape[:2], dtype=torch.float32, device=q.device, + ) + else: + # Kernel only needs a valid pointer; reuse cached dummy. + lse = self._lse_dummy + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + schedule = self.decode_schedule + decode_forward_paged_fp8( + q, k, v, + self.page_table, self.seqused_k, + out, lse, + schedule.request_indices, schedule.qo_tile_indices, + schedule.kv_tile_indices, schedule.block_valid_mask, + schedule.split_counts, schedule.o_indptr, schedule.merge_indptr, + self.O_partial, self.LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=self.seqlen_q, + page_size=self.blk_kv, + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=self.causal, + return_lse=bool(return_softmax_lse), + # cached dummies — avoid per-call torch.empty inside run_decode_attention + O_partial_dummy=self._O_partial_dummy, + LSE_partial_dummy=self._LSE_partial_dummy, + ) + if return_softmax_lse: + return out, lse + return out + + +def _sparse_atten_csr_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + blk_kv: int, + causal: bool, + softmax_scale: float, + lse_temperature_scale: float, + return_temperature_lse: bool, + partial_dtype: torch.dtype, + return_softmax_lse: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + page_table: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + schedule: Optional[SparseAttentionSchedule], + usable_SM_count: int, + batch: int, + head_kv: int, + max_seqlen_q: int, + max_seqlen_k: int, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + total_q, head_q, dim = q.shape + if head_q % head_kv != 0: + raise ValueError("q.shape[1] must be divisible by head_kv") + max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr) + temperature_lse_fast_path = ( + return_temperature_lse + and math.isclose( + float(lse_temperature_scale), + 1.0, + rel_tol=0.0, + abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL, + ) + ) + kernel_return_temperature_lse = ( + return_temperature_lse and not temperature_lse_fast_path + ) + + O_partial = torch.empty( + topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device + ) + LSE_partial = torch.empty( + topK, total_q, head_q, dtype=torch.float32, device=q.device + ) + LSE_temperature_partial = ( + torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device) + if kernel_return_temperature_lse + else None + ) + O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device) + LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device) + LSE_temperature_out = ( + torch.empty_like(LSE_out) if kernel_return_temperature_lse else None + ) + if schedule is None: + k2q_qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.zeros( + (total_q, head_kv), + dtype=torch.int32, + device=q.device, + ) + else: + _validate_fwd_schedule( + schedule, + q=q, + k2q_q_indices=k2q_q_indices, + head_kv=head_kv, + ) + k2q_qsplit_indices = schedule.qsplit_indices + split_counts = schedule.split_counts + schedule = _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + kernel_return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count, + causal=causal, + schedule=schedule, + qk_dtype=qk_dtype, + pv_dtype=pv_dtype, + ) + # Sparse Attention and Sparse Page Attention both use the varlen-Q + # combine path; the kernel-written LSE_out is the final contract. + combine( + O_partial, + LSE_partial, + O_out, + LSE_out, + lse_temperature_partial=LSE_temperature_partial, + lse_temperature_out=LSE_temperature_out, + cu_seqlens=cu_seqlens_q, + split_counts=split_counts, + use_pdl=True, + ) + if temperature_lse_fast_path: + LSE_temperature_out = LSE_out + + if return_softmax_lse: + if return_temperature_lse: + return O_out, LSE_out, LSE_temperature_out + return O_out, LSE_out + return O_out + + +def _call_sparse_decode_forward_sm100_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: Optional[torch.Tensor], + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + schedule: DecodeAttentionSchedule, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int, + causal: bool, + return_lse: bool = True, +) -> None: + """Compile and launch the SM100 paged fp8 decode forward kernel. + + Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse + the same schedule wrapper but needs a separate q2k gather path. + """ + if q2k_indices is not None: + raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet") + if schedule.cta_tile_q != 128: + raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}") + if schedule.split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode forward requires O_partial and LSE_partial") + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + + from .src.sm100.fwd_decode import decode_forward_paged_fp8 + + decode_forward_paged_fp8( + q, + k, + v, + page_table, + seqused_k, + out, + lse, + schedule.request_indices, + schedule.qo_tile_indices, + schedule.kv_tile_indices, + schedule.block_valid_mask, + schedule.split_counts, + schedule.o_indptr, + schedule.merge_indptr, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(blk_kv), + kv_chunk_size_pages=int(schedule.kv_chunk_size_pages), + max_split_count=int(schedule.max_split_count), + split_kv=bool(schedule.split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + ) + + +def _call_sparse_forward_sm100_csr_varlen( + q, + k, + v, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + usable_SM_count=-1, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, + qk_dtype: torch.dtype, + pv_dtype: torch.dtype, +): + """Compile and launch the SM100 sparse forward K1 kernel on CSR metadata.""" + head_dim = q.shape[-1] + dtype = q.dtype + qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype") + pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype") + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size) + else: + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + k.dtype, + v.dtype, + qk_dtype, + pv_dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype), + pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"): + _compile_cache[key]( + k_kernel, + v_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule + + +def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv( + q, + k, + v, + k_scale_128x4, + v_scale_128x4, + k_global_scale, + v_global_scale, + k2q_row_ptr, + k2q_q_indices, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + cu_seqlens_k, + page_table, + seqused_k, + O_partial, + LSE_partial, + LSE_temperature_partial, + softmax_scale, + lse_temperature_scale, + return_temperature_lse, + max_num_kv_blocks, + blk_kv, + head_kv, + max_seqlen_q, + *, + causal=False, + use_prepare_scheduler=True, + schedule: Optional[SparseAttentionSchedule] = None, +): + """Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V.""" + + head_dim = q.shape[-1] + dtype = q.dtype + partial_dtype = O_partial.dtype + return_temperature_lse = bool(return_temperature_lse) + if return_temperature_lse != (LSE_temperature_partial is not None): + raise ValueError( + "return_temperature_lse must match LSE_temperature_partial presence" + ) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + lse_temperature_inv_scale = 1.0 / lse_temperature_scale + n_block_size = int(blk_kv) + head_q = q.shape[1] + qhead_per_kv = head_q // head_kv + fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0" + k_global_scale_kernel = k_global_scale + # V global scale is linear in the final output. Keep K1 on block-scale-only V + # and apply the tensor scale once in K2 combine. + v_global_scale_kernel = None + has_k_global_scale = k_global_scale_kernel is not None + has_v_global_scale = v_global_scale_kernel is not None + paged_kv = page_table is not None + if not bool(use_prepare_scheduler): + raise RuntimeError("KVFP4 sparse forward requires prepare scheduler") + schedule_enabled = k2q_row_ptr.shape[1] > 1 + page_size = int(k.shape[2]) if paged_kv else None + if paged_kv: + _prepare_paged_kv_for_tma(k, v, n_block_size) + k_kernel = k + v_kernel = v + O_partial_flat = O_partial.reshape(-1, head_dim).contiguous() + Q_flat = q.reshape(-1, head_dim).contiguous() + Q_gather4_desc = ( + create_q_gather4_tma_desc( + Q_flat, + box_x=128 if q.dtype == torch.float8_e4m3fn else 64, + ) + if qhead_per_kv in (1, 2, 4) + else None + ) + if schedule is None: + schedule = prepare_sparse_fwd_schedule_and_split( + k2q_row_ptr=k2q_row_ptr, + k2q_q_indices=k2q_q_indices, + k2q_qsplit_indices=k2q_qsplit_indices, + split_counts=split_counts, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + total_q=int(q.shape[0]), + max_seqlen_q=max_seqlen_q, + topk=int(O_partial.shape[0]), + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=n_block_size, + device=q.device, + enabled=schedule_enabled, + ) + use_prepare_scheduler = schedule.enabled + scheduler_metadata = schedule.scheduler_metadata + work_count = schedule.work_count + work_capacity = schedule.work_capacity + if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0: + raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule") + + key = ( + "sparse_forward_sm100_csr_varlen_nvfp4_kv", + head_dim, + n_block_size, + qhead_per_kv, + dtype, + partial_dtype, + bool(causal), + bool(paged_kv), + bool(use_prepare_scheduler), + page_size, + bool(seqused_k is not None), + bool(return_temperature_lse), + bool(fp8_pair_dequant), + bool(has_k_global_scale), + bool(has_v_global_scale), + ) + if key not in _compile_cache: + from .src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _compile_cache[key] = loaded + else: + kernel = SparseAttentionForwardNvfp4KvSm100( + head_dim=head_dim, + qheadperkv=qhead_per_kv, + n_block_size=n_block_size, + paged_kv=paged_kv, + page_size=page_size, + has_seqused_k=seqused_k is not None, + causal=bool(causal), + use_prepare_scheduler=use_prepare_scheduler, + fp8_pair_dequant=bool(fp8_pair_dequant), + has_k_global_scale=bool(has_k_global_scale), + has_v_global_scale=bool(has_v_global_scale), + ) + _compile_cache[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k_kernel), + to_cute_tensor_kvouter(v_kernel), + to_cute_tensor_kvouter(k_scale_128x4), + to_cute_tensor_kvouter(v_scale_128x4), + None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel), + None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(k2q_row_ptr), + None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata), + None if work_count is None else to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(O_partial_flat), + to_cute_tensor_kvouter(LSE_partial), + None + if LSE_temperature_partial is None + else to_cute_tensor_kvouter(LSE_temperature_partial), + to_cute_tensor_kvouter(Q_flat), + None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc), + None if page_table is None else to_cute_tensor_kvouter(page_table), + None if seqused_k is None else to_cute_tensor_kvouter(seqused_k), + to_cute_tensor_kvouter(cu_seqlens_q), + to_cute_tensor_kvouter(cu_seqlens_k), + Float32(softmax_scale), + Float32(lse_temperature_inv_scale), + Int32(max_num_kv_blocks), + Int32(head_kv), + Int32(max_seqlen_q), + Int32(work_capacity), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _compile_cache[key]) + + with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"): + _compile_cache[key]( + k_kernel, + v_kernel, + k_scale_128x4, + v_scale_128x4, + k_global_scale_kernel, + v_global_scale_kernel, + k2q_q_indices, + k2q_qsplit_indices, + k2q_row_ptr, + scheduler_metadata, + work_count, + O_partial_flat, + LSE_partial, + LSE_temperature_partial, + Q_flat, + Q_gather4_desc, + page_table, + seqused_k, + cu_seqlens_q, + cu_seqlens_k, + softmax_scale, + lse_temperature_inv_scale, + max_num_kv_blocks, + head_kv, + max_seqlen_q, + work_capacity, + ) + return schedule diff --git a/build/torch212-cxx11-cu132-x86_64-linux/metadata.json b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..aa9f5795b50acb6564b56d16492191d3351afffd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json @@ -0,0 +1,71 @@ +{ + "name": "msa", + "id": "_msa_cuda_09d7851", + "version": 0, + "license": "other", + "upstream": "https://github.com/MiniMax-AI/MSA", + "python-depends": [ + "tvm-ffi", + "nvidia-cutlass-dsl" + ], + "backend": { + "type": "cuda", + "archs": [ + "10.0" + ] + }, + "digest": { + "algorithm": "sha256", + "files": { + "__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=", + "_msa_cuda_09d7851.abi3.so": "XZHGIzNwVfB2MCBfhtsx2SISTRTqErPuaqCS69+0Bag=", + "_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=", + "fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=", + "interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=", + "msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=", + "quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + "quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=", + "quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=", + "quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=", + "quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=", + "quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=", + "quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=", + "sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=", + "src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=", + "src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=", + "src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=", + "src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=", + "src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=", + "src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=", + "src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=", + "src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=", + "src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=", + "src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=", + "src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=", + "src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=", + "src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=", + "src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=", + "src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=", + "src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=", + "src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=", + "src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=", + "src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=", + "src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=", + "src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=", + "src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=", + "src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=", + "src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=", + "src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=", + "src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=", + "src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=", + "src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=", + "src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=", + "src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=", + "src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=", + "src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=", + "src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=", + "src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw=" + } + } +} \ No newline at end of file diff --git a/build/torch212-cxx11-cu132-x86_64-linux/metadata.json.sigstore b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json.sigstore new file mode 100644 index 0000000000000000000000000000000000000000..5d93479b38df19a403b9fdb517a2888dbbd992fa --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json.sigstore @@ -0,0 +1 @@ +{"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHTDCCBtKgAwIBAgIUQpv9YQ5ULyqkPFWTb8d81JHHwQ8wCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA3WhcNMjYwNjMwMTc1NDA3WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE0q6b6WGSG0+/oJTHLDb5sUwY35M16jKkbc6FYI4Dl8h1coW/y5EtYW/eJsQHEqbDKMMPYi1BrVsOQejgDdlKPqOCBfEwggXtMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQU0LZpnBn23+///p+JIHWKejjyyu4wHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBiwYKKwYBBAHWeQIEAgR9BHsAeQB3AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhk5kAAAQDAEgwRgIhAMa40CGZhAVrkrmStk5Q+zLDyIkx9Z7bc7IcKmqDYyhgAiEAv41b/F/aY9LvxOV8yBrFTwtOVcxV6mrtSvZMm9ec5YMwCgYIKoZIzj0EAwMDaAAwZQIxALfZDJxva9qoPY2CQVmfl4wtEwxiXeM3eDSMuPYidDzBOYxlFd9dfPyEjMxm16GVVwIwRSQuDn/8FXgZvhfMVzMQVqHm/G7zF2Ug0Lh16t8uDSGWpGvpvPcPVkKVc17zI1y5"},"tlogEntries":[{"logIndex":"2024793273","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841447","inclusionPromise":{"signedEntryTimestamp":"MEUCIQDwCxpqw7k7K791EwT2ggAw+Ej79OAwtd28bnfrP0ydbQIgEn5LMg5DuDPlsrzvOwoRU9pScvFMjt86W/hzjZUC/cU="},"inclusionProof":{"logIndex":"1902889011","rootHash":"ox7ck+UcDZ2wx2jI4fM2mGzq1IjeFjuh0oTspHSnlXQ=","treeSize":"1902889018","hashes":["Q4tQDfmLQ815yYvzXjrO5rDEzzpMUlAv6bqEvzkDdkI=","+4UYodE5ZXZBeP6x0mLWBXPMlyC2jswT737krdr6ZSs=","a9d3EQ7qr9o8u8+66PPMHDTxIBOE+MnCd7fh3g71PP0=","LM0BYIvW7lQaS7QX37Sf1wU8uLuOl3921bfOjDXu960=","HgDvlQ62RGLndg9GG5ql9Z/MRgUTyM2bXAu9kUuwC4Y=","pnClwW6lUYu/+5y41b8AT34qCvDe5Ym+MGv8VWKuU5c=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889018\nox7ck+UcDZ2wx2jI4fM2mGzq1IjeFjuh0oTspHSnlXQ=\n\n— rekor.sigstore.dev wNI9ajBEAh9hz7tbYPTN9RlSGoeUVpnM6MMH9GNtjvMqykWw9WCjAiEAlcnCrADhrf7Ckm53FB09VzwSvKyM0PUM1eEO5QcAYa0=\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiI5ZDEzMTIyODY2M2M2YzNjNjc2NmZjMWI4YTE3NTU4ZTYyYTY2NWVlYjg1MGUxNzllOGFhMTU3ZWEyYTQyMjA3In19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FVUNJQlQ5QkIwWkEzd1lqVDVoUys1Sy9VQytJL1BBS2JNWUJmNjh1NUMwbWtFQ0FpRUFrWVpJU0FobWgrZytVN3FUQndiZGo4aTN1SVREbUhaTDMrUG9ZRDhyWjhvPSIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFVSRU5EUW5STFowRjNTVUpCWjBsVlVYQjJPVmxSTlZWTWVYRnJVRVpYVkdJNFpEZ3hTa2hJZDFFNGQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTTFkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFelYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVV3Y1RaaU5sZEhVMGN3S3k5dlNsUklURVJpTlhOVmQxa3pOVTB4Tm1wTGEySmpOa1lLV1VrMFJHdzRhREZqYjFjdmVUVkZkRmxYTDJWS2MxRklSWEZpUkV0TlRWQlphVEZDY2xaelQxRmxhbWRFWkd4TFVIRlBRMEptUlhkbloxaDBUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlV3VEZwd0NtNUNiakl6S3k4dkwzQXJTa2xJVjB0bGFtcDVlWFUwZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwZDFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT1VKSWMwRUtaVkZDTTBGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2hyTld0QlFVRlJSQXBCUldkM1VtZEphRUZOWVRRd1EwZGFhRUZXY210eWJWTjBhelZSSzNwTVJIbEphM2c1V2pkaVl6ZEpZMHR0Y1VSWmVXaG5RV2xGUVhZME1XSXZSaTloQ2xrNVRIWjRUMVk0ZVVKeVJsUjNkRTlXWTNoV05tMXlkRk4yV2sxdE9XVmpOVmxOZDBObldVbExiMXBKZW1vd1JVRjNUVVJoUVVGM1dsRkplRUZNWmxvS1JFcDRkbUU1Y1c5UVdUSkRVVlp0Wm13MGQzUkZkM2hwV0dWTk0yVkVVMDExVUZscFpFUjZRazlaZUd4R1pEbGtabEI1UldwTmVHMHhOa2RXVm5kSmR3cFNVMUYxUkc0dk9FWllaMXAyYUdaTlZucE5VVlp4U0cwdlJ6ZDZSakpWWnpCTWFERTJkRGgxUkZOSFYzQkhkbkIyVUdOUVZtdExWbU14TjNwSk1YazFDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyjADAgEAMIICwQYJKoZIhvcNAQcCoIICsjCCAq4CAQMxDTALBglghkgBZQMEAgEwgbgGCyqGSIb3DQEJEAEEoIGoBIGlMIGiAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQgIO1cEzm94+e20Xhvqsis5nqe17uqSKlYTAPjPSzo5boCFQDDYmGiK+uOw649a+CZ8HDDkRQVrRgPMjAyNjA2MzAxNzQ0MDdaMAMCAQGgMqQwMC4xFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEVMBMGA1UEAxMMc2lnc3RvcmUtdHNhoAAxggHbMIIB1wIBATBRMDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAsGCWCGSAFlAwQCAaCB/DAaBgkqhkiG9w0BCQMxDQYLKoZIhvcNAQkQAQQwHAYJKoZIhvcNAQkFMQ8XDTI2MDYzMDE3NDQwN1owLwYJKoZIhvcNAQkEMSIEIDamfhTMk4zgMmhXnMTA1R5FFReeiwadN7O4+ZP7UHXFMIGOBgsqhkiG9w0BCRACLzF/MH0wezB5BCCF+Se8B6tiysO0Q1bBDvyBssaIP9p6uebYcNnROs0FtzBVMD2kOzA5MRUwEwYDVQQKEwxzaWdzdG9yZS5kZXYxIDAeBgNVBAMTF3NpZ3N0b3JlLXRzYS1zZWxmc2lnbmVkAhQ6E1QvDJBh7rzBQy/Lio6LKiOLDDAKBggqhkjOPQQDAgRnMGUCMQCk0R36yJLXxfVrcOyKxnOaf/AqVeule1jCmYDp6noXvr/iDh57kjw6twFTRefMaAgCMBK0v7qzD4/kjrcMvNVBC3JDjrS+lroR6JnLdoPQo2eUO9ugfUAT9PoXlFUvNho+Aw=="}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"nRMSKGY8bDxnZvwbihdVjmKmZe64UOF56KoVfqKkIgc="},"signature":"MEUCIBT9BB0ZA3wYjT5hS+5K/UC+I/PAKbMYBf68u5C0mkECAiEAkYZISAhmh+g+U7qTBwbdj8i3uITDmHZL3+PoYD8rZ8o="}} \ No newline at end of file diff --git a/build/torch212-cxx11-cu132-x86_64-linux/msa/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/msa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/msa/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/activation.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cbeb29242b92b7cc336cd336604e58c36f4459 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quack/activation.py @@ -0,0 +1,532 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Tuple +from functools import partial + +import cutlass.cute as cute +from cutlass import Float32, Boolean, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm + + +F32_or_F32x2 = Float32 | Tuple[Float32, Float32] + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, +) + + +@dsl_user_op +def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) + return 0.5 + 0.5 * tanh(0.5 * x) + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + + +@dsl_user_op +def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + # return dout * out * (1.0 - out) + return dout * (out - out * out) + + +@dsl_user_op +def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) + else: + return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)) + + +@dsl_user_op +@cute.jit +def drelu( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0)) + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0)) + return dx, relu(x) + + +@dsl_user_op +def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * x + else: + relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) + return cute.arch.mul_packed_f32x2(relu_x, x) + + +@dsl_user_op +@cute.jit +def drelu_sq( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward + Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out + Returns: (dx, relu_sq_out) where: + - dx = dout * 2 * x if x > 0, else 0 + - relu_sq_out = max(x, 0) * x + """ + if const_expr(not isinstance(x, tuple)): + relu_x = relu(x) + relu_sq_out = relu_x * x + # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0 + dx = 2.0 * (dout * relu_x) + return dx, relu_sq_out + else: + relu_x = relu(x) + relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x) + dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x)) + return dx, relu_sq_out + + +@dsl_user_op +def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ + gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x))) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774 + if const_expr(not isinstance(x, tuple)): + return 0.5 * ( + x + # Currently cute.math.tanh(x, fastmath=True) generates very slow code + # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True)) + * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) + ) + else: + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x) + return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + + +@dsl_user_op +def dgelu_tanh_approx( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward + Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out + Returns: (dx, gelu_out) + + Derivative uses the chain rule: + d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2 + and sech^2(z) = 1 - tanh^2(z) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774 + sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322 + + if const_expr(not isinstance(x, tuple)): + # Compute z = x * (c1 + c2 * x^2) + x_sq = x * x + # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True) + tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq)) + half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z + gelu_out = x * half_tanh_z_plus_one + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = 1 - tanh_z * tanh_z + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx)) + + dx = dout * dgelu + return dx, gelu_out + else: + # Compute z = x * (c1 + c2 * x^2) + x_sq = cute.arch.mul_packed_f32x2(x, x) + x_sq_scaled = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = cute.arch.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one) + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = cute.arch.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + + dx = cute.arch.mul_packed_f32x2(dout, dgelu) + return dx, gelu_out + + +@dsl_user_op +@cute.jit +def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + use_linear = Boolean(x > 20.0) + return ( + cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True) + if not use_linear + else x + ) + else: + log2_e = math.log2(math.e) + x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e)) + x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) + x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0)) + log_x_exp_p1 = ( + cute.math.log2(x_exp_p1[0], fastmath=True), + cute.math.log2(x_exp_p1[1], fastmath=True), + ) + ln2 = math.log(2.0) + softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + use_linear_0 = Boolean(x[0] > 20.0) + use_linear_1 = Boolean(x[1] > 20.0) + return ( + softplus_x[0] if not use_linear_0 else x[0], + softplus_x[1] if not use_linear_1 else x[1], + ) + + +@dsl_user_op +@cute.jit +def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + use_linear = Boolean(out > 20.0) + # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout + dx = dout - dout * cute.math.exp(-out, fastmath=True) + return dx if not use_linear else dout + + +@dsl_user_op +def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2: + """ + silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA. + """ + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x if const_expr(not already_halved) else x + # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half + return x_half * tanh(x_half) + x_half + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half) + + +@dsl_user_op +def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return silu(x) * y + else: + return cute.arch.mul_packed_f32x2(silu(x), y) + + +@dsl_user_op +def dswiglu( + x: F32_or_F32x2, + y: F32_or_F32x2, + dout: F32_or_F32x2, + *, + already_halved: bool = False, + loc=None, + ip=None, +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out + Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x) + + d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + + This has been optimized to use fewer instructions (i.e. we expand things out + to use FFMA instead of FADD and FMUL). + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) + # FMUL, MUFU.TANH, then FFMA + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = x * sigmoid_x # FMUL + else: + tanh_x = tanh(x) # MUFU.TANH + sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA + silu_x = x * tanh_x + x # FFMA + silu_x_dout = silu_x * dout # FMUL + # d_silu(x) * dout + # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout + # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA + dx = d_silu_x_dout * y # FMUL + dy = silu_x_dout + swiglu_out = silu_x * y # FMUL + # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(x) and silu(x) + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x) + else: + tanh_x = (tanh(x[0]), tanh(x[1])) + sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2( + sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x + ) + d_silu_x_dout = cute.arch.fma_packed_f32x2( + sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout + ) + dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y) + dy = silu_x_dout + swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y) + return dx, dy, swiglu_out + + +@dsl_user_op +def swiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> F32_or_F32x2: + """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y. + https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249 + x * sigmoid(alpha * x) * (y + 1) + Compile down to FMUL, FMUL, TANH, FFMA, FFMA + """ + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x + # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half + silu_x = x_half * tanh(alpha * x_half) + x_half + return silu_x * y + silu_x + else: + x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + + +@dsl_user_op +def dswiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + Swiglu OAI backward pass: computes gradients w.r.t. x and y + Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out + Returns: (dx, dy, swiglu_oai_out) + + Derivative of x * sigmoid(alpha * x) w.r.t. x: + d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x)) + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + alpha_x_half = (0.5 * alpha) * x # FMUL + # MUFU.TANH, then FFMA + # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True) + sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) + silu_x = x * sigmoid_alpha_x # FMUL + silu_x_dout = silu_x * dout # FMUL + # FFMA, FFMA, FMUL + d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1 + dy = silu_x_dout + swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1 + # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(alpha * x) + alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout) + # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + silu_x_minus_product = cute.arch.fma_packed_f32x2( + silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x + ) + sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2( + (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x + ) + d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + dy = silu_x_dout + swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x) + return dx, dy, swiglu_out + + +@dsl_user_op +def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GLU: Gated Linear Unit + glu(x, y) = sigmoid(x) * y + Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + """ + if const_expr(not isinstance(x, tuple)): + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + return sigmoid_x * y # FMUL + else: + sigmoid_x = sigmoid(x) + return cute.arch.mul_packed_f32x2(sigmoid_x, y) + + +@dsl_user_op +def dglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out + Returns: (dx, dy, glu_out) where: + - dx = dout * y * sigmoid(x) * (1 - sigmoid(x)) + - dy = dout * sigmoid(x) + - glu_out = sigmoid(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + sigmoid_x_dout = sigmoid_x * dout # FMUL + glu_out = sigmoid_x * y # FMUL + # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout + # = y * (1 - sigmoid(x)) * sigmoid_x_dout + # = (y - y * sigmoid(x)) * sigmoid_x_dout + # = (y - glu_out) * sigmoid_x_dout + dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL + dy = sigmoid_x_dout + # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA + return dx, dy, glu_out + else: + sigmoid_x = sigmoid(x) + sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout) + glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y) + # dx = (y - glu_out) * sigmoid_x_dout + y_minus_glu_out = sub_packed_f32x2(y, glu_out) + dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + dy = sigmoid_x_dout + return dx, dy, glu_out + + +@dsl_user_op +def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ReGLU: ReLU Gated Linear Unit + reglu(x, y) = relu(x) * y = max(x, 0) * y + """ + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * y + else: + relu_x = relu(x) + return cute.arch.mul_packed_f32x2(relu_x, y) + + +@dsl_user_op +@cute.jit +def dreglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out + Returns: (dx, dy, reglu_out) where: + - dx = dout * y if x > 0, else 0 + - dy = dout * relu(x) + - reglu_out = relu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + relu_x = cute.arch.fmax(x, Float32(0.0)) + dx = (dout * y) if x_pos else Float32(0.0) + dy = dout * relu_x + reglu_out = relu_x * y + return dx, dy, reglu_out + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + relu_x = relu(x) + dout_y = cute.arch.mul_packed_f32x2(dout, y) + dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) + dy = cute.arch.mul_packed_f32x2(dout, relu_x) + reglu_out = cute.arch.mul_packed_f32x2(relu_x, y) + return dx, dy, reglu_out + + +@dsl_user_op +def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GeGLU: GELU Gated Linear Unit + geglu(x, y) = gelu(x) * y + Uses the tanh approximation of GELU + """ + if const_expr(not isinstance(x, tuple)): + return gelu_tanh_approx(x) * y + else: + return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y) + + +@dsl_user_op +def dgeglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out + Returns: (dx, dy, geglu_out) where: + - dx = dout * y * d_gelu(x) + - dy = dout * gelu(x) + - geglu_out = gelu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = dgelu_x_dout * y + dy = gelu_x * dout + geglu_out = gelu_x * y + return dx, dy, geglu_out + else: + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y) + dy = cute.arch.mul_packed_f32x2(gelu_x, dout) + geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y) + return dx, dy, geglu_out diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/compile_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4375594669c8f12d6a79d8878316271cb819568a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quack/compile_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Optional + +import cutlass.cute as cute + + +def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/copy_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad989559766d6ee6e8ece9d322bf08980706dfa --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quack/copy_utils.py @@ -0,0 +1,890 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import re +from typing import Optional, Type, Tuple, Callable, Sequence +from functools import partial + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Int16, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline +from cutlass._mlir.dialects import llvm +from cutlass._mlir import ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + +Sm100MmaPeerBitMask = 0xFEFFFFFF + + +@dsl_user_op +def cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + retile: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + if const_expr(retile): + src = tiled_copy.retile(src) + cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def load_s2r_retile( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst_shape: cute.Tensor | cute.Shape, + *, + loc=None, + ip=None, +) -> cute.Tensor: + # Will also accept dst_shape being a tensor, in which case we write into that tensor + if const_expr(not isinstance(dst_shape, cute.Tensor)): + dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip) + else: + dst = dst_shape + cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + num_copy_elems = src.shape[0][0] + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +# def tiled_copy_2d( +# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +# ) -> cute.TiledCopy: +# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width +# copy_elems = num_copy_bits // dtype.width +# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() +# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +# gmem_threads_per_row = major_mode_size // copy_elems +# assert num_threads % gmem_threads_per_row == 0 +# thr_layout = cute.make_ordered_layout( +# (num_threads // gmem_threads_per_row, gmem_threads_per_row), +# order=(1, 0), +# ) +# val_layout = cute.make_layout((1, copy_elems)) +# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return b, m, s + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: + bit_msk = (1 << b) - 1 + yyy_msk = bit_msk << (m + s) + return ptr_int ^ ((ptr_int & yyy_msk) >> s) + + +def swizzle_ptr(ptr: cute.Pointer): + b, m, s = parse_swizzle_from_pointer(ptr) + ptr_int = swizzle_int(ptr.toint(), b, m, s) + return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) + + +def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: + outer = tensor.layout + width = tensor.element_type.width + inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for + # for 16 bits and <3, 2, 3> for 32 bits) + new_layout = cute.recast_layout( + width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer)) + ) + # recast_ptr to remove the pointer swizzle + return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout) + + +def partition_D_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_D(tensor).iterator), + thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +def partition_S_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_S(tensor).iterator), + thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +@dsl_user_op +def sm90_get_smem_load_op( + layout_c: cutlass.utils.LayoutEnum, + elem_ty_c: Type[cutlass.Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem load atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_c : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_c : Type[Numeric] + The element type for output tensor D. + + Returns: + -------- + Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters. + """ + + if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta): + raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") + is_m_major = layout_c.is_m_major_c() + if elem_ty_c.width == 16: + return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip) + else: + return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_load_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_store_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + + def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs): + dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx] + cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sC + + +def get_smem_load_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sC = thr_copy.partition_S(sC) + else: + tSR_sC = partition_S_position_independent(thr_copy, sC) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape + + def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs): + src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx] + return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs) + + return copy_fn, thr_copy, tSR_sC + + +def epilog_smem_copy_atom( + tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False +) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + +def get_smem_store_epi( + tiled_mma: cute.TiledMma, + epi_tile: cute.Shape, + sC: Optional[cute.Tensor], + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]: + dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16 + tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile) + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom) + thr_copy = tiled_copy.get_slice(tidx) + tRS_sC = None + if const_expr(sC is not None): + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + sC_shape = sC.shape[:2] if sC is not None else epi_tile + # (R2S, R2S_M, R2S_N, PIPE_C) + tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape + tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs) + + return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC + + +def get_smem_store_A( + tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sA = thr_copy.partition_D(sA) + else: + tRS_sA = partition_D_position_independent(thr_copy, sA) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sA + + +def get_smem_load_A( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + tidx: Int32, + arch: int, + with_dst_tensor: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sA = thr_copy.partition_S(sA) + else: + tSR_sA = partition_S_position_independent(thr_copy, sA) + tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs): + return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs) + + return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + Get the address of the TMA descriptor embedded in a TMA Copy Atom. + + Extracts the constant memory address of the TMA descriptor for use with + custom PTX instructions. + + :param tma_atom: TMA Copy Atom from make_tiled_tma_atom + :return: Pointer to TMA descriptor in constant memory + + Example: + >>> desc_ptr = get_tma_descriptor_address(tma_atom) + """ + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + + +@dsl_user_op +def tma_gather4_load( + tma_desc_ptr: cute.Pointer, + dst_smem_ptr: cute.Pointer, + mbarrier_ptr: cute.Pointer, + col_idx: Int32, + row_indices: Sequence[Int32], + *, + num_cta: int = 1, + multicast_mask=None, + loc=None, + ip=None, +) -> None: + """ + Perform TMA gather4 load from global memory to shared memory. + + Issues PTX instruction: + cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar]; + + This loads 4 rows (specified by row_indices) from a 2D tensor at the given + column index into shared memory, using the TMA descriptor. + + :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned) + :type tma_desc_ptr: Pointer + :param dst_smem_ptr: Destination address in shared memory + :type dst_smem_ptr: Pointer + :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking + :type mbarrier_ptr: Pointer + :param col_idx: Column index + :type col_idx: Int32 + :param row_indices: Sequence of exactly 4 row indices + :type row_indices: Sequence[Int32] + :param num_cta: Number of CTAs participating (default: 1) + :type num_cta: int + :param multicast_mask: Optional multicast mask + :type multicast_mask: Int16 + + Requirements: + - row_indices must contain exactly 4 elements + - Compute capability >= SM_100 (Blackwell) + - TMA descriptor must be properly initialized for 2D tensor + + Example: + >>> from cutlass.cute.nvgpu import cpasync + >>> from cutlass.cute import core + >>> + >>> # Create TMA descriptor + >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...) + >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom) + >>> + >>> # Compute indices (typically from kernel logic) + >>> col_idx = core.get(...) or 5 # Int32 value + >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values + >>> + >>> # Gather 4 rows at computed column + >>> tma_gather4_load( + ... tma_desc_ptr=tma_desc_ptr, + ... dst_smem_ptr=smem_ptr, + ... mbarrier_ptr=barrier_ptr, + ... col_idx=col_idx, + ... row_indices=row_indices + ... ) + """ + if len(row_indices) != 4: + raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}") + col_val = Int32(col_idx).ir_value() + row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices] + # Convert pointers to integer addresses + desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip) + if num_cta > 1: + # Executed by both CTAs. Set peer bit to 0 so that the + # transaction bytes will update CTA0's barrier. + mbar_addr = mbar_addr & Sm100MmaPeerBitMask + mbar_addr = mbar_addr.ir_value() + # Handle multicast_mask - may already be ir.Value or Python int + multicast_mask_val = None + if multicast_mask is not None: + multicast_mask_val = Int16(multicast_mask).ir_value() + assert multicast_mask_val is None, "multicast is not supported yet" + # Emit inline PTX for TMA gather4 + # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes + # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar]; + ptx = ( + f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} " + "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];" + ) + + llvm.inline_asm( + None, + [ + dst_addr, + desc_addr, + col_val, + row_vals[0], + row_vals[1], + row_vals[2], + row_vals[3], + mbar_addr, + ], + ptx, + "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy( + atom, + src[None, src_idx], + dst[None, dst_idx], + mbar_ptr=tma_bar_ptr, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs): + atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type) + with cute.arch.elect_one(): + cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +@cute.jit +def gather_m_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_M), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + tAsA = thr_copy_A.partition_D(sA) + # k-major + assert tAsA.shape[2] == 1 + tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + m_idx = cute.make_rmem_tensor(rows_per_thread, Int32) + for m in cutlass.range(rows_per_thread, unroll_full=True): + row_idx = tAcA[0, m, 0][0] + if tApA_m[m]: + m_idx[m] = gsAIdx[row_idx] + else: + m_idx[m] = 0 # It's ok to load row 0 in the case of OOB + + mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + + def copy_fn(src_idx, dst_idx, pred: bool = False): + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + mA_cur = mA_k[None, (None, src_idx)] + for m in cutlass.range_constexpr(tAcA.shape[1]): + # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape + # ((elems_per_load), thread_per_row) + # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA + # So we append 1s to the last dimension and then do tiled_divide, then slice. + mA_row = cute.tiled_divide( + cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1) + )[None, None, 0] + if const_expr(is_even_m_smem) or tApA_m[m]: + # There's only 1 load per row + assert cute.size(tAcA.shape, mode=[2]) == 1 + ki = tAcA[0, 0, 0][1] // elems_per_load + cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k) + + return copy_fn + + +@cute.jit +def gather_k_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (tile_M, whatever) + sA: cute.Tensor, # (tile_M, tile_K, STAGE) + gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + gAIdx, sAIdx = None, None + if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem): + gAIdx = gsAIdx + else: + assert gsAIdx.memspace == cute.AddressSpace.smem + sAIdx = gsAIdx + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + # (atom_v, CPY_M, 1, STAGE) + tAsA = thr_copy_A.partition_D(sA) + # m-major + tAsA = cute.group_modes(tAsA, 0, 3) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) + # This is very convoluted but idk a better way + # for tile_M=128, flat_divide gives (8, 16, K), + # then logical_divide gives ((8, 1), (8, 2), K). + tidx = thr_copy_A.thr_idx + tAmA = cute.logical_divide( + cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col) + )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K) + + def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]: + # Prefetch mAIdx early, even before smem is free + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + gAIdx_cur = gAIdx[None, src_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + if const_expr(not pred): + k_idx[k] = gAIdx_cur[col_idx] + else: + if tApA_k[k]: + k_idx[k] = gAIdx_cur[col_idx] + else: + k_idx[k] = -1 + return k_idx, tApA_k + + def prefetch_from_smem_fn( + a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False + ) -> Tuple[cute.Tensor, cute.Tensor]: + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_cur = sAIdx[None, dst_idx] + k_idx = cute.make_rmem_tensor(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + k_idx[k] = sAIdx_cur[col_idx] + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return k_idx, tApA_k + + def copy_fn( + src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False + ): + k_idx, tApA_k = k_idx_tApA_k + tApA_k_pred = None + if const_expr(pred): + tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread) + for k in cutlass.range_constexpr(tAcA.shape[2]): + # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2)) + for m in cutlass.range_constexpr(tAcA.shape[1]): + if tApA_m[m]: + cute.copy( + thr_copy_A, + tAmA[None, m, k_idx[k]], + tAsA[(None, m, k), dst_idx], + pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k], + ) + + return copy_fn, prefetch_from_gmem_fn if const_expr( + gAIdx is not None + ) else prefetch_from_smem_fn + + +@cute.jit +def gather_m_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # ((4, 32), (64, 1), STAGE) + sAIdx: cute.Tensor, # (tile_M), + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Callable: + tile_M = cute.size(sAIdx, mode=[0]) + tile_K = cute.size(sA[None, None, 0]) // tile_M + assert tile_M % 4 == 0 + # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2 + cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel + + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout + ) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) + # ((4, 1), 8, (64, 1), STAGE) + tSR_sA = warp_copy_AIdx_s2r.partition_S(sA) + tSR_rAIdx = load_s2r(tSR_sAIdx) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + col_idx = tile_K * src_idx + for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, m] for v in range(4)] + smem_ptr = tSR_sA[None, m, None, dst_idx].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) + + return copy_fn diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/cute_dsl_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c92cf39ac08b92245316da46526494d7d8370e1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quack/cute_dsl_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from functools import lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float16, BFloat16, Float32 +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: Float16, + torch.bfloat16: BFloat16, + torch.float32: Float32, + torch.int32: Int32, + torch.int64: Int64, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quack/layout_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/quack/layout_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..099e0daf54cdac4b25b6d96f01b35451c810249b --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quack/layout_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, const_expr + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: + shape = (*a.shape[:dim], size, *a.shape[dim:]) + stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +@cute.jit +def permute_gated_Cregs_b16(t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower) + + +@cute.jit +def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 + a b | c d | e f | g h + to + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [2, 0, 3, 1] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b10 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a b | c d | e f | g h -> a b | c d | f e | h g + left0 = left if quad_idx < 2 else right + right0 = right if quad_idx < 2 else left + # a b | c d | f e | h g -> a b | f d | c e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a e | f b | c g | h d + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a e | f b | c g | h d -> a e | b f | c g | d h + t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0 + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + + +@cute.jit +def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + to + T0 | T1 | T2 | T3 + a b | c d | e f | g h + This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [1, 3, 0, 2] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b01 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + # This is just the inverse of permute_Cregs_b32_for_stsm + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a e | b f | c g | d h -> a e | f b | c g | h d + left0 = left if quad_idx % 2 == 0 else right + right0 = right if quad_idx % 2 == 0 else left + # a e | f b | c g | h d -> a b | f d | c e | h g + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a b | c d | f e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | c d | f e | h g -> a b | c d | e f | g h + t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0 + + +@cute.jit +def concat_layout(*layouts: cute.Layout) -> cute.Layout: + return cute.make_layout( + tuple(l.shape for l in layouts), + stride=tuple(l.stride for l in layouts), + ) + + +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def convert_layout_zero_stride( + input: cute.Tensor | cute.Layout, ref_layout: cute.Layout +) -> cute.Layout: + layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input + # Group the modes with non-zero stride in the ref_layout together, + # and the modes with zero stride together + layout_flat = cute.flatten(layout) + ref_layout_flat = cute.flatten(ref_layout) + nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0] + zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0] + # There's an edge case when all modes are zero stride + new_shape = ( + tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,), + tuple(layout_flat[i].shape for i in zero_modes), + ) + new_stride = ( + tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,), + tuple(layout_flat[i].stride for i in zero_modes), + ) + out_layout = cute.make_layout(new_shape, stride=new_stride) + if const_expr(isinstance(input, cute.Tensor)): + return cute.make_tensor(input.iterator, out_layout) + else: + return out_layout + + +def mma_partition_C_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def mma_partition_A_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/quantize.py b/build/torch212-cxx11-cu132-x86_64-linux/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4719a4854bc9388b2a866598f9e21c1f14921181 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/quantize.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Transformer Engine NVFP4 quantization helper. + +This file is intended as a customer-facing example for preparing KV tensors +for the KVFP4 attention kernel: + - BF16/FP16 K/V input + - packed E2M1 FP4 data from Transformer Engine + - E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout + - one FP32 tensor/global scale per tensor +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + + +NVFP4_BLOCK_SIZE = 16 +NVFP4_FP4_MAX = 6.0 +NVFP4_FP8_E4M3_MAX = 448.0 + + +@dataclass(frozen=True) +class Nvfp4QuantizedTensor: + """Packed NVFP4 tensor plus dequantization metadata. + + Attributes + ---------- + data : torch.Tensor + Packed E2M1 FP4 data from Transformer Engine. The last dimension is + half of the original logical last dimension because each byte stores + two FP4 values. + scale_128x4 : torch.Tensor + E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage. + global_scale : torch.Tensor + FP32 tensor/global dequant scale. + logical_scale_shape : tuple[int, int] + Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling. + original_shape : tuple[int, ...] + Original BF16/FP16 tensor shape before quantization. + """ + + data: torch.Tensor + scale_128x4: torch.Tensor + global_scale: torch.Tensor + logical_scale_shape: Tuple[int, int] + original_shape: Tuple[int, ...] + + +def _round_up(x: int, multiple: int) -> int: + return ((int(x) + multiple - 1) // multiple) * multiple + + +def nvfp4_scale_128x4_offset( + row: torch.Tensor, + col: torch.Tensor, + scale_cols: int, +) -> torch.Tensor: + """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage. + + Parameters + ---------- + row : torch.Tensor + Logical row indices. + col : torch.Tensor + Logical scale-column indices. + scale_cols : int + Logical number of scale columns before padding to a multiple of 4. + + Returns + ------- + torch.Tensor + Flat offsets into the padded 128x4 tiled storage. + """ + + tiles_n = _round_up(scale_cols, 4) // 4 + tile_m = row // 128 + tile_n = col // 4 + outer = row % 128 + inner = col % 4 + return ( + (tile_m * tiles_n + tile_n) * 512 + + (outer % 32) * 16 + + (outer // 32) * 4 + + inner + ) + + +def swizzle_nvfp4_scale_to_128x4( + scale: torch.Tensor, + *, + rows: int, + cols: int, +) -> torch.Tensor: + """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout. + + Parameters + ---------- + scale : torch.Tensor + Logical rowwise scale tensor with at least shape ``[rows, cols]``. + rows : int + Number of logical rows to convert. + cols : int + Number of logical scale columns to convert. + + Returns + ------- + torch.Tensor + Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)`` + and swizzled into 128x4 tiled storage. + """ + + if scale.ndim != 2: + raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}") + + rows = int(rows) + cols = int(cols) + padded_rows = _round_up(rows, 128) + padded_cols = _round_up(cols, 4) + if scale.shape[0] < rows or scale.shape[1] < cols: + raise ValueError( + "scale is smaller than the requested logical shape: " + f"got {tuple(scale.shape)}, need at least {(rows, cols)}" + ) + + logical = scale[:rows, :cols].contiguous() + if logical.shape != (padded_rows, padded_cols): + logical = torch.nn.functional.pad( + logical.to(torch.float32), + (0, padded_cols - cols, 0, padded_rows - rows), + ).to(scale.dtype) + swizzled = torch.empty_like(logical) + + row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None] + col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1) + swizzled.reshape(-1)[offset] = logical.reshape(-1) + return swizzled + + +def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor: + """Compute TE NVFP4 tensor/global dequant scale from rowwise amax. + + Parameters + ---------- + amax : torch.Tensor + Rowwise absolute maxima returned by Transformer Engine. + + Returns + ------- + torch.Tensor + FP32 global scale equal to ``amax / (448 * 6)``. + """ + + return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX) + + +def _import_te_nvfp4_quantizer(): + try: + from transformer_engine.pytorch.tensor import NVFP4Quantizer + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "Transformer Engine NVFP4 quantization is unavailable. Install a " + "Transformer Engine build with its PyTorch dependencies, including " + "FlashAttention v3 when required by that TE build." + ) from exc + return NVFP4Quantizer + + +def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor: + """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine. + + TE returns rowwise scales in logical padded layout. This helper returns + the scales in physical 128x4 tiled storage, so the attention kernel can + load them with ``nvfp4_scale_128x4_offset``. + + Parameters + ---------- + x : torch.Tensor + CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16, + and the flattened row dimension ``prod(x.shape[:-1])`` must also be + divisible by 16. + + Returns + ------- + Nvfp4QuantizedTensor + Packed FP4 data, 128x4-swizzled block scales, global scale, and shape + metadata needed by the KVFP4 attention kernel or by reference + dequantization. + """ + + if not x.is_cuda: + raise ValueError("NVFP4 quantization requires a CUDA tensor") + if x.dtype not in (torch.bfloat16, torch.float16): + raise TypeError(f"x must be bf16 or fp16, got {x.dtype}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}") + if x.shape[-1] % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}" + ) + + rows = 1 + for dim in x.shape[:-1]: + rows *= int(dim) + if rows % NVFP4_BLOCK_SIZE != 0: + raise ValueError( + "flattened row dimension must be divisible by " + f"{NVFP4_BLOCK_SIZE}, got {rows}" + ) + + NVFP4Quantizer = _import_te_nvfp4_quantizer() + quantizer = NVFP4Quantizer(rowwise=True, columnwise=False) + qx = quantizer.quantize(x.contiguous()) + meta = qx.get_metadata() + + data = meta["rowwise_data"] + if data.dtype is not torch.uint8: + data = data.view(torch.uint8) + logical_scale = meta["rowwise_scale_inv"] + amax = meta["amax_rowwise"] + scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE + scale_128x4 = swizzle_nvfp4_scale_to_128x4( + logical_scale, + rows=rows, + cols=scale_cols, + ) + global_scale = nvfp4_global_scale_from_amax(amax).contiguous() + + return Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale_128x4, + global_scale=global_scale, + logical_scale_shape=(rows, scale_cols), + original_shape=tuple(int(v) for v in x.shape), + ) + + +def quantize_kv_bf16_to_nvfp4_128x4( + k: torch.Tensor, + v: torch.Tensor, +) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]: + """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention. + + Parameters + ---------- + k : torch.Tensor + CUDA BF16 or FP16 K tensor. + v : torch.Tensor + CUDA BF16 or FP16 V tensor. + + Returns + ------- + tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor] + Quantized K and V tensors with independent scales. + """ + + return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v) + + +def dequantize_nvfp4_128x4_to_bf16( + qx: Nvfp4QuantizedTensor, + *, + include_global_scale: bool = True, +) -> torch.Tensor: + """Reference dequantization for validation. + + This mirrors the kernel contract: + x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale + + Parameters + ---------- + qx : Nvfp4QuantizedTensor + Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``. + include_global_scale : bool, optional + If True, multiply by ``qx.global_scale`` after applying per-block + scales. + + Returns + ------- + torch.Tensor + BF16 tensor with shape ``qx.original_shape``. + """ + + data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8) + if data.shape[-1] * 2 != qx.original_shape[-1]: + raise ValueError( + "packed data last dimension does not match original shape: " + f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical" + ) + + rows, scale_cols = qx.logical_scale_shape + logical_dim = int(qx.original_shape[-1]) + if scale_cols * NVFP4_BLOCK_SIZE != logical_dim: + raise ValueError( + "logical scale columns do not match original last dimension: " + f"{scale_cols} scale cols vs dim {logical_dim}" + ) + + fp4_lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=data.device, + ) + packed = data.reshape(rows, logical_dim // 2) + lo = packed & 0x0F + hi = packed >> 4 + values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device) + values[:, 0::2] = fp4_lut[lo.long()] + values[:, 1::2] = fp4_lut[hi.long()] + + row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None] + col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, scale_cols) + scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols) + scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32) + scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1) + out = values * scale + if include_global_scale: + global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32) + out = out * global_scale + return out.reshape(qx.original_shape).to(torch.bfloat16) + + +def _example() -> None: + device = torch.device("cuda") + k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v) + print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype) + print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype) + print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype) + print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype) + print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype) + print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + raise RuntimeError("quantize.py requires CUDA") + _example() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/sparse_index_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/sparse_index_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54c982c9230b189051e3a0bdf76d22b397dd62a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/sparse_index_utils.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Host-side q2k <-> k2q index conversion for sparse attention. + +These utilities prepare sparse metadata on the Python side for tests, +benchmarks, and other offline preprocessing flows. They are not kernel +runtime helpers, so they intentionally live outside `src/common`. + +Sparse attention pattern: + - Each Q token independently selects up to topK KV blocks (blk_kv tokens each). + - Under GQA, all Q heads in one group share the same sparsity pattern, + so indices are defined at the head_kv level. + +Shapes: + q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks), + trailing unused slots padded with -1 + k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1 + k2q_counts: [batch, head_kv, Nkv] int32 + +CSR reverse-index format: + q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices + k2q_row_ptr: [head_kv, total_rows + 1] int32 + k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx +""" + +from typing import Optional, Tuple + +import torch + +from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100 + + +def q2k_to_k2q( + q2k_indices: torch.Tensor, + num_kv_blocks: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert q2k sparse indices to k2q representation. + + For each KV block, find which Q tokens attend to it. + + Args: + q2k_indices: [batch, head_kv, Sq, topK] int32. + For each Q token, the KV blocks it attends to. Unused slots must + be padded with -1. + num_kv_blocks: Total number of KV blocks (= Skv / blk_kv). + + Returns: + k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32. + For each KV block, the Q token indices that attend to it, + left-packed and padded with -1. Last dim fixed to Sq (upper bound). + k2q_counts: [batch, head_kv, num_kv_blocks] int32. + Actual number of Q tokens per KV block. + """ + B, H, Sq, topK = q2k_indices.shape + device = q2k_indices.device + N = Sq * topK + + kv_flat = q2k_indices.reshape(B, H, N).long() + valid_flat = kv_flat >= 0 + q_flat = ( + torch.arange(Sq, device=device) + .unsqueeze(-1) + .expand(Sq, topK) + .reshape(N) + .unsqueeze(0) + .unsqueeze(0) + .expand(B, H, N) + ) + + k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + k2q_counts.scatter_add_( + 2, + safe_kv_flat, + valid_flat.to(torch.int32), + ) + + sort_keys = torch.where( + valid_flat, + kv_flat, + torch.full_like(kv_flat, num_kv_blocks), + ) + sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True) + sorted_q = q_flat.gather(-1, sort_idx) + sorted_valid = valid_flat.gather(-1, sort_idx) + + offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device) + offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long() + + global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N) + group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1)) + pos_in_group = global_pos - group_offset + + k2q_indices = torch.full( + (B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device + ) + flat_k2q = k2q_indices.reshape(B, H, -1) + flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group + for b in range(B): + for h in range(H): + valid = sorted_valid[b, h] + flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int() + + return k2q_indices, k2q_counts + + +def k2q_to_q2k( + k2q_indices: torch.Tensor, + k2q_counts: torch.Tensor, + Sq: int, + topK: int, +) -> torch.Tensor: + """Convert dense k2q indices back to q2k representation. + + Parameters + ---------- + k2q_indices : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values + are Q token indices padded with ``-1``. + k2q_counts : torch.Tensor + Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of + valid Q indices per KV block. + Sq : int + Q sequence length per batch item in this dense reference format. + topK : int + Maximum number of KV blocks selected per Q token. + + Returns + ------- + torch.Tensor + Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted + by KV block index with ``-1`` padding at the tail. + """ + B, H, Nkv, _ = k2q_indices.shape + device = k2q_indices.device + + q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device) + counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device) + + for b in range(B): + for h in range(H): + for kv_blk in range(Nkv): + count = k2q_counts[b, h, kv_blk].item() + for j in range(count): + qt = k2q_indices[b, h, kv_blk, j].item() + if qt < 0: + continue + p = counters[b, h, qt].item() + if p < topK: + q2k[b, h, qt, p] = kv_blk + counters[b, h, qt] += 1 + + q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k) + _, sort_idx = q2k_sort_key.sort(dim=-1) + q2k = q2k.gather(-1, sort_idx) + return q2k + + +def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}") + if cu_seqlens.numel() < 1: + raise ValueError(f"{name} must have at least one element") + if not cu_seqlens.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor: + seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + return (seqlens_k + kv_block_size - 1) // kv_block_size + + +def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]: + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + batch = len(rows_per_batch_cpu) + max_rows = max(rows_per_batch_cpu, default=0) + row_dtype = ( + torch.int32 + if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max + else torch.int64 + ) + row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype) + row_linear = 0 + for kv_block_idx in range(max_rows): + for batch_idx, row_count in enumerate(rows_per_batch_cpu): + if kv_block_idx < row_count: + row_map_cpu[batch_idx, kv_block_idx] = row_linear + row_linear += 1 + return row_map_cpu.to(rows_per_batch.device), row_linear + + +def build_k2q_csr_torch_reference( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, +) -> tuple: + """Torch reference for q2k -> k2q CSR conversion. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32. Values are + batch-local KV block indices padded with ``-1``. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape + ``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape + ``[head_kv, total_q * topK]``. + """ + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError( + "q2k_indices must have shape [head_kv, total_q, topK], " + f"got {tuple(q2k_indices.shape)}" + ) + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + + head_kv, total_q, topk = q2k_indices.shape + if total_q != int(cu_seqlens_q[-1].item()): + raise ValueError( + f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] " + f"({int(cu_seqlens_q[-1].item())})" + ) + + rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size) + row_map, total_rows = _build_packed_row_map(rows_per_batch) + nnz_upper_bound = total_q * topk + + k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device) + k2q_q_indices = torch.full( + (head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device + ) + if total_rows == 0 or total_q == 0 or topk == 0: + return k2q_row_ptr, k2q_q_indices + + counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device) + total_entries = total_q * topk + row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64 + row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device) + q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device) + valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device) + rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist() + q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist() + entry_cursor = 0 + + for batch_idx, kv_rows in enumerate(rows_per_batch_cpu): + q_start = q_cu_cpu[batch_idx] + q_end = q_cu_cpu[batch_idx + 1] + q_len = q_end - q_start + if q_len == 0: + continue + num_entries = q_len * topk + q2k_batch = q2k_indices[:, q_start:q_end, :] + valid_batch = q2k_batch >= 0 + if valid_batch.any(): + max_valid_kv = int(q2k_batch[valid_batch].max().item()) + if max_valid_kv >= kv_rows: + raise ValueError( + f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, " + f"but that batch only has {kv_rows} logical kv blocks" + ) + kv_flat = q2k_batch.reshape(head_kv, num_entries).long() + valid_flat = valid_batch.reshape(head_kv, num_entries) + safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat)) + row_map_batch = row_map[batch_idx] + row_flat = row_map_batch[safe_kv_flat] + q_flat = ( + torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32) + .view(1, q_len, 1) + .expand(head_kv, q_len, topk) + .reshape(head_kv, num_entries) + ) + row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat + q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat + valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat + counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32)) + entry_cursor += num_entries + + k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32) + + sort_stride = max(total_q, 1) + invalid_key = total_rows * sort_stride + max_sort_key = invalid_key + max(total_q - 1, 0) + if max_sort_key < torch.iinfo(torch.int32).max: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32) + sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all] + else: + sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64) + sort_keys[valid_all] = ( + row_all[valid_all].to(torch.int64) * sort_stride + + q_all[valid_all].to(torch.int64) + ) + _, sort_idx = sort_keys.sort(dim=1, stable=True) + sorted_q = q_all.gather(1, sort_idx) + + valid_counts = valid_all.sum(dim=1) + write_mask = ( + torch.arange(total_entries, device=q2k_indices.device) + .unsqueeze(0) + .expand(head_kv, -1) + < valid_counts.unsqueeze(1) + ) + k2q_q_indices[write_mask] = sorted_q[write_mask] + + return k2q_row_ptr, k2q_q_indices + + +_K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100() + + +def build_k2q_csr( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, + *, + total_k: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]: + """Build the public k2q CSR reverse index on GPU. + + Runtime construction does not read device-side ``cu_seqlens`` on the host, + so callers must provide size hints such as ``total_k`` from already-known + tensor shapes. + + Parameters + ---------- + q2k_indices : torch.Tensor + Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are + batch-local KV block indices with trailing ``-1`` padding. + cu_seqlens_q : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths. + cu_seqlens_k : torch.Tensor + Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths. + kv_block_size : int + Number of KV tokens per sparse block. + total_k : int + Total KV token count. Required; normally ``k.shape[0]`` for dense KV + or ``sum(kv_segment_lens)`` for paged KV. + max_seqlen_k : int, optional + Maximum KV sequence length. Passing this avoids recomputing a bound. + max_seqlen_q : int, optional + Maximum Q sequence length. + total_rows : int, optional + Total number of packed KV-block rows across the batch. If omitted, + the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``. + qhead_per_kv : int, optional + Number of Q heads per KV head under GQA. + return_schedule : bool, optional + If True, also return the sparse forward schedule object produced by the + SM100 builder. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object] + ``(k2q_row_ptr, k2q_q_indices)`` or + ``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on + the same CUDA device as ``q2k_indices``. + """ + if total_k is None: + raise ValueError("build_k2q_csr requires total_k from k.shape[0]") + if kv_block_size <= 0: + raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}") + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}") + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]") + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q") + _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device") + return _K2Q_CSR_BUILDER( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + total_k=int(total_k), + blk_kv=int(kv_block_size), + max_seqlen_k=max_seqlen_k, + max_seqlen_q=max_seqlen_q, + total_rows=total_rows, + qhead_per_kv=qhead_per_kv, + return_schedule=return_schedule, + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfab1c4fc45dcec26ba7489b35624f5c46a698 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/aot_cache.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/aot_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..99fd0b4da4ddb6fba21bcb18c924f5e9e8b583e6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/aot_cache.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Persistent AOT cache for CuTe DSL compiled kernels. + +Saves compiled TVM FFI kernels as .o files on first compile, +loads them on subsequent runs to skip JIT compilation. + +Environment variables: + MM_SPARSE_ATTN_AOT_CACHE: Override cache directory + (default: ~/.cache/minfer/mm_sparse_attn) + MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely +""" + +import hashlib +import os +import time + +import cutlass.cute as cute + +_AOT_CACHE_DIR = os.environ.get( + "MM_SPARSE_ATTN_AOT_CACHE", + os.path.expanduser("~/.cache/minfer/mm_sparse_attn"), +) +_AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1" + +_loaded_modules: dict[str, object] = {} + + +def _key_to_path(key: tuple) -> str: + h = hashlib.sha256(repr(key).encode()).hexdigest()[:16] + name = str(key[0]).replace("/", "_") + return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}") + + +def try_load_aot(key: tuple): + if _AOT_DISABLE: + return None + obj_path = _key_to_path(key) + ".o" + if not os.path.isfile(obj_path): + return None + func_name = str(key[0]) + try: + if obj_path not in _loaded_modules: + _loaded_modules[obj_path] = cute.runtime.load_module( + obj_path, enable_tvm_ffi=True + ) + return getattr(_loaded_modules[obj_path], func_name) + except Exception as e: + print(f"[aot_cache] Failed to load {obj_path}: {e}") + return None + + +def save_aot(key: tuple, compiled) -> None: + if _AOT_DISABLE: + return + if not hasattr(compiled, "export_to_c"): + return + obj_path = _key_to_path(key) + ".o" + os.makedirs(_AOT_CACHE_DIR, exist_ok=True) + tmp_path = obj_path + f".tmp.{os.getpid()}" + func_name = str(key[0]) + try: + t0 = time.time() + compiled.export_to_c(tmp_path, function_name=func_name) + os.replace(tmp_path, obj_path) + dt = time.time() - t0 + print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)") + except Exception as e: + print(f"[aot_cache] Failed to save {func_name}: {e}") + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/barrier.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5753a8a175b529567e0be238f47fd4cc8401bf --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/barrier.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/blackwell_helpers.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd22f7efa3cef9988b4036c2d00fc1d3b9c816e8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/blackwell_helpers.py @@ -0,0 +1,1093 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +from . import mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, + num_unroll_groups: int = 1, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range( + cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups + ): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, + acc_tmem_addr, + rA, + rB, + sA_cur, + sB_cur, + zero_init=zero_init, + cta_group=cta_group, + **kwargs, + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + split_arrive: Optional[int] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = sA.iterator.type.swizzle_type + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = sB.iterator.type.swizzle_type + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + assert split_arrive is not None, ( + "split_arrive must be provided when mbar_ptr is not None" + ) + split_arrive_idx = split_arrive // op.shape_mnk[2] + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed( + acc_tmem_addr: Int32, + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_start_b: Int32, + idesc: int, + smem_desc_base_a: Optional[int], + smem_desc_base_b: int, + tCrA_layout: cute.Layout, + tCrB_layout: cute.Layout, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + else: + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] + + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + # smem_desc_start_a_lo = smem_desc_start_a + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(num_k_tile // 4 * 3, num_k_tile) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_smem_desc( + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_base_a: Optional[int], + tCrA_layout: cute.Layout, + var_name_prefix: str = "smem_desc", +) -> None: + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + smem_desc_base_a_lo, smem_desc_a_hi = None, None + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + if const_expr(not is_ts): + llvm.inline_asm( + None, + [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], + f".reg .b32 {var_name_prefix}_lo;\n\t" + f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" + f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + + "".join( + ( + f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" + f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" + ) + for k in range(1, num_k_tile) + ), + "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: + idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) + llvm.inline_asm( + None, + [], + f".reg .b32 {var_name};\n\t" # noqa + f"mov.b32 {var_name}, {hex(idesc)};\n\t", + constraints="", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed_varname( + acc_tmem_addr: Int32, + smem_desc_start_b: Int32, + # idesc: int, + smem_desc_base_b: int, + tCrB_layout: cute.Layout, + smem_var_name_prefix: str, + idesc_var_name: str, + smem_offset: int, + zero_init: bool | Boolean = False, + cta_group: int = 1, + mma_kind: str = "f16", +) -> None: + is_ts = False + num_k_tile = cute.size(tCrB_layout.shape[2]) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + # ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + # ".reg .b64 smem_desc_b;\n\t" + f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + # f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $2;\n\t" + "mov.b32 smem_desc_b_lo_start, $0;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + + "".join( + ( + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + ) + for k in range(1, num_k_tile) + ) + + "setp.ne.b32 p, $1, 0;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + + "".join( + ( + # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/block_info.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/block_info.py new file mode 100644 index 0000000000000000000000000000000000000000..463290ab3b022a8883e7d40b84ff1ab31827e5dc --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/block_info.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...src.common.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n)) + n_block_min = 0 + if num_splits > 1: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) + return n_block_min, n_block_max + + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_block_max = cute.ceil_div( + seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m + ) + m_block_min = 0 + if const_expr(self.is_causal): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx *= self.qhead_per_kvhead_packgqa + m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m) + return m_block_min, m_block_max diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/copy_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/copy_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98ba5f40b7b9543744e663a96bcdf637c7e2a146 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/copy_utils.py @@ -0,0 +1,1179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Copy, store, and layout execution helpers. + +`copy_utils.py` is the canonical owner for generic copy primitives, async +bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers. +""" + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +# Generic Copy Primitives + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +# Store/Layout Helpers + +@dsl_user_op +def atomic_add_i32(gmem_ptr, *, loc=None, ip=None): + """Simple atomicAdd. Intended for use under a single-thread guard.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "atom.global.add.u32 $0, [$1], 1;\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None): + """Lane-0 atomicAdd broadcast to the whole warp via shfl.""" + result = llvm.inline_asm( + T.i32(), + [gmem_ptr.toint().ir_value(loc=loc, ip=ip)], + "{\n" + ".reg .pred p;\n" + ".reg .u32 lane, r;\n" + "mov.u32 lane, %laneid;\n" + "mov.u32 r, 0;\n" + "setp.eq.u32 p, lane, 0;\n" + "@p atom.global.add.u32 r, [$1], 1;\n" + "shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n" + "mov.u32 $0, r;\n" + "}\n", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Int32(result) + + +@dsl_user_op +def stg_128( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; " + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.bf16.f32 h0, $5;\n" + "cvt.rn.bf16.f32 h1, $6;\n" + "cvt.rn.bf16.f32 h2, $7;\n" + "cvt.rn.bf16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_64_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3;\n" + ".reg .b32 p0, p1;\n" + "cvt.rn.f16.f32 h0, $5;\n" + "cvt.rn.f16.f32 h1, $6;\n" + "cvt.rn.f16.f32 h2, $7;\n" + "cvt.rn.f16.f32 h3, $8;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "st.global.v2.b32 [$4], {p0, p1};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_32_fp8_e4m3( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n" + "mov.b32 p0, {h0, h1};\n" + "st.global.b32 [$4], p0;\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;", + "=f,=f,=f,=f,l,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_bf16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two bf16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .b16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.bf16.f32 h0, $1;\n" + "cvt.rn.bf16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def sts_32_f16( + smem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + *, + loc=None, + ip=None, +): + """Store two fp16 values to shared memory as one 32-bit transaction.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + ".reg .f16 h0, h1;\n" + ".reg .b32 p0;\n" + "cvt.u32.u64 sa, $0;\n" + "cvt.rn.f16.f32 h0, $1;\n" + "cvt.rn.f16.f32 h1, $2;\n" + "mov.b32 p0, {h0, h1};\n" + "st.shared.b32 [sa], p0;\n" + "}\n", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_bf16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.bf16.f32 h0, $9;\n" + "cvt.rn.bf16.f32 h1, $10;\n" + "cvt.rn.bf16.f32 h2, $11;\n" + "cvt.rn.bf16.f32 h3, $12;\n" + "cvt.rn.bf16.f32 h4, $13;\n" + "cvt.rn.bf16.f32 h5, $14;\n" + "cvt.rn.bf16.f32 h6, $15;\n" + "cvt.rn.bf16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_f16_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.f16.f32 h0, $9;\n" + "cvt.rn.f16.f32 h1, $10;\n" + "cvt.rn.f16.f32 h2, $11;\n" + "cvt.rn.f16.f32 h3, $12;\n" + "cvt.rn.f16.f32 h4, $13;\n" + "cvt.rn.f16.f32 h5, $14;\n" + "cvt.rn.f16.f32 h6, $15;\n" + "cvt.rn.f16.f32 h7, $16;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;", + "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stg_128_fp8_e4m3_cs( + gmem_ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + *, + loc=None, + ip=None, +): + llvm.inline_asm( + llvm.StructType.get_literal( + [ + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + T.f32(), + ] + ), + [ + gmem_ptr.toint().ir_value(loc=loc, ip=ip), + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + Float32(v8).ir_value(loc=loc, ip=ip), + Float32(v9).ir_value(loc=loc, ip=ip), + Float32(v10).ir_value(loc=loc, ip=ip), + Float32(v11).ir_value(loc=loc, ip=ip), + Float32(v12).ir_value(loc=loc, ip=ip), + Float32(v13).ir_value(loc=loc, ip=ip), + Float32(v14).ir_value(loc=loc, ip=ip), + Float32(v15).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n" + ".reg .b32 p0, p1, p2, p3;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n" + "cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n" + "cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n" + "cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n" + "cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n" + "cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n" + "cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n" + "mov.b32 p0, {h0, h1};\n" + "mov.b32 p1, {h2, h3};\n" + "mov.b32 p2, {h4, h5};\n" + "mov.b32 p3, {h6, h7};\n" + "st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n" + "}\n" + "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; " + "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; " + "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; " + "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; " + "mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; " + "mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; " + "mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; " + "mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;", + ( + "=f,=f,=f,=f,=f,=f,=f,=f," + "=f,=f,=f,=f,=f,=f,=f,=f," + "l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f" + ), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + *acc_layout_col_major.shape[2:], + ), + stride=( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + *acc_layout_col_major.stride[2:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor: + layout = convert_layout_acc_mn( + convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout) + ) + return cute.make_tensor(tensor.iterator, layout) + + +def real_col_to_stg128_fake_col(col: Int32) -> Int32: + nt = col // Int32(16) + col16 = col - nt * Int32(16) + pair = col16 // Int32(2) + rank = pair % Int32(4) + kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2)) + return nt * Int32(16) + rank * Int32(4) + kv + + +def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(16) + fake16 = fake_col - nt * Int32(16) + rank = fake16 // Int32(4) + kv = fake16 % Int32(4) + return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2)) + + +def real_col_to_stg128_half_fake_col(col: Int32) -> Int32: + nt = col // Int32(32) + col32 = col - nt * Int32(32) + lane = (col32 % Int32(8)) // Int32(2) + group = col32 // Int32(8) + elem = col32 % Int32(2) + return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem + + +def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(32) + fake32 = fake_col - nt * Int32(32) + lane = fake32 // Int32(8) + lane_slot = fake32 - lane * Int32(8) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem + + +def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32: + nt = col // Int32(64) + col64 = col - nt * Int32(64) + lane = (col64 % Int32(8)) // Int32(2) + group = col64 // Int32(8) + elem = col64 % Int32(2) + return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem + + +def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32: + nt = fake_col // Int32(64) + fake64 = fake_col - nt * Int32(64) + lane = fake64 // Int32(16) + lane_slot = fake64 - lane * Int32(16) + group = lane_slot // Int32(2) + elem = lane_slot - group * Int32(2) + return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem + + +# Cluster & Bulk Async Ops + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_s2cluster( + smem_src_ptr: cute.Pointer, + smem_dst_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + size: int | Int32, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +): + smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() + smem_dst_ptr_i32 = set_block_rank( + smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [ + smem_dst_ptr_i32, + smem_src_ptr_i32, + mbar_ptr_i32, + Int32(size).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +# TMA Copy Adapters + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +__all__ = [ + "atomic_add_broadcast_i32", + "atomic_add_fp32x4", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "copy", + "cpasync_bulk_g2s", + "cpasync_bulk_get_copy_fn", + "cpasync_bulk_s2cluster", + "cpasync_reduce_bulk_add_f32", + "cvt_copy", + "get_copy_atom", + "load_s2r", + "make_16x256b_tensor_mn_view", + "make_tmem_copy", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "set_block_rank", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "sts_32_bf16", + "sts_32_f16", + "store_shared_remote_fp32x4", + "tiled_copy_1d", + "tiled_copy_2d", + "tma_get_copy_fn", + "tma_producer_copy_fn", +] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/cute_dsl_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/cute_dsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3473fbbf77fa1261abfc8fd960102c70d3e64bd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/cute_dsl_utils.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import logging +import os +import pathlib +import time +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +logger = logging.getLogger("minimax") + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta +from cutlass.cute.runtime import from_dlpack + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile. + + Behaviour: + - Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set. + - Logs JIT compile wall time at DEBUG level via the ``minimax`` logger, + tagged with the kernel's class name when available. Enable with + ``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env + ``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT + (~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md). + """ + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + kernel_obj = args[0] if args else kwargs.get("op") + kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "" + t0 = time.time() + output = cute_compile_og(*args, **kwargs) + dt = time.time() - t0 + logger.debug("[%s] compiled in %.1fs", kernel_name, dt) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1": + if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + logger.addHandler(_h) + logger.setLevel(logging.DEBUG) + + +# Monkey-patch cute.compile so every JIT compile across the repo gets timed +# without touching individual call sites. Idempotent: only patches once. +if cute.compile is not cute_compile_patched: + cute.compile = cute_compile_patched + + +def assume_strides_aligned(t): + """Assume all strides except the last are divisible by 128 bits. + + Python int strides (e.g., stride=0 from GQA expand) are kept as-is + since they're static and don't need alignment assumptions. + """ + divby = 128 // t.element_type.width + strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1]) + return (*strides, t.stride[-1]) + + +def assume_tensor_aligned(t): + """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.""" + if t is None: + return None + return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t))) + + +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/fast_math.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/fast_math.py new file mode 100644 index 0000000000000000000000000000000000000000..63a8b4a501ac499e372056a07d499832c830b474 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/fast_math.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/mask.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0da42c3be9bf1c3dcff81ccde579b54131bfa4c6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/mask.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Callable, Optional, TypeAlias +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Uint32, const_expr + +from ...src.common import utils as utils +from ...src.common.seqlen_info import SeqlenInfoQK + +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + + +@cute.jit +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) + + +@cute.jit +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) + + +@cute.jit +def mask_r2p_lambda( + X: cute.Tensor, + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, +) -> None: + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)): + mask = mask_gen_fn(s) + for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * MASK_R2P_CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) + + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_info: SeqlenInfoQK + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + swap_AB: cutlass.Constexpr[bool] = False + + @property + def seqlen_q(self) -> Int32: + return self.seqlen_info.seqlen_q + + @property + def seqlen_k(self) -> Int32: + return self.seqlen_info.seqlen_k + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + row_idx: Optional[Int32] = None, + kv_valid_cols: Optional[Int32] = None, + kv_block_col_start: Optional[Int32] = None, + ) -> None: + if const_expr(not mask_seqlen and not mask_causal): + return + + col_limit = Int32(self.tile_n) + if const_expr(mask_seqlen): + if const_expr(kv_valid_cols is not None): + col_limit = kv_valid_cols + else: + col_limit = self.seqlen_k - n_block * Int32(self.tile_n) + + if const_expr(mask_causal): + if const_expr(row_idx is None): + row_axis = 0 if const_expr(not self.swap_AB) else 1 + row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m) + if const_expr(self.qhead_per_kvhead_packgqa > 1): + row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa) + else: + row_idx_cur = row_idx + if const_expr(kv_block_col_start is not None): + block_col_start = kv_block_col_start + else: + block_col_start = n_block * Int32(self.tile_n) + causal_col_limit = ( + row_idx_cur + self.seqlen_k - self.seqlen_q + - block_col_start + Int32(1) + ) + col_limit = ( + cutlass.min(col_limit, causal_col_limit) + if const_expr(mask_seqlen) + else causal_col_limit + ) + + if col_limit < Int32(self.tile_n): + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit, s), + rank1=True, + ) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + is_full_block: bool = False, + check_m_boundary: bool = True, + valid_tok_count: Optional[Int32] = None, + q_idx_tile: Optional[cute.Tensor] = None, + masked_tok_count: Optional[Int32] = None, + ) -> None: + del is_full_block, check_m_boundary + del t0ScS_t2r + row_axis = 0 if const_expr(not self.swap_AB) else 1 + col_axis = 1 if const_expr(not self.swap_AB) else 0 + + if const_expr(valid_tok_count is not None): + kv_block_col_start = n_block * Int32(self.tile_n) + causal_q_offset = self.seqlen_k - self.seqlen_q + nfrag = const_expr(cute.size(acc_S.shape)) + for i in cutlass.range(nfrag, unroll_full=True): + row_idx = tScS_t2r[i][row_axis] + tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa) + acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i] + if const_expr(mask_seqlen): + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i] + if const_expr(mask_causal): + if const_expr(q_idx_tile is not None): + causal_tok_count = ( + masked_tok_count + if const_expr(masked_tok_count is not None) + else Int32(0) + ) + if tok_idx < causal_tok_count: + q_idx = q_idx_tile[tok_idx] + kv_idx = kv_block_col_start + tScS_t2r[i][col_axis] + acc_S[i] = ( + -Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i] + ) + return + + thr_col_offset = tScS_t2r[0][col_axis] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + + if const_expr(not mask_causal): + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + return + + thr_row_offset = tScS_t2r[0][row_axis] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + row_limit_top = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_seqlen) and seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + num_rep = cute.size(tScS_t2r, mode=[0]) + row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/mma_sm100_desc.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/mma_sm100_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..53c58d17f5085d207f2a1d7b6b45d627ff3322e3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/mma_sm100_desc.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT +# +# The bit-field encodings, enum values, and descriptor layout below mirror the +# SM100 tcgen05 MMA instruction descriptor as documented and +# implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay +# identical to the hardware/ISA encodings; see the "Third-party licenses" +# section of README.md at the repo root for attribution. + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix "layout" in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type -> encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 + if cutlass_type is cutlass.Float8E4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.Float8E5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for SM100 MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + # CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this + # descriptor aligned with generated/reference SM100 FP8 kernels. + desc |= (int(is_f8f6f4) & 0x1) << 23 + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 + + +def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: + sA_swizzle = sA.iterator.type.swizzle_type + return make_smem_desc_base( + cute.recast_layout(128, sA.element_type.width, sA.layout[0]), + sA_swizzle, + major, + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/named_barrier.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/named_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..a7722a471ca011a94d5fd7774224906001979b78 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/named_barrier.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import enum + + +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + LoadWG = enum.auto() + StoreEpilogue = enum.auto() + KvLoad = enum.auto() + KvDequantK = enum.auto() + KvDequantV = enum.auto() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/pack_gqa.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/pack_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dc25edd3f48fbe2c77ec94c8ab3f1ea417507 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/pack_gqa.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""PackGQA primitives for GQA (grouped-query attention) tile layouts. + +Contains: +- ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead`` + into the seqlen dimension of a tensor layout (zero-copy view). +- ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / + ``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)`` + as a single packed row dimension. +- ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE`` + for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q + innermost (stride-1). +""" + +from dataclasses import dataclass +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ...quack import layout_utils + +from . import utils + + +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass +class PackGQA: + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + + +@dataclass +class PackGQAComb(PackGQA): + """PackGQA subclass for the K2 combine kernel. + + Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from + ``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when + LSE_partial is laid out with H_q innermost. + + K2 combine treats each query head independently (no GQA grouping in combine + itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller — + all heads are folded into one "group" per Sq position. + """ + + @cute.jit + def load_LSE( + self, + mLSE_partial: cute.Tensor, + # Packed layout after caller-side reshape: + # shape ((qhead_per_kvhead, seqlen_q), num_splits) + # stride ((1, qhead_per_kvhead), ...) + # — H_q is the innermost (stride-1) element of the packed first dim. + sLSE: cute.Tensor, + # SMEM destination: ``(topk, m_block_size)`` fp32. + topk: cutlass.Constexpr[int], + # Explicit topk so the identity tensor shape is a plain int, + # avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape. + gmem_tiled_copy: cute.TiledCopy, + tidx: Int32, + block: Int32, + num_splits: Int32, + seqlen: Int32, + num_heads_divmod: FastDivmodDivisor, + mCounter: Optional[cute.Tensor] = None, + batch_idx: Optional[Int32] = None, + qhead_per_kvhead: Int32 = Int32(1), + # divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so + # caller controls whether the divisor is constexpr or a runtime value. + ): + """Coalesced GMEM→SMEM async load of LSE_partial for one tile. + + For each (split, row) slot this thread owns in the tile, compute the + GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32. + Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``) + are filled with ``-inf`` so they flow cleanly through downstream reductions. + + Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values + (head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to + adjacent GMEM addresses when H_q is stride-1 — one sector per warp. + """ + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cLSE = cute.make_identity_tensor((topk, self.m_block_size)) + tLSEcLSE = gmem_thr_copy.partition_S(cLSE) + tLSEsLSE = gmem_thr_copy.partition_D(sLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = block * self.m_block_size + mi + m_pos, h_pos = divmod(idx, num_heads_divmod) + + if m_pos < seqlen: + row_count = ( + mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead] + if const_expr(mCounter is not None) + else num_splits + ) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + # Build a 1-element GMEM tensor at ((h_pos, m_pos), si), + # matching PackGQA.store_LSE's ptr pattern so cute.copy + # receives a proper Tensor, not a scalar. + src_ptr_i64 = utils.elem_pointer( + mLSE_partial, ((h_pos, m_pos), si)).toint() + src_ptr = cute.make_ptr( + Float32, src_ptr_i64, + cute.AddressSpace.gmem, assumed_align=4, + ) + src_t = cute.make_tensor(src_ptr, (1,)) + cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/paged_kv.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/paged_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5f6923c42a826d4f3dd1f192ce2fdb38eefbf5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/paged_kv.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + + +@dataclass(frozen=True) +class PagedKVManager: + mPageTable: cute.Tensor + page_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + + @staticmethod + def create( + mPageTable: cute.Tensor, + *, + page_size: int, + n_block_size: int, + ): + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + return PagedKVManager( + mPageTable, + page_size=page_size, + n_block_size=n_block_size, + ) + + @cute.jit + def logical_length( + self, + batch_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + if const_expr(mSeqUsedK is not None): + return mSeqUsedK[batch_idx] + return num_kv_blocks * Int32(self.n_block_size) + + @cute.jit + def valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + num_kv_blocks: Int32, + mSeqUsedK=None, + ) -> Int32: + seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def physical_block_index( + self, + batch_idx: Int32, + kv_block_idx: Int32, + ) -> Int32: + return self.mPageTable[batch_idx, kv_block_idx] + +__all__ = ["PagedKVManager"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/pipeline.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..27f711772f5c6fa16a86f4aa305f42a0ca9322eb --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/pipeline.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate, dsl_user_op +from cutlass.pipeline import PipelineState +from cutlass.pipeline import PipelineUserType +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg +import cutlass.pipeline as cutlass_pipeline + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """Compatibility wrapper for FA-style helpers now vendored into src.common.""" + return cutlass_pipeline.make_pipeline_state(type, stages) + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + @staticmethod + def create(*args, **kwargs): + obj = NamedBarrierOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", NamedBarrier) + return obj + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +@dataclass(frozen=True) +class PipelineAsync(PipelineAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineAsync + object.__setattr__(obj, "__class__", PipelineAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_try_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + *, + loc=None, + ip=None, + ): + return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineTmaAsyncOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineTmaAsync) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineTmaUmma + object.__setattr__(obj, "__class__", PipelineTmaUmma) + return obj + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive( + state.index, self.producer_mask, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx( + state.index, tx_count, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineUmmaAsyncOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineUmmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineUmmaAsync) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsyncUmmaOg): + @staticmethod + def create(*args, **kwargs): + obj = PipelineAsyncUmmaOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", PipelineAsyncUmma) + return obj + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/seqlen_info.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/seqlen_info.py new file mode 100644 index 0000000000000000000000000000000000000000..873304f71c2cb47ffdd1453fe771c754783f51a4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/seqlen_info.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from ...quack import copy_utils + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + + +@dataclass(frozen=True) +class SeqlenInfo: + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + batch_idx: Int32, + seqlen_static: Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, + ): + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) + ) + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] + else: + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + return SeqlenInfoQK( + offset_q, + offset_k, + padded_offset_q, + padded_offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, + ) + + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + else: + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) + + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) + else: + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/softmax.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8f94c1c9e40aeb44c0a128165d90a502feb04afd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/softmax.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Online softmax primitives. + +Contains: +- ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O. + The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed + ``fmul.f32x2`` (2× CUDA-core throughput) when available. +- ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``, + ``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel. +""" + +import math +import operator +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from ...quack import layout_utils +from ...quack.cute_dsl_utils import ParamsBase + +from . import utils + + +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is + rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL + compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU + confirms: FFMA2 count is 0 for the TensorSSA path). The packed + rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA + instruction count for the softmax scale/subtract stage. + """ + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + row_scale = cute.make_rmem_tensor_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + + for r in cutlass.range(cute.size(row_max), unroll_full=True): + acc_S_row_slice = acc_S_mn[r, None] + acc_S_row = acc_S_row_slice.load() + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + + row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4) + row_max_prev = row_max[r] + row_max[r] = row_max_cur + + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + + row_max_cur_scaled = row_max_cur * scale_log2 + minus_row_max_scaled = -row_max_cur_scaled + n = cute.size(acc_S_row_slice) + + if cutlass.const_expr(arch >= 100 and n % 2 == 0): + # SM100 packed f32x2 FMA path: scale + subtract in one pass. + for i in cutlass.range(0, n, 2, unroll_full=True): + acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row_slice[i], acc_S_row_slice[i + 1]), + (scale_log2, scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + for i in cutlass.range(n, unroll_full=True): + acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True) + acc_S_row_exp = acc_S_row_slice.load() + else: + acc_S_row_exp = cute.math.exp2( + acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True + ) + acc_S_row_slice.store(acc_S_row_exp) + + if cutlass.const_expr(is_first): + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) + row_scale[r] = 1.0 + else: + row_scale[r] = cute.math.exp2( + (row_max_prev - row_max_cur) * scale_log2, fastmath=True + ) + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_sum[r] = acc_S_row_sum + + return row_scale + + @cute.jit + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp. + + On SM100+ with an even ``num_rows`` and no sink_val, the loop is + unrolled in pairs so the key per-row arithmetic ― rcp*final_scale, + max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one + ``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more + ``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare). + """ + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_rmem_tensor_like(row_max, Float32) + + LN2 = math.log(2.0) + num_rows = cute.size(row_sum) + use_packed = cutlass.const_expr( + self.arch >= 100 and num_rows % 2 == 0 and sink_val is None + ) + + if use_packed: + for r in cutlass.range(0, num_rows, 2, unroll_full=True): + s0 = row_sum[r] + s1 = row_sum[r + 1] + m0 = row_max[r] + m1 = row_max[r + 1] + bad0 = s0 == 0.0 or s0 != s0 + bad1 = s1 == 0.0 or s1 != s1 + + # row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar + # (no packed rcp intrinsic); the trailing multiply packs. + rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0) + rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1) + row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2( + (rcp0, rcp1), (final_scale, final_scale) + ) + + # LSE = (row_max * scale_log2 + log2(row_sum)) * LN2 + # packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2. + log0 = cute.math.log2(s0, fastmath=True) + log1 = cute.math.log2(s1, fastmath=True) + lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2( + (m0, m1), (scale_log2, scale_log2), (log0, log1) + ) + lse_0, lse_1 = cute.arch.mul_packed_f32x2( + (lse_pre_0, lse_pre_1), (LN2, LN2) + ) + row_sum[r] = -Float32.inf if bad0 else lse_0 + row_sum[r + 1] = -Float32.inf if bad1 else lse_1 + else: + for r in cutlass.range(num_rows, unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) + + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + row_scale[r] = ( + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = row_sum[r] + row_sum[r] = ( + (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor.""" + acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + n = cute.size(acc_O_mn, mode=[1]) + if cutlass.const_expr(self.arch >= 100 and n % 2 == 0): + # SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput). + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + scale = row_scale[r] + for j in cutlass.range(0, n, 2, unroll_full=True): + acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2( + (acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale) + ) + else: + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + """SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths.""" + + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + @cute.jit + def update_row_max_deferred_exp2( + self, + acc_S_row: cute.TensorSSA, + is_first: int, + ) -> Tuple[Float32, Float32]: + """update_row_max variant that publishes the log2-delta (un-exp2'd) so + the consumer can do the exp2 only when an actual rescale fires. + + Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where: + - ``row_max_safe`` is the same row-max as ``update_row_max`` (with + ``rescale_threshold`` rollback applied). + - ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when + the threshold rollback fired (consumer treats as no rescale), else + the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2`` + (consumer computes ``cute.math.exp2`` and rescales). + + This keeps MUFU.EX2 off the sm_stats publication critical path that + gates the correction WG's consumer wait. + """ + publish = Float32(0.0) + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + # publish stays 0.0 (signal: no rescale needed) + else: + publish = acc_scale_ + else: + publish = acc_scale_ + self.row_max[0] = row_max_new + return row_max_safe, publish + + @cute.jit + def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0]) + self.row_max[0] = row_max_new + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + + @cute.jit + def compute_scaled_exp2_row_sum( + self, + acc_S_row: cute.Tensor, + scale: Float32, + ) -> Float32: + return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch) + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + else: + if cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True + ) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert_sum( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + init_sum: Float32, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, + ) -> Float32: + # When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res + # pairs in the inner loop use the FFMA2-based polynomial ex2 emulation + # (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's + # apply_exp2_convert. This removes the MUFU "wait" stall that dominates + # the second-largest stall bucket in decode (~22% of total). + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + acc_sum = (init_sum, Float32(0.0)) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2( + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + if cutlass.const_expr(ex2_emu_freq == 0): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + use_real = cutlass.const_expr( + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg + ) + if cutlass.const_expr(use_real): + acc_S_row_frg[k, j] = cute.math.exp2( + acc_S_row_frg[k, j], fastmath=True) + acc_S_row_frg[k + 1, j] = cute.math.exp2( + acc_S_row_frg[k + 1, j], fastmath=True) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + utils.ex2_emulation_2( + acc_S_row_frg[k, j], + acc_S_row_frg[k + 1, j], + ) + ) + acc_sum = cute.arch.add_packed_f32x2( + acc_sum, + (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + return acc_sum[0] + acc_sum[1] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/tile_scheduler.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..985b4289e146288355dfecd7169383eb64df4f09 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/tile_scheduler.py @@ -0,0 +1,967 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable +from dataclasses import dataclass + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams + +from ...quack.cute_dsl_utils import ParamsBase + +from ...src.common import utils as utils +from ...src.common.fast_math import clz + + +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `SparseAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + args.use_cluster_idx, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) + return ( + grid_x, + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_cluster_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks_cluster: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) + total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(num_block_cluster), + FastDivmodDivisor(args.num_head), + total_blocks_cluster, + cluster_shape_m=args.cluster_shape_mn[0], + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": + if const_expr(cute.size(params.cluster_shape_m) == 1): + tile_idx = cute.arch.block_idx()[0] + else: + tile_idx = cute.arch.cluster_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + usable_SM_count=0, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + cluster_shape_m = int(params.cluster_shape_m) + if usable_SM_count > 0: + sm_count = usable_SM_count + else: + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m + max_ctas = max(max_ctas, cluster_shape_m) + grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks_cluster + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.cluster_shape_m == 1): + self._tile_idx += cute.arch.grid_dim()[0] + else: + self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + num_head: Int32 + num_batch: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileLPTScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # Seems faster if swizzle is a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), + is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + if const_expr(params.lpt): + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self.clc = clc + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + # Round down to nearest multiple of cluster since odd excess is always padding. + total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx // params.cluster_shape_m + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt or params.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + * params.cluster_shape_m + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + if cutlass.const_expr(params.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_m + bidx_in_cluster[0] + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + def prefetch_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work + self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/tma_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/tma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdc19a08eacf9a060f2c0a7a4d50a4adb735094 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/tma_utils.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Raw TMA ops and descriptor builders. + +`tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA +descriptor construction. Non-TMA store/layout helpers are re-exported from +`copy_utils.py` for backward compatibility. +""" + +import ctypes + +from cutlass import Int32, Int64 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as cute_ir +import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir +from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen + + +# Raw TMA Ops + +TMA_CACHE_EVICT_FIRST = 0x12F0000000000000 +TMA_CACHE_EVICT_LAST = 0x14F0000000000000 + + +@dsl_user_op +def tma_tile_load( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $9;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None): + """Prefetch a raw TMA descriptor pointer into the descriptor cache.""" + ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip) + ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width) + ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip) + ptr_ty = cute_ir.PtrType.get( + cute_nvgpu_ir.TmaDescriptorTiledType.get(), + cute_ir.AddressSpace.gmem, + 128, + ) + desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip) + cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def tma_tile_prefetch( + tma_desc_ptr, + col_idx, + row_idx, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint " + "[$0, {$1, $2}], $3;\n", + "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_prefetch( + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint " + "[$0, {$1, $2, $3, $4, $5}], $6;\n", + "l,r,r,r,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_load_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row_idx, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_FIRST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar.""" + llvm.inline_asm( + T.i32(), + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $1;\n" + "add.u32 sa, sa, $2;\n" + "cvt.u32.u64 ma, $6;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile" + ".mbarrier::complete_tx::bytes.L2::cache_hint " + "[sa], [$3, {$4, $5}], [ma], $7;\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,l,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_gather4_cached( + smem_ptr, + smem_byte_offset, + tma_desc_ptr, + col_idx, + row0, + row1, + row2, + row3, + mbar_ptr, + cache_hint=TMA_CACHE_EVICT_LAST, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint.""" + llvm.inline_asm( + None, + [ + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row0).ir_value(loc=loc, ip=ip), + Int32(row1).ir_value(loc=loc, ip=ip), + Int32(row2).ir_value(loc=loc, ip=ip), + Int32(row3).ir_value(loc=loc, ip=ip), + mbar_ptr.toint().ir_value(loc=loc, ip=ip), + Int64(cache_hint).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa, ma;\n" + "cvt.u32.u64 sa, $0;\n" + "add.u32 sa, sa, $1;\n" + "cvt.u32.u64 ma, $8;\n" + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4" + ".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint " + "[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n" + "}\n", + "l,r,l,r,r,r,r,r,l,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_tile_store( + tma_desc_ptr, + col_idx, + row_idx, + smem_ptr, + smem_byte_offset, + *, + loc=None, + ip=None, +): + """cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store.""" + llvm.inline_asm( + T.i32(), + [ + tma_desc_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(col_idx).ir_value(loc=loc, ip=ip), + Int32(row_idx).ir_value(loc=loc, ip=ip), + smem_ptr.toint().ir_value(loc=loc, ip=ip), + Int32(smem_byte_offset).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .u32 sa;\n" + "cvt.u32.u64 sa, $4;\n" + "add.u32 sa, sa, $5;\n" + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + " [$1, {$2, $3}], [sa];\n" + "mov.u32 $0, 0;\n" + "}\n", + "=r,l,r,r,l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +# Descriptor Builders + +_TMA_DESC_BYTES = 128 + + +def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes: + import torch + import cuda.bindings.driver as cuda + + if tensor_2d.ndim != 2: + raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}") + rows, cols = tensor_2d.shape + if tensor_2d.stride(-1) != 1: + raise ValueError(f"{context} tensor must be contiguous in the last dimension") + dtype_map = { + torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + } + if tensor_2d.dtype not in dtype_map: + raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}") + + sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)] + strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())] + box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)] + elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)] + err, tm = cuda.cuTensorMapEncodeTiled( + dtype_map[tensor_2d.dtype], + 2, + tensor_2d.data_ptr(), + sizes, + strides, + box, + elem_stride, + cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, + cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + ) + assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}" + buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr()) + return bytes(buf) + + +def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device): + import torch + + desc_bytes = bytes(desc_bytes) + device = torch.device(device) + if device.type != "cuda": + raise ValueError(f"TMA descriptors require a CUDA device, got {device}") + + host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True) + host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8)) + device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device) + stream = torch.cuda.current_stream(device) + with torch.cuda.stream(stream): + device_desc.copy_(host_desc, non_blocking=True) + device_desc.record_stream(stream) + # Keep the staging buffer alive for the async copy without caching descriptors. + device_desc._tma_host_desc = host_desc + return device_desc + + +def create_flat_gather4_tma_desc(tensor_2d, box_x=64): + """Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor.""" + if tensor_2d.ndim != 2: + raise ValueError( + f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}" + ) + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=1, + context="gather4", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_q_gather4_tma_desc(q_flat, box_x=64): + return create_flat_gather4_tma_desc(q_flat, box_x=box_x) + + +def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y): + """Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride.""" + desc = _encode_tma_desc_2d_bytes( + tensor_2d, + box_x=box_x, + box_y=box_y, + context="strided 2D", + ) + return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device) + + +def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128): + """Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage.""" + import torch + + if kv_flat.ndim != 3: + raise ValueError( + f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}" + ) + total_k, head_kv, dim = kv_flat.shape + row_stride = head_kv * dim + desc_table = bytearray() + for h in range(head_kv): + head_view = torch.as_strided( + kv_flat, + size=(total_k, dim), + stride=(row_stride, 1), + storage_offset=h * dim, + ) + desc_table.extend( + _encode_tma_desc_2d_bytes( + head_view, + box_x=box_x, + box_y=box_y, + context="flat KV", + ) + ) + return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape( + head_kv, _TMA_DESC_BYTES + ) + + +# Compatibility Re-exports + +from .copy_utils import ( + atomic_add_broadcast_i32, + atomic_add_i32, + convert_layout_acc_mn, + convert_layout_from_tmem16x256b_to_acc_sm90, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, + stg_128, + stg_128_cs, + stg_128_bf16, + stg_128_bf16_cs, + stg_128_f16, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, + stg_32_fp8_e4m3, + stg_64_bf16, + stg_64_f16, +) + + +__all__ = [ + "TMA_CACHE_EVICT_FIRST", + "TMA_CACHE_EVICT_LAST", + "atomic_add_broadcast_i32", + "atomic_add_i32", + "convert_layout_acc_mn", + "convert_layout_from_tmem16x256b_to_acc_sm90", + "create_flat_gather4_tma_desc", + "create_flat_kv_tma_descs", + "create_q_gather4_tma_desc", + "create_strided_2d_tma_desc", + "make_16x256b_tensor_mn_view", + "prefetch_tma_desc_raw", + "real_col_to_stg128_fake_col", + "real_col_to_stg128_fp8_fake_col", + "real_col_to_stg128_half_fake_col", + "stg128_fake_col_to_real_col", + "stg128_fp8_fake_col_to_real_col", + "stg128_half_fake_col_to_real_col", + "stg_128", + "stg_128_cs", + "stg_128_bf16", + "stg_128_bf16_cs", + "stg_128_f16", + "stg_128_f16_cs", + "stg_128_fp8_e4m3_cs", + "stg_32_fp8_e4m3", + "stg_64_bf16", + "stg_64_f16", + "tma_gather4", + "tma_gather4_cached", + "tma_gather4_prefetch", + "tma_tile_load", + "tma_tile_load_cached", + "tma_tile_prefetch", + "tma_tile_store", +] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/common/utils.py b/build/torch212-cxx11-cu132-x86_64-linux/src/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bd0ba76b532cb54c159eba5e82320266c80c63 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/common/utils.py @@ -0,0 +1,1088 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +import math +import hashlib +import inspect +from typing import Type, Callable, Optional, Tuple, overload + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass.cute.runtime import from_dlpack + + +from ...quack import activation +_MIXER_ATTRS = ("__vec_size__",) + +# Obtained from sollya: +# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); +POLY_EX2 = { + 0: (1.0), + 1: ( + 1.0, + 0.922497093677520751953125, + ), + 2: ( + 1.0, + 0.6657850742340087890625, + 0.330107033252716064453125, + ), + 3: ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ), + 4: ( + 1.0, + 0.693042695522308349609375, + 0.2412912547588348388671875, + 5.2225358784198760986328125e-2, + 1.3434938155114650726318359375e-2, + ), + 5: ( + 1.0, + 0.693151414394378662109375, + 0.24016360938549041748046875, + 5.5802188813686370849609375e-2, + 9.01452265679836273193359375e-3, + 1.86810153536498546600341796875e-3, + ), +} + + +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) + + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash + + hasher = hashlib.sha256(base_hash.encode()) + + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) + + return hasher.hexdigest() + + +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale): + """Compute softmax_scale_log2 from softmax_scale. + + Returns (softmax_scale_log2, None). + """ + return softmax_scale * LOG2_E, None + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_rmem_tensor(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + else: + # New API: infers result type automatically + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + else: + res = cute.make_rmem_tensor(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@cute.jit +def fadd_exp2_scaled_reduce( + x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements" + if const_expr(arch < 100): + return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch) + elif const_expr(cute.size(x.shape) % 8 == 0): + local_sum = [ + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + (Float32(0.0), Float32(0.0)), + ] + for i in cutlass.range_constexpr(0, cute.size(x.shape), 8): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i + 0], x[i + 1]), (scale, scale) + ) + acc2, acc3 = cute.arch.mul_packed_f32x2( + (x[i + 2], x[i + 3]), (scale, scale) + ) + acc4, acc5 = cute.arch.mul_packed_f32x2( + (x[i + 4], x[i + 5]), (scale, scale) + ) + acc6, acc7 = cute.arch.mul_packed_f32x2( + (x[i + 6], x[i + 7]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + acc2 = cute.math.exp2(acc2, fastmath=True) + acc3 = cute.math.exp2(acc3, fastmath=True) + acc4 = cute.math.exp2(acc4, fastmath=True) + acc5 = cute.math.exp2(acc5, fastmath=True) + acc6 = cute.math.exp2(acc6, fastmath=True) + acc7 = cute.math.exp2(acc7, fastmath=True) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1)) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3)) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5)) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7)) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + else: + row_sum = Float32(0.0) + for i in cutlass.range_constexpr(0, cute.size(x.shape), 2): + acc0, acc1 = cute.arch.mul_packed_f32x2( + (x[i], x[i + 1]), (scale, scale) + ) + acc0 = cute.math.exp2(acc0, fastmath=True) + acc1 = cute.math.exp2(acc1, fastmath=True) + row_sum += acc0 + acc1 + return row_sum + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_rmem_tensor( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_f32( + a: float | Float32, + b: float | Float32, + c: float | Float32, + d: float | Float32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n" + ".reg .b16 h0, h1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n" + "cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n" + "mov.b32 $0, {h0, h1};\n" + "}\n", + "=r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp8x4_e4m3_bf16x4( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Convert packed e4m3x4 bits into two packed bf16x2 registers.""" + out0 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "and.b32 out, q, 0x80008000;\n\t" + "and.b32 mant, q, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + out1 = cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b32 q, qs, mant, out, bias, zero;\n\t" + "prmt.b32 q, $1, $1, 0x1302;\n\t" + "shl.b32 qs, q, 8;\n\t" + "and.b32 out, qs, 0x80008000;\n\t" + "and.b32 mant, qs, 0x7f007f00;\n\t" + "shr.u32 mant, mant, 4;\n\t" + "or.b32 out, out, mant;\n\t" + "mov.b32 bias, 0x7b807b80;\n\t" + "mov.b32 zero, 0;\n\t" + "fma.rn.bf16x2 $0, out, bias, zero;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return out0, out1 + + +@dsl_user_op +def cvt_fp4x2_e2m1_f16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert one packed E2M1 byte into one packed f16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0;\n\t" + "mov.b32 {byte0, _, _, _}, $1;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_f16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed f16x2 registers.""" + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + +@dsl_user_op +def cvt_fp4x8_e2m1_bf16x8( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Convert four packed E2M1 bytes into four packed bf16x2 registers.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t" + "cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t" + "cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t" + "cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t" + "cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t" + "}\n", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip)) + out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip)) + return out0, out1, out2, out3 + + f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8( + src, loc=loc, ip=ip + ) + return ( + cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip), + cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip), + ) + + +@dsl_user_op +def cvt_fp4x8_e2m1_scaled_e4m3x8( + src: cutlass.Int32, + scale_e4m3: cutlass.Int32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Int32, cutlass.Int32]: + """Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3.""" + + from cutlass import CUDA_VERSION + + if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2): + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 tmp, ra;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + "prmt.b32 tmp, $3, 0, 0;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "mov.b32 ra, {byte0, byte1, _, _};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t" + "mov.b32 ra, {_, _, byte2, byte3};\n\t" + "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + out = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32()]), + [ + cutlass.Int32(src).ir_value(loc=loc, ip=ip), + cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b32 sf_bytes, sf_f16x2;\n\t" + ".reg .b16 sf_pair, e0, e1, e2, e3;\n\t" + ".reg .b8 byte0, byte1, byte2, byte3;\n\t" + ".reg .b32 h0, h1, h2, h3;\n\t" + "prmt.b32 sf_bytes, $3, 0, 0;\n\t" + "mov.b32 {sf_pair, _}, sf_bytes;\n\t" + "cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t" + "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t" + "cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t" + "cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t" + "cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t" + "cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t" + "mul.rn.f16x2 h0, h0, sf_f16x2;\n\t" + "mul.rn.f16x2 h1, h1, sf_f16x2;\n\t" + "mul.rn.f16x2 h2, h2, sf_f16x2;\n\t" + "mul.rn.f16x2 h3, h3, sf_f16x2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t" + "mov.b32 $0, {e0, e1};\n\t" + "mov.b32 $1, {e2, e3};\n\t" + "}\n", + "=r,=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip)) + out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def cvt_f16x2_to_bf16x2( + src: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Convert a packed f16x2 register into a packed bf16x2 register.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [cutlass.Int32(src).ir_value(loc=loc, ip=ip)], + "{\n\t" + ".reg .b16 h0, h1;\n\t" + ".reg .f32 f0, f1;\n\t" + "mov.b32 {h0, h1}, $1;\n\t" + "cvt.f32.f16 f0, h0;\n\t" + "cvt.f32.f16 f1, h1;\n\t" + "cvt.rn.bf16x2.f32 $0, f1, f0;\n\t" + "}\n", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def mul_bf16x2( + a: cutlass.Int32, + b: cutlass.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Multiply two packed bf16x2 registers.""" + + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Int32(a).ir_value(loc=loc, ip=ip), + cutlass.Int32(b).ir_value(loc=loc, ip=ip), + ], + "mul.rn.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32: + """Decode one E4M3 byte and replicate it into a packed bf16x2 register.""" + + src_u8 = src & cutlass.Int32(0xFF) + packed = src_u8 * cutlass.Int32(0x01010101) + out0, _ = cvt_fp8x4_e4m3_bf16x4(packed) + return out0 + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_rmem_tensor(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@cute.jit +def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None: + """Convert a Float32 rmem tensor to dst's element type. + + fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment + ``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the + DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8 + bytes within a 32-bit register cell in the order DSL chooses, which is + expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads. + """ + if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]): + cvt_f16(src, dst) + elif const_expr(dst.element_type is cutlass.Float8E4M3FN): + assert src.element_type is Float32, "src must be Float32" + assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements" + frg_tile = 4 + src_frg = cute.logical_divide(src, cute.make_layout(frg_tile)) + dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile)) + for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])): + dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type)) + else: + assert src.element_type is Float32, "src must be Float32" + dst_view = cute.make_tensor(dst.iterator, src.layout) + dst_view.store(src.load().to(dst.element_type)) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: + assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" + # We assume x <= 127.0 + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +@dsl_user_op +def ex2_emulation_2( + x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None +) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = activation.sub_packed_f32x2( + xy_rounded, (fp32_round_int, fp32_round_int) + ) + xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_rmem_tensor(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] + + +# ------------------------------------------------------------------ +# Host-side Python helpers (not @cute.jit — called from PyTorch host code) +# ------------------------------------------------------------------ + +def default_softmax_scale(dim: int) -> float: + """Default softmax scale: 1 / sqrt(dim).""" + return 1.0 / math.sqrt(dim) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f23267fe73800d35db382a1919bc28196da5aa8c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention kernels.""" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/build_k2q_csr/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/build_k2q_csr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf19c60a32d2f57595c9666323b47738b878115 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/build_k2q_csr/__init__.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""q2k -> k2q CSR builder backed by the precompiled Torch ops. + +The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built +ahead of time by kernel-builder; it is reached through the ``_ops`` +namespace instead of being JIT-compiled at import time. + +The kernel pipeline is tuned and verified for SM100; other +architectures are not supported. +""" + +from __future__ import annotations + +import torch + +from ...._ops import ops + + +def run_build_k2q_csr( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, +) -> None: + """In-place fill of ``row_ptr`` and ``q_idx``. + + Args: + q2k: int32 [H, total_q, topK] contiguous (CUDA). + cu_seqlens_q: int32 [B+1] contiguous (CUDA). + cu_seqlens_k: int32 [B+1] contiguous (CUDA). + row_ptr: int32 [H, total_rows + 1] CUDA, written in place. + q_idx: int32 [H, total_q * topK] CUDA, written in place + (trailing slots set to -1). + topk: must be in {4, 8, 16, 32}. + blk_kv: must equal 128. + total_rows: sum over batches of ceil(seqlen_k / blk_kv). + max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound + used to size the row_map workspace and clamp valid kv ids. + """ + ops.run_build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + ) + + +def run_build_k2q_csr_with_schedule( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + qsplit_idx: torch.Tensor, + split_counts: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, + target_q_per_cta: int, + work_capacity: int, + max_seqlen_q: int, +) -> None: + """In-place fill of CSR plus fused sparse attention schedule metadata.""" + ops.run_build_k2q_csr_with_schedule( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + scheduler_metadata, + work_count, + qsplit_idx, + split_counts, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + int(target_q_per_cta), + int(work_capacity), + int(max_seqlen_q), + ) + + +def is_supported(topk: int, blk_kv: int) -> bool: + return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128 + + +__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/decode_schedule.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/decode_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..037791818feb030a5969ebf6ac3cc3943cdb7dce --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/decode_schedule.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Split-KV schedule for paged fp8 decode attention. + +The public PageKV representation remains this repo's rectangular page table: +``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only +describes how query tiles and KV chunks are split into work items. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DecodeAttentionSchedule: + split_kv: bool + cta_tile_q: int + num_q_tiles: int + kv_chunk_size_pages: int + kv_chunk_size_tokens: int + work_count: int + padded_work_count: int + partial_rows: int + max_split_count: int + max_grid_size: int + active_blocks_per_sm: int + num_sms: int + base_cta: int + request_indices: torch.Tensor + qo_tile_indices: torch.Tensor + kv_tile_indices: torch.Tensor + merge_indptr: torch.Tensor + o_indptr: torch.Tensor + block_valid_mask: torch.Tensor + kv_pages: torch.Tensor + split_counts: torch.Tensor + + +def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def prepare_decode_schedule( + *, + seqused_k: torch.Tensor, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, +) -> DecodeAttentionSchedule: + """Build paged decode split-KV schedule on the GPU. + + A single CUDA kernel reads ``seqused_k`` on device and writes all + schedule index arrays. Only a small summary tensor is D2H-synced so + the wrapper can size O_partial / pick the kernel grid / choose the + split-vs-non-split compile path. + + ``max_seqlen_k`` is the host-side worst-case bound used to pad the + work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``. + """ + _require_i32_cuda_1d(seqused_k, name="seqused_k") + # Hard cap: current single-CTA schedule kernel stores per-batch state + # in shared memory. Larger batches require a multi-CTA cooperative + # scheduler (unimplemented). Fail fast at the Python boundary so the + # error doesn't surface from inside the CUDA extension. + if int(seqused_k.shape[0]) > 1024: + raise NotImplementedError( + "decode schedule currently supports batch <= 1024 " + f"(got batch={int(seqused_k.shape[0])}). Larger batches need " + "the multi-CTA scheduler — not yet implemented." + ) + # Two API-boundary checks tied to the kernel's packed-GQA layout + # (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8): + # + # (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as + # col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first + # q-token in the packed group) this is col_limit = seqlen_k - seqlen_q + # + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked + # row then enters a mask-codegen path with PTX-undefined shift counts + # and the kernel hangs. The condition is also semantically invalid + # in batched-decode: you can't emit seqlen_q new tokens with fewer + # than seqlen_q total context tokens (seqlen_k includes them). + # + # (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires + # when the LAST partial page has < q_tokens_per_group=8 valid + # columns, because then the *last MMA tile* hits the same all-masked + # row case for the trailing q-tokens. + # + # Both are tracked as a separate kernel-level TODO (un-pack the + # all-masked row → skip mask call, or saturate causal_col_limit at >= 1 + # in mask.py). Until then, fail fast at the Python boundary with a + # clear message rather than letting the kernel timeout. + seqlen_q_i = int(seqlen_q) + bad_q = seqused_k < seqlen_q_i + if bool(bad_q.any().item()): + bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) " + f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. " + f"This is also a batched-decode invariant: seqlen_k must include " + f"the seqlen_q new tokens being emitted." + ) + rem = seqused_k % int(page_size) + bad_rem = (rem > 0) & (rem < seqlen_q_i) + if bool(bad_rem.any().item()): + bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] % page_size ∈ " + f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. " + f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has " + f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). " + f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to " + f"a multiple of {page_size}." + ) + if int(page_size) <= 0: + raise ValueError("page_size must be positive") + if int(seqlen_q) <= 0: + raise ValueError("seqlen_q must be positive") + if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0: + raise ValueError("head counts must be positive") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(num_qo_heads) // int(num_kv_heads) != 16: + raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16") + if int(head_dim) != 128: + raise NotImplementedError("decode schedule currently supports only head_dim=128") + if int(max_seqlen_k) <= 0: + raise ValueError("max_seqlen_k must be positive") + + from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule + + raw = build_decode_schedule( + seqused_k, + page_size=int(page_size), + seqlen_q=int(seqlen_q), + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + head_dim=int(head_dim), + max_seqlen_k=int(max_seqlen_k), + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=0 if max_grid_size is None else int(max_grid_size), + fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size), + disable_split_kv=bool(disable_split_kv), + ) + return DecodeAttentionSchedule( + split_kv=bool(raw["split_kv"]), + cta_tile_q=int(raw["cta_tile_q"]), + num_q_tiles=int(raw["num_q_tiles"]), + kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]), + kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]), + work_count=int(raw["work_count"]), + padded_work_count=int(raw["padded_work_count"]), + partial_rows=int(raw["partial_rows"]), + max_split_count=int(raw["max_split_count"]), + max_grid_size=int(raw["max_grid_size"]), + active_blocks_per_sm=int(raw["active_blocks_per_sm"]), + num_sms=int(raw["num_sms"]), + base_cta=int(raw["base_cta"]), + request_indices=raw["request_indices"], + qo_tile_indices=raw["qo_tile_indices"], + kv_tile_indices=raw["kv_tile_indices"], + merge_indptr=raw["merge_indptr"], + o_indptr=raw["o_indptr"], + block_valid_mask=raw["block_valid_mask"], + kv_pages=raw["kv_pages"], + split_counts=raw["split_counts"], + ) + + +__all__ = [ + "DecodeAttentionSchedule", + "prepare_decode_schedule", +] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fp4_indexer.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fp4_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa83e39a5504ac6cf8d732255e495e48b35fa20a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fp4_indexer.py @@ -0,0 +1,1956 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 FP4 sparse-attention indexer kernels.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from ...src.common import pipeline as common_pipeline + + +FP4_FORMAT = Literal["mxfp4", "nvfp4"] +_FP4_PACKED_D_BYTES = 64 +_HEAD_DIM = 128 +_BLOCK_K = 128 +_PAGE_SIZE = 128 +_MMA_TILER_MN = (128, 128) +_MMA_INST_SHAPE_K = 64 +_NON_CAUSAL_K_TILES_PER_CTA = 16 +_CAUSAL_K_TILES_PER_CTA = 16 +_DECODE_PACK_Q_LEN = 8 +_DECODE_QHEAD_PER_KV = 16 +_DECODE_K_TILES_PER_CTA = 16 +_AB_DTYPE = cutlass.Float4E2M1FN + + +@dataclass(frozen=True) +class Fp4FormatSpec: + name: FP4_FORMAT + sf_vec_size: int + scale_groups: int + torch_scale_dtype: torch.dtype + cutlass_scale_dtype: type + + +_FORMAT_SPECS: dict[str, Fp4FormatSpec] = { + "mxfp4": Fp4FormatSpec( + name="mxfp4", + sf_vec_size=32, + scale_groups=4, + torch_scale_dtype=torch.float8_e8m0fnu, + cutlass_scale_dtype=cutlass.Float8E8M0FNU, + ), + "nvfp4": Fp4FormatSpec( + name="nvfp4", + sf_vec_size=16, + scale_groups=8, + torch_scale_dtype=torch.float8_e4m3fn, + cutlass_scale_dtype=cutlass.Float8E4M3FN, + ), +} + + +def normalize_fp4_format(fmt: str) -> Fp4FormatSpec: + key = str(fmt).lower() + try: + return _FORMAT_SPECS[key] + except KeyError as exc: + raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc + + +def ceil_div(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def k_tiles_per_cta_for(causal: bool) -> int: + return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA + + +class Fp4IndexerScaleReorderSm100: + """Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + q_scale_mma_ptr: cute.Pointer, + k_scale_mma_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, page_count, heads_k = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = cute.ceil_div(self.scale_groups, 4) + k_l = page_count * heads_k + + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (total_q, heads_q, self.scale_groups), + stride=(heads_q * self.scale_groups, self.scale_groups, 1), + ), + ) + k_scale = cute.make_tensor( + k_scale_ptr, + cute.make_layout( + (page_count, heads_k, _PAGE_SIZE, self.scale_groups), + stride=( + heads_k * _PAGE_SIZE * self.scale_groups, + _PAGE_SIZE * self.scale_groups, + self.scale_groups, + 1, + ), + ), + ) + + q_mma_layout = cute.make_ordered_layout( + (32, 4, rest_q_m, 4, rest_g, heads_q), + order=(2, 1, 4, 0, 3, 5), + ) + k_mma_layout = cute.make_ordered_layout( + (32, 4, 1, 4, rest_g, k_l), + order=(2, 1, 4, 0, 3, 5), + ) + q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout) + k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout) + q_scale_mma = cute.group_modes(q_scale_mma, 0, 3) + q_scale_mma = cute.group_modes(q_scale_mma, 1, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 0, 3) + k_scale_mma = cute.group_modes(k_scale_mma, 1, 3) + + q_scale_count = total_q * heads_q * Int32(self.scale_groups) + k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups) + total_scale_count = q_scale_count + k_scale_count + grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta) + self.kernel( + q_scale, + k_scale, + q_scale_mma, + k_scale_mma, + heads_q, + heads_k, + q_scale_count, + total_scale_count, + ).launch( + grid=(grid_ctas, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + q_scale: cute.Tensor, + k_scale: cute.Tensor, + q_scale_mma: cute.Tensor, + k_scale_mma: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + q_scale_count: Int32, + total_scale_count: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + block_idx, _, _ = cute.arch.block_idx() + grid_dim, _, _ = cute.arch.grid_dim() + linear = block_idx * Int32(self.threads_per_cta) + tidx + stride = grid_dim * Int32(self.threads_per_cta) + + while linear < total_scale_count: + if linear < q_scale_count: + group = linear % Int32(self.scale_groups) + tmp = linear // Int32(self.scale_groups) + head = tmp % heads_q + row = tmp // heads_q + q_scale_mma[row, group, head] = q_scale[row, head, group] + else: + k_linear = linear - q_scale_count + group = k_linear % Int32(self.scale_groups) + tmp = k_linear // Int32(self.scale_groups) + row = tmp % Int32(_PAGE_SIZE) + tmp = tmp // Int32(_PAGE_SIZE) + head = tmp % heads_k + page = tmp // heads_k + scale_l = page * heads_k + head + k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group] + linear += stride + + +class Fp4IndexerStagedMmaSm100: + """Single-kernel FP4 indexer for preordered MMA scale storage.""" + + def __init__( + self, + *, + fmt: str, + causal: bool, + preordered_q_scale_tma: bool = False, + compact_schedule: bool = False, + use_tmem_load_red: bool = False, + ): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.preordered_q_scale_tma = bool(preordered_q_scale_tma) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal) + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + m, + _, + k, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + compact_task_count, + ) = problem_size + page_count = lk // heads_k + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (total_q, _HEAD_DIM, heads_q), + stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (total_q, _HEAD_DIM, heads_q), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor( + kv_indices_ptr, + cute.make_layout((page_count,), stride=(1,)), + ) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + if const_expr(self.preordered_q_scale_tma): + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + else: + tma_qs = tma_q + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0]) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid_x = compact_task_count + else: + grid_x = grid_q_tiles * grid_k_groups + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + q_scale_tensor, + k_scale_tensor, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + has_qo_offset, + max_k_tiles, + grid_k_groups, + ).launch( + grid=(grid_x, batch * heads_q, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_tile_start: Int32, + q_tile_last: Int32, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_tile_start < q_len and ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset + return True + + @cute.jit + def _full_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.jit + def _partial_tile_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + q_local: Int32, + k_local: Int32, + q_len: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mQS: cute.Tensor, + mKS: cute.Tensor, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + k_group_count: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + lane_idx = cute.arch.lane_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_idx, q_l, _ = cute.arch.block_idx() + batch_idx = q_l // heads_q + hq = q_l - batch_idx * heads_q + hk = hq // (heads_q // heads_k) + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + task_valid = True + q_tile_idx = Int32(0) + ktile_group = Int32(0) + if const_expr(self.compact_schedule): + remaining = task_idx + q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0]) + batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + q_scan = Int32(0) + task_valid = False + while q_scan < q_tile_count and not task_valid: + q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0]) + q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_scan_last >= q_len: + q_scan_last = q_len - Int32(1) + visible_limit = q_scan_last + causal_offset + visible_group_count = Int32(0) + if visible_limit >= Int32(0): + visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1) + if visible_group_count > batch_k_group_count: + visible_group_count = batch_k_group_count + task_valid = remaining < visible_group_count + if not task_valid: + remaining -= visible_group_count + q_scan += Int32(1) + if task_valid: + q_tile_idx = q_scan + ktile_group = remaining + else: + q_len = Int32(0) + k_len = Int32(0) + else: + q_tile_idx = task_idx // k_group_count + ktile_group = task_idx - q_tile_idx * k_group_count + q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0]) + q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) + if q_tile_last >= q_len: + q_tile_last = q_len - Int32(1) + q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len + q_tile_global_start = q_begin + q_tile_start + q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128) + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_tile_start, + q_tile_last, + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCsQ = thr_mma.partition_A(sQ_public) + tCsK = thr_mma.partition_B(sK_public) + mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma) + gQ_tma = cute.local_tile( + mQ_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + if const_expr(self.preordered_q_scale_tma): + mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma) + gQS_tma = cute.local_tile( + mQS_tma_cur, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + sQS = sQS_public + sKS = sKS_public + + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + if const_expr(self.preordered_q_scale_tma): + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_tma_copy_bytes, + defer_sync=True, + ).make_participants() + if const_expr(self.preordered_q_scale_tma): + qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.qs_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=qs_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + if warp_idx == self.load_warp_id: + if group_has_visible: + q_empty = q_producer.acquire_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_empty = qs_producer.acquire_and_advance() + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, q_tile_idx, 0, hq)], + tQsQS_tma[(None, qs_empty.index)], + tma_bar_ptr=qs_empty.barrier, + ) + qs_empty.commit() + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + else: + for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32): + row = row_base + lane_idx + q_local = q_tile_start + row + row_major = row // Int32(32) + row_atom = row - row_major * Int32(32) + for group in cutlass.range_constexpr(self.scale_groups): + group_i = Int32(group) + mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size) + sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0)) + q_scale_row = q_begin + q_local + if q_local >= q_len: + q_scale_row = q_begin + sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq] + cute.copy( + tma_q.atom, + tQgQ_tma[(None, q_tile_idx, 0, hq)], + tQsQ_tma[(None, q_empty.index)], + tma_bar_ptr=q_empty.barrier, + ) + q_empty.commit() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Move block scales into TMEM and issue one FP4 GEMM per visible K tile. + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_full = q_consumer.wait_and_advance() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_full = qs_consumer.wait_and_advance() + qs_full.release() + q_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + ktile = Int32(0) + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx == self.load_warp_id: + if group_has_visible: + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset + ktile = Int32(0) + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + if const_expr(self.preordered_q_scale_tma): + if q_scale_tma_safe: + qs_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + # Load accumulators from TMEM, reduce per-row max, and store scores. + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + q_local_store0 = q_tile_start + epi_tidx + q_global_store0 = q_begin + q_local_store0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta) + q_global_store1 = q_begin + q_local_store1 + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_tile_start, + q_tile_last, + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + tile_full = q_tile_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + if tile_mask_free: + if tile_full: + if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if coord_m == epi_tidx and q_local < q_len and k_local < k_len: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len: + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + if tile_full: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._full_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._full_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + q_local = q_tile_start + coord_m + k_local = ktile * Int32(_BLOCK_K) + coord_n + if self._partial_tile_coord_visible( + coord_m, + epi_tidx, + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if self._partial_tile_coord_visible( + coord_m, + epi_tidx + Int32(self.epi_threads_per_cta), + q_local, + k_local, + q_len, + k_len, + causal_offset, + ): + row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i]) + if q_tile_full: + mScores[hq, ktile, q_global_store0] = row_max0 + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = row_max0 + if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta): + if q_tile_full: + mScores[hq, ktile, q_global_store1] = row_max1 + elif q_local_store1 < q_len: + mScores[hq, ktile, q_global_store1] = row_max1 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner) + if ktile < max_k_tiles: + if q_tile_full: + mScores[hq, ktile, q_global_store0] = -Float32.inf + elif q_local_store0 < q_len: + mScores[hq, ktile, q_global_store0] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) + +class Fp4IndexerDecodeQPackSm100: + """Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage.""" + + def __init__(self, *, fmt: str): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.sf_dtype = spec.cutlass_scale_dtype + self.scale_groups = spec.scale_groups + self.threads_per_cta = 256 + + @cute.jit + def __call__( + self, + q_ptr: cute.Pointer, + q_scale_ptr: cute.Pointer, + q_pack_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + total_q, heads_q, heads_k, batch = problem_size + rest_q_m = cute.ceil_div(total_q, 128) + rest_g = ceil_div(self.scale_groups, 4) + q = cute.make_tensor( + q_ptr, + cute.make_layout( + (total_q, heads_q, _FP4_PACKED_D_BYTES), + stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale = cute.make_tensor( + q_scale_ptr, + cute.make_layout( + (heads_q, rest_q_m, rest_g, 32, 4, 4), + stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + q_pack_l = batch * heads_k + q_pack = cute.make_tensor( + q_pack_ptr, + cute.make_layout( + (q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES), + stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1), + ), + ) + q_scale_pack = cute.make_tensor( + q_scale_pack_ptr, + cute.make_layout( + (q_pack_l, 1, rest_g, 32, 4, 4), + stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1), + ), + ) + cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,))) + self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch( + grid=(q_pack_l, 1, 1), + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQS: cute.Tensor, + mQPack: cute.Tensor, + mQSPack: cute.Tensor, + mCuQ: cute.Tensor, + heads_q: Int32, + heads_k: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + q_pack_l, _, _ = cute.arch.block_idx() + batch_idx = q_pack_l // heads_k + hk = q_pack_l - batch_idx * heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + q_len = q_end - q_begin + qhead_per_kv = heads_q // heads_k + + linear = tidx + while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES): + row = linear // Int32(_FP4_PACKED_D_BYTES) + byte = linear - row * Int32(_FP4_PACKED_D_BYTES) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + if q_local < q_len and h_in_group < qhead_per_kv: + mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte] + else: + mQPack[q_pack_l, row, byte] = cutlass.Uint8(0) + linear += Int32(self.threads_per_cta) + + scale_linear = tidx + while scale_linear < Int32(_PAGE_SIZE * self.scale_groups): + row = scale_linear // Int32(self.scale_groups) + group = scale_linear - row * Int32(self.scale_groups) + h_in_group = row // Int32(_DECODE_PACK_Q_LEN) + q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN) + hq = hk * qhead_per_kv + h_in_group + q_abs = q_begin + q_local + if q_local >= q_len or h_in_group >= qhead_per_kv: + q_abs = q_begin + hq = hk * qhead_per_kv + src_rest_m = q_abs // Int32(128) + src_row = q_abs - src_rest_m * Int32(128) + src_row_atom = src_row % Int32(32) + src_row_major = src_row // Int32(32) + dst_row_atom = row % Int32(32) + dst_row_major = row // Int32(32) + rest_g = group // Int32(4) + group_in_rest = group - rest_g * Int32(4) + mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[ + hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest + ] + scale_linear += Int32(self.threads_per_cta) + + +class Fp4IndexerDecodePackedQSm100: + """Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``.""" + + def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False): + spec = normalize_fp4_format(fmt) + self.fmt = spec.name + self.is_causal = bool(causal) + self.compact_schedule = bool(compact_schedule) + self.use_tmem_load_red = bool(use_tmem_load_red) + self.sf_vec_size = spec.sf_vec_size + self.sf_dtype = spec.cutlass_scale_dtype + self.use_nvfp4 = spec.name == "nvfp4" + self.epi_threads_per_cta = 128 + self.epi_warps_per_group = 4 + self.num_epi_warpgroups = 2 + self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups + self.load_warp_id = self.mma_warp_id + 1 + self.threads_per_cta = 384 + self.num_tmem_alloc_cols = 512 + self.num_q_stage = 1 + self.num_acc_stage = 3 + self.num_ab_stage = 3 + self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA + self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2) + self.cta_tile_shape_mnk = self.mma_tiler + + @cute.jit + def __call__( + self, + q_pack_ptr: cute.Pointer, + k_ptr: cute.Pointer, + q_scale_pack_ptr: cute.Pointer, + k_scale_ptr: cute.Pointer, + scores_ptr: cute.Pointer, + kv_indices_ptr: cute.Pointer, + cu_seqlens_q_ptr: cute.Pointer, + cu_seqlens_k_ptr: cute.Pointer, + cu_page_offsets_ptr: cute.Pointer, + qo_offset_ptr: cute.Pointer, + problem_size: tuple, + stream: cuda.CUstream, + ): + ( + _, + _, + _, + _, + lk, + heads_q, + heads_k, + batch, + max_k_tiles, + total_q, + has_qo_offset, + ) = problem_size + page_count = lk // heads_k + q_pack_l = batch * heads_k + q_tma_tensor = cute.make_tensor( + cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM), + ), + ) + k_tma_tensor = cute.make_tensor( + cute.recast_ptr(k_ptr, dtype=_AB_DTYPE), + cute.make_layout( + (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count), + stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM), + ), + ) + q_scale_tensor = cute.make_tensor( + q_scale_pack_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, q_pack_l), + self.sf_vec_size, + ), + ) + k_scale_tensor = cute.make_tensor( + k_scale_ptr, + blockscaled_utils.tile_atom_to_shape_SF( + (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k), + self.sf_vec_size, + ), + ) + scores_tensor = cute.make_tensor( + scores_ptr, + cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)), + ) + kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,))) + cu_layout = cute.make_layout((batch + 1,), stride=(1,)) + cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout) + cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout) + cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout) + qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,))) + + if const_expr(self.use_nvfp4): + mma_op = tcgen05.MmaMXF4NVF4Op( + self.sf_dtype, + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + else: + mma_op = tcgen05.MmaMXF4Op( + (*_MMA_TILER_MN, _MMA_INST_SHAPE_K), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + ) + tiled_mma = cute.make_tiled_mma(mma_op) + q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage) + k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage) + q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_q_stage, + ) + k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0)) + k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0)) + tma_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_tma_tensor, + q_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_tma_tensor, + k_smem_layout_stage, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + ) + tma_qs = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_scale_tensor, + q_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + tma_ks = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k_scale_tensor, + k_scale_smem_layout, + self.mma_tiler, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta) + compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta) + if const_expr(self.compact_schedule): + grid = (compact_k_groups, heads_k, 1) + else: + grid = (grid_k_groups, batch * heads_k, 1) + self.kernel( + tiled_mma, + tma_q, + tma_qs, + tma_k, + tma_ks, + scores_tensor, + kv_indices_tensor, + cu_q_tensor, + cu_k_tensor, + cu_page_offsets_tensor, + qo_offset_tensor, + q_smem_layout, + k_smem_layout, + q_scale_smem_layout, + k_scale_smem_layout, + heads_q, + heads_k, + batch, + has_qo_offset, + max_k_tiles, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(1, 1, 1), + stream=stream, + ) + + @cute.jit + def _group_has_visible( + self, + q_len: Int32, + group_first_ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_has_visible( + self, + q_len: Int32, + ktile: Int32, + batch_k_tiles: Int32, + causal_offset: Int32, + ): + visible = ktile < batch_k_tiles + if const_expr(self.is_causal): + visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + return visible + + @cute.jit + def _tile_mask_free(self, ktile: Int32, causal_offset: Int32): + if const_expr(self.is_causal): + return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset + return True + + @cute.jit + def _packed_coord_visible( + self, + coord_m: Int32, + target_m: Int32, + h_in_group: Int32, + qhead_per_kv: Int32, + q_local: Int32, + q_len: Int32, + k_local: Int32, + k_len: Int32, + causal_offset: Int32, + ): + visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len + if const_expr(self.is_causal): + visible = visible and k_local <= q_local + causal_offset + return visible + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_q: cpasync.TmaInfo, + tma_qs: cpasync.TmaInfo, + tma_k: cpasync.TmaInfo, + tma_ks: cpasync.TmaInfo, + mScores: cute.Tensor, + mKvIndices: cute.Tensor, + mCuQ: cute.Tensor, + mCuK: cute.Tensor, + mCuPages: cute.Tensor, + mQoOffset: cute.Tensor, + q_smem_layout: cute.ComposedLayout, + k_smem_layout: cute.ComposedLayout, + q_scale_smem_layout: cute.Layout, + k_scale_smem_layout: cute.Layout, + heads_q: Int32, + heads_k: Int32, + batch: Int32, + has_qo_offset: Int32, + max_k_tiles: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + epi_tidx = tidx % Int32(self.epi_threads_per_cta) + epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group) + task_x, task_y, _ = cute.arch.block_idx() + task_valid = True + batch_idx = Int32(0) + hk = Int32(0) + ktile_group = Int32(0) + q_l = Int32(0) + if const_expr(self.compact_schedule): + hk = task_y + group_base = Int32(0) + scan_batch = Int32(0) + task_valid = False + while scan_batch < batch and not task_valid: + batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch] + batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta) + task_valid = task_x < group_base + batch_groups + if not task_valid: + group_base += batch_groups + scan_batch += Int32(1) + if task_valid: + batch_idx = scan_batch + ktile_group = task_x - group_base + q_l = batch_idx * heads_k + hk + else: + ktile_group = task_x + q_l = task_y + batch_idx = q_l // heads_k + hk = q_l - batch_idx * heads_k + qhead_per_kv = heads_q // heads_k + q_begin = mCuQ[batch_idx] + q_end = mCuQ[batch_idx + 1] + k_begin = mCuK[batch_idx] + k_end = mCuK[batch_idx + 1] + q_len = q_end - q_begin + k_len = k_end - k_begin + if const_expr(self.compact_schedule): + if not task_valid: + q_len = Int32(0) + k_len = Int32(0) + page_begin = mCuPages[batch_idx] + batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE) + causal_offset = Int32(0) + if const_expr(self.is_causal): + causal_offset = k_len - q_len + if has_qo_offset != 0: + causal_offset = mQoOffset[batch_idx] + group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta) + group_has_visible = self._group_has_visible( + q_len, + group_first_ktile, + batch_k_tiles, + causal_offset, + ) + + @cute.struct + class SharedStorage: + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner) + sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner) + sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128) + sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128) + mQ_tma = tma_q.tma_tensor + mQS_tma = tma_qs.tma_tensor + mK_tma = tma_k.tma_tensor + mKS_tma = tma_ks.tma_tensor + thr_mma = tiled_mma.get_slice(0) + tCrQ = tiled_mma.make_fragment_A(sQ_public) + tCrK = tiled_mma.make_fragment_B(sK_public) + tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2])) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + gQ_tma = cute.local_tile( + mQ_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQ_tma = thr_mma.partition_A(gQ_tma) + tQsQ_tma, tQgQ_tma = cpasync.tma_partition( + tma_q.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQ_public, 0, 3), + cute.group_modes(tCgQ_tma, 0, 3), + ) + gQS_tma = cute.local_tile( + mQS_tma, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + tCgQS_tma = thr_mma.partition_A(gQS_tma) + tQsQS_tma, tQgQS_tma = cpasync.tma_partition( + tma_qs.atom, + 0, + cute.make_layout(1), + cute.group_modes(sQS_public, 0, 3), + cute.group_modes(tCgQS_tma, 0, 3), + ) + tQsQS_tma = cute.filter_zeros(tQsQS_tma) + tQgQS_tma = cute.filter_zeros(tQgQS_tma) + gK_tma = cute.local_tile( + mK_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None, None), + ) + tCgK_tma = thr_mma.partition_B(gK_tma) + tKsK_tma, tKgK_tma = cpasync.tma_partition( + tma_k.atom, + 0, + cute.make_layout(1), + cute.group_modes(sK_public, 0, 3), + cute.group_modes(tCgK_tma, 0, 3), + ) + gKS_tma = cute.local_tile( + mKS_tma, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + tCgKS_tma = thr_mma.partition_B(gKS_tma) + tKsKS_tma, tKgKS_tma = cpasync.tma_partition( + tma_ks.atom, + 0, + cute.make_layout(1), + cute.group_modes(sKS_public, 0, 3), + cute.group_modes(tCgKS_tma, 0, 3), + ) + tKsKS_tma = cute.filter_zeros(tKsKS_tma) + tKgKS_tma = cute.filter_zeros(tKgKS_tma) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * (self.mma_warp_id + 1), + ), + ) + acc_pipeline = common_pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta), + defer_sync=True, + ) + acc_producer, _ = acc_pipeline.make_participants() + q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout) + qs_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_qs.smem_layout, mode=[0, 1, 2]), + ) + k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout) + ks_tma_copy_bytes = cute.size_in_bytes( + self.sf_dtype, + cute.select(tma_ks.smem_layout, mode=[0, 1, 2]), + ) + q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes + k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes + q_producer, q_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.q_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=q_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + k_producer, k_consumer = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.k_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=k_pair_tma_copy_bytes, + defer_sync=True, + ).make_participants() + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + if warp_idx == self.load_warp_id: + if group_has_visible: + q_pair_empty = q_producer.acquire_and_advance() + cute.copy( + tma_q.atom, + tQgQ_tma[(None, 0, 0, q_l)], + tQsQ_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + cute.copy( + tma_qs.atom, + tQgQS_tma[(None, 0, 0, q_l)], + tQsQS_tma[(None, q_pair_empty.index)], + tma_bar_ptr=q_pair_empty.barrier, + ) + q_pair_empty.commit() + load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if const_expr(self.is_causal): + load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if load_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_empty = k_producer.acquire_and_advance() + physical_page = mKvIndices[page_begin + ktile] + cute.copy( + tma_k.atom, + tKgK_tma[(None, 0, 0, hk, physical_page)], + tKsK_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + scale_l = physical_page * heads_k + hk + cute.copy( + tma_ks.atom, + tKgKS_tma[(None, 0, 0, scale_l)], + tKsKS_tma[(None, k_pair_empty.index)], + tma_bar_ptr=k_pair_empty.barrier, + ) + k_pair_empty.commit() + k_producer.tail() + q_producer.tail() + + if warp_idx == self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(q_scale_smem_layout, (None, None, None, 0)), + ) + tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(k_scale_smem_layout, (None, None, None, 0)), + ) + tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype) + tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype) + tCsQS_compact = cute.filter_zeros(sQS_public) + tCtQS_compact = cute.filter_zeros(tCtQS) + tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact) + thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0) + tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_qs, + thr_copy_s2t_qs.partition_S(tCsQS_compact), + ) + tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact) + tCsKS_compact = cute.filter_zeros(sKS_public) + tCtKS_compact = cute.filter_zeros(tCtKS) + tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact) + thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0) + tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_ks, + thr_copy_s2t_ks.partition_S(tCsKS_compact), + ) + tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact) + if group_has_visible: + q_pair_full = q_consumer.wait_and_advance() + q_pair_full.release() + cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t) + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + q_tile_crd = (None, None, None, 0) + if const_expr(self.is_causal): + causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1) + causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset + if causal_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles + if k_group_full: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + else: + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < batch_k_tiles: + k_pair_full = k_consumer.wait_and_advance() + acc_empty = acc_producer.acquire_and_advance() + cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t) + k_tile_crd = (None, None, None, k_pair_full.index) + tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)] + cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage) + acc_empty.commit() + k_pair_full.release() + acc_producer.tail() + + if warp_idx < self.mma_warp_id: + tmem_pool = tmem.reserve(self.num_tmem_alloc_cols) + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32) + if const_expr(self.use_tmem_load_red): + copy_atom_t2r = cute.make_copy_atom( + tcgen05.LdRed32x32bOp( + tcgen05.Repetition.x128, + tcgen05.Pack.NONE, + tcgen05.TmemLoadRedOp.MAX, + ), + Float32, + ) + else: + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_cC = thr_copy_t2r.partition_D(tCcC) + tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32) + if const_expr(self.use_tmem_load_red): + tTR_rRed = cute.make_rmem_tensor((1,), Float32) + h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN) + q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN) + h_global_store = hk * qhead_per_kv + h_store + q_global_store = q_begin + q_local_store + if group_has_visible: + visible_tile_count = Int32(0) + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + tile_has_visible = self._tile_has_visible( + q_len, + ktile, + batch_k_tiles, + causal_offset, + ) + if tile_has_visible: + epilogue_owns_tile = epi_warpgroup_idx == Int32( + ktile_inner % self.num_epi_warpgroups + ) + if epilogue_owns_tile: + acc_stage_index = visible_tile_count % Int32(self.num_acc_stage) + acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2) + tile_mask_free = self._tile_mask_free(ktile, causal_offset) + k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len + q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN) + tile_full = q_pack_full and k_tile_full + acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase) + tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)] + if const_expr(self.use_tmem_load_red): + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed]) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc) + row_max0 = -Float32.inf + if tile_mask_free and tile_full: + if const_expr(self.use_tmem_load_red): + row_max0 = tTR_rRed[0] + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, _ = tTR_cC[i] + if coord_m == epi_tidx: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + else: + for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True): + coord_m, coord_n = tTR_cC[i] + h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN) + q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN) + k_local = ktile * Int32(_BLOCK_K) + coord_n + valid = self._packed_coord_visible( + coord_m, + epi_tidx, + h_in_group, + qhead_per_kv, + q_local, + q_len, + k_local, + k_len, + causal_offset, + ) + if valid: + row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i]) + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = row_max0 + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release_w_index(acc_stage_index) + visible_tile_count += Int32(1) + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + else: + if const_expr(not self.compact_schedule): + if epi_warpgroup_idx == Int32(0): + for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta): + ktile = group_first_ktile + Int32(ktile_inner) + if ktile < max_k_tiles: + if h_store < qhead_per_kv and q_local_store < q_len: + mScores[h_global_store, ktile, q_global_store] = -Float32.inf + cute.arch.barrier() + tmem.free(tmem_pool.base_ptr) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18b99aea3f8b4915c03fe8147127374d920970f3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 forward kernels and combine paths.""" + +from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100 + +__all__ = ["SparseAttentionForwardNvfp4KvSm100"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..531b27c9e6b4bd8c1bc74fb1f92ed98a192ca0b2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd.py @@ -0,0 +1,3020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- Sparse Attention with flat varlen K/V +- Sparse Page Attention with paged K/V +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardSm100: + """SM100 sparse attention forward kernel.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + qk_dtype=None, + pv_dtype=None, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.qk_dtype_param = qk_dtype + self.pv_dtype_param = pv_dtype + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P dtype follows the PV operand policy and is packed into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mV: cute.Tensor, # Sparse Attention: [total_k, head_kv, dim] / Sparse Page Attention: prepared paged KV tensor + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_input_dtype = mK.element_type + self.v_input_dtype = mV.element_type + self.qk_dtype = ( + self.q_dtype if const_expr(self.qk_dtype_param is None) else self.qk_dtype_param + ) + if const_expr(self.pv_dtype_param is None): + legacy_fp8_kv_cache = ( + self.q_dtype == cutlass.BFloat16 + and self.k_input_dtype == cutlass.Float8E4M3FN + and self.v_input_dtype == cutlass.Float8E4M3FN + ) + self.pv_dtype = cutlass.BFloat16 if legacy_fp8_kv_cache else self.v_input_dtype + else: + self.pv_dtype = self.pv_dtype_param + self.k_dtype = self.qk_dtype + self.v_dtype = self.pv_dtype + self.p_dtype = self.pv_dtype + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported Q/K/V dtype: {self.q_dtype}") + if const_expr(self.qk_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported qk_dtype: {self.qk_dtype}") + if const_expr(self.pv_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"Unsupported pv_dtype: {self.pv_dtype}") + if const_expr(self.q_dtype != self.qk_dtype): + raise TypeError("Q storage dtype must match qk_dtype") + if const_expr( + self.k_input_dtype != self.k_dtype + and not (self.k_input_dtype == cutlass.Float8E4M3FN and self.k_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 K -> BF16 QK staging is supported") + if const_expr( + self.v_input_dtype != self.v_dtype + and not (self.v_input_dtype == cutlass.Float8E4M3FN and self.v_dtype == cutlass.BFloat16) + ): + raise TypeError("Only FP8 V -> BF16 PV staging is supported") + self.k_fp8_to_bf16 = ( + self.k_input_dtype == cutlass.Float8E4M3FN + and self.k_dtype == cutlass.BFloat16 + ) + self.v_fp8_to_bf16 = ( + self.v_input_dtype == cutlass.Float8E4M3FN + and self.v_dtype == cutlass.BFloat16 + ) + self.kv_fp8_to_bf16 = self.k_fp8_to_bf16 or self.v_fp8_to_bf16 + self.qk_mma_kind = "f8f6f4" if const_expr(self.qk_dtype.width == 8) else "f16" + self.pv_mma_kind = "f8f6f4" if const_expr(self.pv_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.p_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV = [assume_tensor_aligned(t) for t in (mK, mV)] + + if const_expr(not self.paged_kv): + # Flat varlen K/V use CUTE-managed TMA descriptors, matching FA: + # K: [total_k, h, d] -> [total_k, d, h]. + # V: [total_k, h, d] -> [d, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Sparse Page Attention with page-sized blocks can use the blocked + # paged TMA layout directly. Host input is [page, head, token, dim]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d,h,b) -> (d,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp8_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim), + stride=(self.head_dim, 1), + ), + cute.make_layout((1,)), + ) + sV_fp8_layout = cute.append( + cute.make_layout( + (self.head_dim, self.n_block_size), + stride=(1, self.head_dim), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.p_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms + # ------------------------------------------------------------------ + k_tma_layout = ( + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2]) + ) + v_tma_layout = ( + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2]) + ) + kv_tma_bytes = ( + cute.size_in_bytes(self.k_input_dtype, k_tma_layout) + + cute.size_in_bytes(self.v_input_dtype, v_tma_layout)) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + if const_expr(self.k_fp8_to_bf16): + tma_atom_K, mK = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp8_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim), + ) + else: + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + if const_expr(self.v_fp8_to_bf16): + tma_atom_V, mV = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp8_layout, mode=[0, 1]), + (self.head_dim, self.n_block_size), + ) + else: + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for unified kernel signature. Small-GQA Q load + # uses raw gather4 and keeps mQ_2d as a plain row-major GMEM tensor. + tma_atom_Q = tma_atom_V + else: + tma_atom_Q, mQ_2d = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + if const_expr(self.k_fp8_to_bf16): + mbar_k_tma: cute.struct.MemRange[Int64, 2] + if const_expr(self.v_fp8_to_bf16): + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + if const_expr(self.k_fp8_to_bf16): + sKFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.k_input_dtype, cute.cosize(sK_fp8_layout) + ], + self.buffer_align_bytes] + if const_expr(self.v_fp8_to_bf16): + sVFp8: cute.struct.Align[ + cute.struct.MemRange[ + self.v_input_dtype, cute.cosize(sV_fp8_layout) + ], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp8_layout, sV_fp8_layout, tP_layout, + tma_atom_K, tma_atom_V, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + kv_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + tma_K: cute.Tensor, + tma_V: cute.Tensor, + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp8_layout: cute.Layout, + sV_fp8_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atoms + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + kv_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + if const_expr(self.k_fp8_to_bf16): + sKFp8 = storage.sKFp8.get_tensor(sK_fp8_layout) + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + if const_expr(self.v_fp8_to_bf16): + sVFp8 = storage.sVFp8.get_tensor(sV_fp8_layout) + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_tma_bytes = cute.size_in_bytes( + self.k_input_dtype, + cute.select(sK_fp8_layout, mode=[0, 1]) + if const_expr(self.k_fp8_to_bf16) + else cute.select(sK_layout, mode=[0, 1, 2])) + v_tma_bytes = cute.size_in_bytes( + self.v_input_dtype, + cute.select(sV_fp8_layout, mode=[0, 1]) + if const_expr(self.v_fp8_to_bf16) + else cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + if const_expr(self.k_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_k_ptr, k_tma_bytes) + if const_expr(self.v_fp8_to_bf16): + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_tma_bytes) + else: + cute.arch.mbarrier_expect_tx(mbar_v_ptr, v_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + if const_expr(self.kv_fp8_to_bf16): + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + if const_expr(self.k_fp8_to_bf16): + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if const_expr(self.v_fp8_to_bf16): + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if warp_idx == Int32(self.total_warps - 1): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + if const_expr(self.kv_fp8_to_bf16): + self._wg_load_kv_maybe_cast( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + sKFp8 if const_expr(self.k_fp8_to_bf16) else None, + sVFp8 if const_expr(self.v_fp8_to_bf16) else None, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + mbar_k_tma_ptr if const_expr(self.k_fp8_to_bf16) else None, + mbar_v_tma_ptr if const_expr(self.v_fp8_to_bf16) else None, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + else: + self._wg_load_kv( + tma_atom_K, tma_atom_V, + tma_K, tma_V, + sPagedKvIdx, + sK, sV, + tiled_mma_qk, tiled_mma_pv, + mbar_k_ptr, mbar_v_ptr, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.k_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sKFp8, + sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + False, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + if const_expr(self.v_fp8_to_bf16): + self._wg_convert_fp8_kv_to_bf16_smem( + sVFp8, + sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + True, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _convert_fp8x16_to_bf16x16( + self, + src: cute.Tensor, + dst: cute.Tensor, + ): + src_i32 = cute.recast_tensor(src, cutlass.Int32) + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(4): + ( + dst_i32[word_idx * 2], + dst_i32[word_idx * 2 + 1], + ) = utils.cvt_fp8x4_e4m3_bf16x4(src_i32[word_idx]) + + @cute.jit + def _convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + elems_per_load: cutlass.Constexpr[int] = 16 + elems_per_store: cutlass.Constexpr[int] = 8 + chunks_per_row: cutlass.Constexpr[int] = self.head_dim // elems_per_load + r_fp8 = cute.make_rmem_tensor((elems_per_load,), cutlass.Float8E4M3FN) + r_bf16 = cute.make_rmem_tensor((elems_per_load,), cutlass.BFloat16) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * chunks_per_row + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(chunks_per_row) + chunk = task_idx - row * Int32(chunks_per_row) + col = chunk * Int32(elems_per_load) + smem_offset = row * Int32(self.head_dim) + col + s_fp8_ptr = cute.make_ptr( + cutlass.Float8E4M3FN, + sFp8.iterator.toint() + Int64(smem_offset), + mem_space=sFp8.iterator.memspace, + assumed_align=elems_per_load, + ) + s_fp8_vec = cute.make_tensor( + s_fp8_ptr, + cute.make_layout(elems_per_load), + ) + cute.autovec_copy(s_fp8_vec, r_fp8) + self._convert_fp8x16_to_bf16x16(r_fp8, r_bf16) + if const_expr(is_v): + sBf16_view = sBf16[(None, row % Int32(16)), 0, row // Int32(16), 0] + sBf16_vec = cute.local_tile(sBf16_view, (elems_per_load,), (chunk,)) + else: + sBf16_vec = sBf16[ + (row, None), + 0, + (chunk % Int32(4), chunk // Int32(4)), + 0, + ] + r_tiles = cute.logical_divide(r_bf16, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sBf16_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_load // elems_per_store): + cute.autovec_copy(r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv_maybe_cast( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sKFp8: Optional[cute.Tensor], + sVFp8: Optional[cute.Tensor], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + mbar_k_tma_ptr: Optional[cutlass.Pointer], + mbar_v_tma_ptr: Optional[cutlass.Pointer], + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.k_fp8_to_bf16): + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, + 0, + cute.make_layout(1), + gK, + sKFp8, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + else: + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.v_fp8_to_bf16): + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + gV, + sVFp8, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + else: + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_convert_fp8_kv_to_bf16_smem( + self, + sFp8: cute.Tensor, + sBf16: cute.Tensor, + mbar_tma_ptr, + mbar_ready_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + is_v: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + if has_work: + cute.arch.mbarrier_wait(mbar_tma_ptr, 0) + self._convert_fp8_kv_to_bf16_smem( + sFp8, + sBf16, + lane, + warp_idx_in_wg, + num_dequant_warps, + is_v, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_ready_ptr) + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K, + tma_atom_V, + tma_K: cute.Tensor, + tma_V: cute.Tensor, + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mbar_k_ptr, + mbar_v_ptr, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + thr_mma_qk = tiled_mma_qk.get_slice(0) + if const_expr(self.paged_kv): + mK_cur = tma_K[None, None, head_kv_idx, None] + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + tma_K[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0), + ) + tSgK = thr_mma_qk.partition_B(gK) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + gmem_k_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_K, + tKgK[(None, 0, gmem_k_idx)] + if const_expr(self.paged_kv) + else tKgK[(None, gmem_k_idx)], + tKsK[(None, 0)], + tma_bar_ptr=mbar_k_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + if warp_idx_in_wg == Int32(1): + thr_mma_pv = tiled_mma_pv.get_slice(0) + if const_expr(self.paged_kv): + mV_cur = tma_V[None, None, head_kv_idx, None] + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + tma_V[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None), + ) + tOgV = thr_mma_pv.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + gmem_v_idx = ( + sPagedKvIdx[0] if const_expr(self.paged_kv) else kv_block_idx + ) + cute.copy( + tma_atom_V, + tVgV[(None, 0, gmem_v_idx)] + if const_expr(self.paged_kv) + else tVgV[(None, gmem_v_idx)], + tVsV[(None, 0)], + tma_bar_ptr=mbar_v_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if producer_warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group in cutlass.range(num_q_groups, unroll=1): + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if warp_idx_in_wg == Int32(0): + next_slot = num_q_groups % Int32(self.q_stage) + next_phase = ( + (num_q_groups // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.qk_mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.pv_mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.p_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.p_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (p_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / p_dtype.width`` packed fp32 TMEM columns + # ``// (p_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.p_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1a8d6bf92b16d2943aa5e40fd91e26224ac40 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py @@ -0,0 +1,3305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""SM100 sparse attention forward kernel with NVFP4 K/V. + +This kernel implements the delivered Sparse Attention / Sparse Page Attention +forward contract: +- CSR sparse metadata (`k2q_row_ptr`, `k2q_q_indices`) +- varlen Q metadata via `cu_seqlens_q` +- BF16 Q +- packed NVFP4 K/V data +- E4M3 per-1x16 K/V scales in cuBLAS/cuDNN 128x4 tiled layout +- FP32 per-tensor K/V global scales +""" + +import math +from functools import partial +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cutlass_dsl import BaseDSL + +from ....quack import copy_utils + +from ....src.common.cute_dsl_utils import assume_tensor_aligned +from ....src.common import utils +from ....src.common import pipeline +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +# Shared raw PTX helpers and layout conversions used by the lean kernel. +from ....src.common.paged_kv import PagedKVManager +from ....src.common.tma_utils import ( + tma_gather4_cached, + tma_gather4_prefetch, + prefetch_tma_desc_raw, + TMA_CACHE_EVICT_LAST, + make_16x256b_tensor_mn_view, + real_col_to_stg128_fake_col, + real_col_to_stg128_fp8_fake_col, + real_col_to_stg128_half_fake_col, + stg_128_cs, + stg_128_bf16_cs, + stg_128_f16_cs, + stg_128_fp8_e4m3_cs, +) + + +class SparseAttentionForwardNvfp4KvSm100: + """SM100 sparse attention forward kernel with NVFP4 K/V.""" + + k_tile = 64 # UTCMMA bf16 K-tile (matches sparse_fwd_utcmma.py) + + def __init__( + self, + head_dim: int = 128, + qheadperkv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + paged_kv: bool = False, + page_size: Optional[int] = None, + has_seqused_k: bool = False, + causal: bool = False, + use_prepare_scheduler: bool = True, + fp8_pair_dequant: bool = True, + has_k_global_scale: bool = True, + has_v_global_scale: bool = True, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseAttentionForwardNvfp4KvSm100 currently supports only D=128, got D={head_dim}" + ) + self.head_dim = 128 + self.qheadperkv = qheadperkv + self.use_q_gather4 = qheadperkv in (4, 2, 1) + if qheadperkv not in (16, 8, 4, 2, 1): + raise ValueError( + "SparseAttentionForwardNvfp4KvSm100 supports qheadperkv in " + f"{{1, 2, 4, 8, 16}}, got {qheadperkv}" + ) + self.tokens_per_gather4 = 4 // qheadperkv if self.use_q_gather4 else 0 + self.m_block_size = m_block_size # 128 packed Q heads + self.n_block_size = n_block_size # 128 KV-block width + self.paged_kv = paged_kv + self.page_size = page_size + self.has_seqused_k = has_seqused_k + self.causal = causal + self.fp8_pair_dequant = fp8_pair_dequant + self.has_k_global_scale = has_k_global_scale + self.has_v_global_scale = has_v_global_scale + if not use_prepare_scheduler: + raise ValueError("SparseAttentionForwardNvfp4KvSm100 requires prepare scheduler") + self.use_prepare_scheduler = True + if self.paged_kv: + if page_size is None: + raise ValueError("page_size must be provided when paged_kv=True") + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal blk_kv ({n_block_size})" + ) + else: + self.page_size = n_block_size + self.q_tokens_per_group = m_block_size // qheadperkv # 8 + + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim) + self.mma_tiler_pv = (m_block_size, self.head_dim, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # Pipeline configuration — deeper Q prefetch ring plus 2-slot S/O rings. + self.q_stage = 2 + self.s_stage = 2 + self.o_stage = 2 + self.kv_stage = 1 + # Sparse q_idx metadata ring bridging load -> epilogue. Sized larger + # than the in-flight group distance so epilogue can reuse q_idx + # without rereading mK2qIndices. + self.qidx_meta_stages = 16 + + self.k_stages = 2 + self.q_stage_stride_bytes = m_block_size * self.head_dim * 2 + self.k_tile_stride_bytes = m_block_size * self.k_tile * 2 + self.token_stride_bytes = qheadperkv * self.k_tile * 2 + + # Warp layout: two softmax WGs, one Q-load/epilogue WG, one + # MMA issue warp, two K/V load warps, and one empty warp. + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.store_warp_base = self.softmax1_warp_base + self.warps_per_group + self.mma_warp_id = self.store_warp_base + self.warps_per_group + self.load_warp_base = self.mma_warp_id + 1 + self.q_load_warp_base = self.store_warp_base + self.kv_load_warp_base = self.load_warp_base + self.num_kv_load_warps = 2 + self.num_q_load_warps = self.warps_per_group + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps # 512 + + # TMEM layout follows FA SM100: + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for hdim_v=128 + # P is bf16 and starts halfway into each S tile. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * n_block_size + self.tmem_s_to_p_offset = n_block_size // 2 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. The + # 128-wide path naturally uses 512 columns; 64-wide KV blocks use 384 + # columns and must round the allocation up while keeping the same + # logical offsets. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # Let PV start once the first 3/4 of P is visible in TMEM. The final + # split is synchronized by a separate mbarrier consumed inside PV MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # Register allocation per role. The causal hdim128 split gives the + # epilogue enough room for partial-O/LSE address generation while the + # two softmax WGs still have enough registers to avoid S/P spills. + self.num_regs_softmax = 176 if causal else 192 + self.num_regs_store = 112 if causal else 80 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_store + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_empty = self.num_regs_other + self.store_reg_decrease = self.num_regs_store <= 128 + self.ex2_emu_freq = 16 if causal else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # SM100 config. + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + @cute.jit + def _batch_q_offset( + self, + batch_idx: Int32, + mCuSeqlensQ, + ) -> Int32: + return mCuSeqlensQ[batch_idx] + + @cute.jit + def _logical_seqlen_k( + self, + batch_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + if const_expr(self.has_seqused_k): + return mSeqUsedK[batch_idx] + if const_expr(self.paged_kv): + return Int32(mPageTable.shape[1]) * Int32(self.page_size) + return mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + + @cute.jit + def _valid_cols_in_block( + self, + batch_idx: Int32, + kv_block_idx: Int32, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) -> Int32: + seqlen_k = self._logical_seqlen_k( + batch_idx, mPageTable, mSeqUsedK, mCuSeqlensK + ) + block_start = kv_block_idx * Int32(self.n_block_size) + remaining = seqlen_k - block_start + remaining = cutlass.max(remaining, Int32(0)) + return cutlass.min(remaining, Int32(self.n_block_size)) + + @cute.jit + def _load_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _load_qsplit_idx( + self, + mK2qQSplitIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + qi: Int32, + ) -> Int32: + return mK2qQSplitIndices[head_kv_idx, row_start + qi] + + @cute.jit + def _decode_q_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return qsplit & Int32(0x00FF_FFFF) + + @cute.jit + def _decode_split_idx_from_qsplit(self, qsplit: Int32) -> Int32: + return (qsplit >> Int32(24)) & Int32(0xFF) + + @cute.jit + def _lower_bound_q_idx( + self, + mK2qIndices: cute.Tensor, + head_kv_idx: Int32, + row_start: Int32, + count: Int32, + q_value: Int32, + ) -> Int32: + left = Int32(0) + right = count + # k2q_q_indices is sorted by q_idx within each CSR row. A fixed + # 32-step loop covers int32-sized rows and keeps this CTA-level. + for _ in cutlass.range(32, unroll=1): + if left < right: + mid = (left + right) // Int32(2) + q_idx = self._load_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + mid, + ) + if q_idx < q_value: + left = mid + Int32(1) + else: + right = mid + return left + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mK: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mV: cute.Tensor, # packed NVFP4: flat [total_k, head_kv, dim/2] or paged [page, head_kv, token, dim/2] + mKScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened K rows and dim/16 cols + mVScale: cute.Tensor, # E4M3 uint8, 128x4 tiled over flattened V rows and dim/16 cols + mKGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mVGlobalScale: Optional[cute.Tensor], # optional FP32 tensor/global dequant scale + mK2qIndices: cute.Tensor, # csr payload: [head_kv, nnz] + mK2qQSplitIndices: cute.Tensor, # csr payload: [head_kv, nnz] packed q_idx/split slot + mK2qCounts: cute.Tensor, # csr row_ptr: [head_kv, total_rows + 1] + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, # fp32 O_partial buffer (kept alive) + mLSE_partial: cute.Tensor, # fp32 LSE_partial + mLSE_temperature_partial: Optional[cute.Tensor], # fp32 temperature-scaled LSE_partial + mQ_flat: cute.Tensor, # [batch*Sq*head_q, dim] bf16, pre-flattened + mQ_gather4_desc: Optional[cute.Tensor], # [128] uint8 tensor map for gather4 Q load + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + softmax_scale: Float32, + lse_temperature_scale: Float32, + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + stream=None, + ): + self.q_dtype = mQ_flat.element_type + self.k_cache_dtype = mK.element_type + self.v_cache_dtype = mV.element_type + self.k_scale_dtype = mKScale.element_type + self.v_scale_dtype = mVScale.element_type + if const_expr(self.q_dtype not in [cutlass.BFloat16, cutlass.Float8E4M3FN]): + raise TypeError(f"KVFP4 forward requires BF16 or FP8 E4M3 Q, got {self.q_dtype}") + self.k_dtype = self.q_dtype + self.v_dtype = self.q_dtype + if const_expr(self.k_cache_dtype is not cutlass.Uint8 or self.v_cache_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects packed uint8 K/V, got {self.k_cache_dtype}, {self.v_cache_dtype}" + ) + if const_expr(self.k_scale_dtype is not cutlass.Uint8 or self.v_scale_dtype is not cutlass.Uint8): + raise TypeError( + f"KVFP4 forward expects uint8 E4M3 scales, got {self.k_scale_dtype}, {self.v_scale_dtype}" + ) + if const_expr(self.has_k_global_scale and mKGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 K global scale") + if const_expr(self.has_v_global_scale and mVGlobalScale.element_type is not Float32): + raise TypeError("KVFP4 forward expects FP32 V global scale") + self.mma_kind = "f8f6f4" if const_expr(self.q_dtype.width == 8) else "f16" + elem_bytes = const_expr(self.q_dtype.width // 8) + self.q_stage_stride_bytes = self.m_block_size * self.head_dim * elem_bytes + self.k_tile_stride_bytes = self.m_block_size * self.k_tile * elem_bytes + self.token_stride_bytes = self.qheadperkv * self.k_tile * elem_bytes + p_cols_as_fp32 = const_expr(self.n_block_size * self.q_dtype.width // Float32.width) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + self.o_dtype = mO_partial.element_type + if const_expr( + self.o_dtype + not in [Float32, cutlass.BFloat16, cutlass.Float16, cutlass.Float8E4M3FN] + ): + raise TypeError(f"Unsupported O_partial dtype: {self.o_dtype}") + mK, mV, mKScale, mVScale = [ + assume_tensor_aligned(t) for t in (mK, mV, mKScale, mVScale) + ] + + if const_expr(not self.paged_kv): + # Flat varlen K/V: + # K: [total_k, h, d/2] -> [total_k, d/2, h]. + # V: [total_k, h, d/2] -> [d/2, total_k, h] for MN-major PV. + layout_t = [0, 2, 1] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2])) + else: + # Host input is [page, head, token, dim/2]. + layout_t = [2, 3, 1, 0] + mK = cute.make_tensor(mK.iterator, cute.select(mK.layout, mode=layout_t)) + mV_kv = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=layout_t)) + # V: (s,d/2,h,b) -> (d/2,s,h,b) for MN-major + mV = cute.make_tensor(mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3])) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T and PV + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM layouts: sQ/sK/sV only. O_partial is written directly from + # registers to GMEM in the epilogue. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + q_load_tile = ( + self.head_dim + if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_tile + ) + q_load_subtiles_per_token = const_expr(self.head_dim // q_load_tile) + num_subtiles_total = ( + total_q_stages * self.q_tokens_per_group * q_load_subtiles_per_token + ) + sQ_load_layout = sm100_utils.make_smem_layout( + tcgen05.OperandMajorMode.K, + (self.qheadperkv, q_load_tile), self.q_dtype, num_subtiles_total) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + sK_fp4_layout = cute.append( + cute.make_layout( + (self.n_block_size, self.head_dim // 2), + stride=(self.head_dim // 2, 1), + ), + cute.make_layout((1,)), + ) + sV_fp4_layout = cute.append( + cute.make_layout( + (self.head_dim // 2, self.n_block_size), + stride=(1, self.head_dim // 2), + ), + cute.make_layout((1,)), + ) + # P SMEM layout metadata (no actual SMEM allocation — P lives in TMEM, + # overlaying the S region; this layout is only used to compute the PV + # A-operand TMEM descriptor shape at the MMA issue site.) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + # ------------------------------------------------------------------ + # TMA atoms. Packed FP4 K/V are staged by TMA, then dequantized into + # BF16 MMA SMEM layout by the KV load warps. + # ------------------------------------------------------------------ + k_fp4_tma_bytes = cute.size_in_bytes( + self.k_cache_dtype, cute.select(sK_fp4_layout, mode=[0, 1])) + v_fp4_tma_bytes = cute.size_in_bytes( + self.v_cache_dtype, cute.select(sV_fp4_layout, mode=[0, 1])) + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_atom_K_fp4, mK_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mK, + cute.select(sK_fp4_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim // 2), + ) + tma_atom_V_fp4, mV_tma = cpasync.make_tiled_tma_atom( + tma_load_op, + mV, + cute.select(sV_fp4_layout, mode=[0, 1]), + (self.head_dim // 2, self.n_block_size), + ) + mK = mK_tma + mV = mV_tma + + # Q per-sub-tile TMA atom: bf16 uses two 64-element halves; fp8 uses + # one 128-element row because 128 fp8 elements occupy the same 128B + # swizzle span as 64 bf16 elements. + mQ_flat = assume_tensor_aligned(mQ_flat) + mQ_2d = cute.make_tensor( + mQ_flat.iterator, cute.select(mQ_flat.layout, mode=[0, 1])) + if const_expr(self.use_q_gather4): + # Placeholder atom for the unified kernel signature. Small-GQA Q + # loading uses raw gather4, so mQ_2d must stay as the plain GMEM + # tensor. The placeholder uses the natural SMEM top-level shape. + tma_atom_Q, _ = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (8, q_load_tile)) + else: + tma_atom_Q, mQ_2d_tma = cpasync.make_tiled_tma_atom( + tma_load_op, mQ_2d, + cute.select(sQ_load_layout, mode=[0, 1]), + (self.qheadperkv, q_load_tile)) + mQ_2d = mQ_2d_tma + q_subtile_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_load_layout, mode=[0, 1])) + + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + lse_temperature_scale_log2 = softmax_scale_log2 * lse_temperature_scale + + # ------------------------------------------------------------------ + # SharedStorage — lean: just the mbars and tiles we actually use. + # + # Mbarriers (all storage rings stay below the 64-per-CTA limit): + # mbar_kv [2] one-shot K/V load handshake (full + empty) + # mbar_q [q_stage * 2] Q producer/consumer ring + # mbar_s [2] QK UTCMMA -> softmax (full + empty) + # mbar_o [2] PV UTCMMA -> epilogue (full + empty) + # mbar_p [s_stage * 2] softmax early-P arrive -> PV + # mbar_p_lastsplit [s_stage * 2] softmax final-P arrive -> PV + # (used only when ``self.split_P_arrive > 0``) + # mbar_sm_stats [s_stage * 2] softmax row_sum/row_max + # publish -> epilogue consumer read. In lean + # 1-WG topology the producer and consumer are + # the same 128 WG_C threads, but we keep the + # barrier for structural parity with FA so the + # softmax body reads identically. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_k: cute.struct.MemRange[Int64, 2] + mbar_v: cute.struct.MemRange[Int64, 2] + mbar_k_tma: cute.struct.MemRange[Int64, 2] + mbar_v_tma: cute.struct.MemRange[Int64, 2] + mbar_q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_p_lastsplit: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_o: cute.struct.MemRange[Int64, self.o_stage * 2] + mbar_sm_stats: cute.struct.MemRange[Int64, self.o_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + # Per-row softmax stats cache (for epilogue LSE + rescale): + # [0 : m_block_size) row_sum + # [m_block_size : 2*m_block_size) row_max + sScale: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size * 2] + # Per-row temperature LSE row_sum cache. The row_max is shared with + # sScale because lse_temperature_scale is positive. + sScaleTemperature: cute.struct.MemRange[ + Float32, self.o_stage * self.m_block_size] + # Per-token split_id from prepare-time per-edge metadata. + sSplitIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Per-token q_idx cache to avoid reloading sparse indices in epilogue. + sQIdx: cute.struct.MemRange[ + Int32, self.o_stage * self.q_tokens_per_group] + # Prefix length of q_idx-sorted row entries that may need causal + # masking for this KV block. This is CTA-level metadata, not a + # token-count cap. + sDiagQCount: cute.struct.MemRange[Int32, 1] + # CTA-wide row metadata, published once by tidx 0 and reused by + # all warp-specialized roles: + # [0] batch_idx + # [1] kv_block_idx + # [2] row_start + # [3] count_raw + # [4] kv_valid_cols + # [5] q_batch_offset + # [6] k_batch_offset + # [7] causal_q_offset = seqlen_k - seqlen_q + sRowMeta: cute.struct.MemRange[Int32, 8] + sPagedKvIdx: cute.struct.MemRange[Int32, 1] + sQLoadMIdx: cute.struct.MemRange[ + Int32, self.q_stage * self.q_tokens_per_group] + # Packed per-edge q/split metadata: + # low 24 bits = q_idx, high 8 bits = split slot. + sQIdxMeta: cute.struct.MemRange[ + Int32, self.qidx_meta_stages * self.q_tokens_per_group] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes] + sKFp4: cute.struct.Align[ + cute.struct.MemRange[self.k_cache_dtype, cute.cosize(sK_fp4_layout)], + self.buffer_align_bytes] + sVFp4: cute.struct.Align[ + cute.struct.MemRange[self.v_cache_dtype, cute.cosize(sV_fp4_layout)], + self.buffer_align_bytes] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes] + + self.shared_storage = SharedStorage + num_ctas = work_capacity + + self.kernel( + mK, mV, mKScale, mVScale, mKGlobalScale, mVGlobalScale, + mK2qIndices, mK2qQSplitIndices, mK2qCounts, + mSchedulerMetadata, mWorkCount, + mO_partial, mLSE_partial, mLSE_temperature_partial, mQ_2d, mQ_gather4_desc, + mPageTable, mSeqUsedK, mCuSeqlensQ, mCuSeqlensK, + softmax_scale_log2, lse_temperature_scale_log2, lse_temperature_scale, + sQ_layout, sQ_load_layout, sK_layout, sV_layout, + sK_fp4_layout, sV_fp4_layout, tP_layout, + tma_atom_K_fp4, tma_atom_V_fp4, tma_atom_Q, + tiled_mma_qk, tiled_mma_pv, + k_fp4_tma_bytes, v_fp4_tma_bytes, q_tma_bytes, q_subtile_bytes, + num_kv_blocks, num_heads_kv, seq_len_q, work_capacity, + ).launch( + grid=(num_ctas,), + block=[self.threads_per_cta, 1, 1], + smem=max(SharedStorage.size_in_bytes(), 49152), + stream=stream, + min_blocks_per_mp=1, + ) + + # ------------------------------------------------------------------ + # Device-side: kernel entry, dispatch by warpgroup + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + # Runtime tensors + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + mK2qIndices: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mK2qCounts: cute.Tensor, + mSchedulerMetadata: Optional[cute.Tensor], + mWorkCount: Optional[cute.Tensor], + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + mQ_2d: cute.Tensor, + mQ_gather4_desc: Optional[cute.Tensor], + mPageTable, + mSeqUsedK, + mCuSeqlensQ, + mCuSeqlensK, + # Scalars + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + # Layouts + sQ_layout: cute.ComposedLayout, + sQ_load_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sK_fp4_layout: cute.Layout, + sV_fp4_layout: cute.Layout, + tP_layout: cute.ComposedLayout, + # TMA atom + tma_atom_K_fp4: cute.CopyAtom, + tma_atom_V_fp4: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + # MMA + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + # Transfer sizes + k_fp4_tma_bytes: cutlass.Constexpr[int], + v_fp4_tma_bytes: cutlass.Constexpr[int], + q_tma_bytes: cutlass.Constexpr[int], + q_subtile_bytes: cutlass.Constexpr[int], + # Iteration bounds + num_kv_blocks: Int32, + num_heads_kv: Int32, + seq_len_q: Int32, + work_capacity: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, CTA coordinate + # ------------------------------------------------------------------ + bidx, _, _ = cute.arch.block_idx() + row_linear = Int32(0) + head_kv_idx = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + work_q_begin = Int32(0) + work_q_count = Int32(0) + cta_valid_work = True + work_idx = bidx + cta_valid_work = work_idx < mWorkCount[Int32(0)] + if cta_valid_work: + head_kv_idx = mSchedulerMetadata[work_idx, Int32(0)] + row_linear = mSchedulerMetadata[work_idx, Int32(1)] + work_q_begin = mSchedulerMetadata[work_idx, Int32(2)] + work_q_count = mSchedulerMetadata[work_idx, Int32(3)] + batch_idx = mSchedulerMetadata[work_idx, Int32(4)] + kv_block_idx = mSchedulerMetadata[work_idx, Int32(5)] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + head_q = num_heads_kv * Int32(self.qheadperkv) + paged_kv_manager = ( + PagedKVManager.create( + mPageTable, + page_size=self.page_size, + n_block_size=self.n_block_size, + ) + if const_expr(self.paged_kv) + else None + ) + + # Prefetch TMA descriptors (warp 0 once). + if warp_idx == 0: + if const_expr(not self.use_q_gather4): + cpasync.prefetch_descriptor(tma_atom_Q) + else: + with cute.arch.elect_one(): + prefetch_tma_desc_raw(mQ_gather4_desc.iterator) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # SMEM allocation (all warps — same SharedStorage type from __call__) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sKFp4 = storage.sKFp4.get_tensor(sK_fp4_layout) + sVFp4 = storage.sVFp4.get_tensor(sV_fp4_layout) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sQ_load = storage.sQ.get_tensor(sQ_load_layout.outer, swizzle=sQ_load_layout.inner) + sScale = storage.sScale.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size * 2)) + sScaleTemperature = storage.sScaleTemperature.get_tensor( + cute.make_layout(self.o_stage * self.m_block_size)) + sSplitIdx = storage.sSplitIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sQIdx = storage.sQIdx.get_tensor( + cute.make_layout((self.o_stage * self.q_tokens_per_group,))) + sDiagQCount = storage.sDiagQCount.get_tensor(cute.make_layout((1,))) + sRowMeta = storage.sRowMeta.get_tensor(cute.make_layout((8,))) + sPagedKvIdx = storage.sPagedKvIdx.get_tensor(cute.make_layout((1,))) + sQLoadMIdx = storage.sQLoadMIdx.get_tensor( + cute.make_layout((self.q_stage * self.q_tokens_per_group,))) + sQIdxMeta = storage.sQIdxMeta.get_tensor( + cute.make_layout((self.qidx_meta_stages * self.q_tokens_per_group,))) + mbar_k_ptr = storage.mbar_k.data_ptr() + mbar_v_ptr = storage.mbar_v.data_ptr() + mbar_k_tma_ptr = storage.mbar_k_tma.data_ptr() + mbar_v_tma_ptr = storage.mbar_v_tma.data_ptr() + + # ------------------------------------------------------------------ + # TMEM allocator — allocator warp 0 serves the whole CTA. + # ------------------------------------------------------------------ + tmem_alloc_warps: cutlass.Constexpr[int] = self.warps_per_group * 2 + 1 + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # ------------------------------------------------------------------ + # Warp-specialized pipelines. + # ------------------------------------------------------------------ + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + epilogue_threads = softmax_threads + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_s.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_p_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_threads, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_o.data_ptr(), + num_stages=self.o_stage, + producer_group=mma_thread, + consumer_group=epilogue_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_sm_stats.data_ptr(), + num_stages=self.o_stage, + producer_group=softmax_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + # Cluster sync (no-op for 1CTA cluster). + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # ------------------------------------------------------------------ + # Work count: how many Q tokens reference this CTA's KV block + # ------------------------------------------------------------------ + k_smem_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + v_smem_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + if tidx == 0: + row_batch_idx = batch_idx + row_kv_block_idx = kv_block_idx + base_row_start = mK2qCounts[head_kv_idx, row_linear] + row_start = base_row_start + count_raw = ( + mK2qCounts[head_kv_idx, row_linear + Int32(1)] + - base_row_start + ) + row_start = base_row_start + work_q_begin + count_raw = work_q_count + kv_valid_cols = ( + self._valid_cols_in_block( + row_batch_idx, + row_kv_block_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + ) + q_batch_offset = self._batch_q_offset( + row_batch_idx, mCuSeqlensQ + ) + k_batch_offset = ( + Int32(0) + if const_expr(self.paged_kv) + else mCuSeqlensK[row_batch_idx] + ) + sRowMeta[0] = row_batch_idx + sRowMeta[1] = row_kv_block_idx + sRowMeta[2] = row_start + sRowMeta[3] = count_raw + sRowMeta[4] = kv_valid_cols + sRowMeta[5] = q_batch_offset + sRowMeta[6] = k_batch_offset + causal_q_offset = Int32(0) + if const_expr(self.causal): + seqlen_q = mCuSeqlensQ[row_batch_idx + Int32(1)] - q_batch_offset + seqlen_k = self._logical_seqlen_k( + row_batch_idx, + mPageTable, + mSeqUsedK, + mCuSeqlensK, + ) + causal_q_offset = seqlen_k - seqlen_q + sRowMeta[7] = causal_q_offset + if const_expr(self.paged_kv): + sPagedKvIdx[0] = paged_kv_manager.physical_block_index( + row_batch_idx, row_kv_block_idx + ) + cute.arch.mbarrier_init(mbar_k_ptr, 1) + cute.arch.mbarrier_init(mbar_v_ptr, 1) + cute.arch.mbarrier_init(mbar_k_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_k_tma_ptr, k_fp4_tma_bytes) + cute.arch.mbarrier_init(mbar_v_tma_ptr, 1) + cute.arch.mbarrier_expect_tx(mbar_v_tma_ptr, v_fp4_tma_bytes) + diag_q_count = Int32(0) + if const_expr(self.causal): + row_has_visible_cols = (count_raw > Int32(0)) & (kv_valid_cols > Int32(0)) + if row_has_visible_cols: + kv_valid_end = ( + row_kv_block_idx * Int32(self.n_block_size) + + kv_valid_cols + ) + q_threshold = kv_valid_end - causal_q_offset + diag_q_count = self._lower_bound_q_idx( + mK2qIndices, + head_kv_idx, + row_start, + count_raw, + q_threshold, + ) + sDiagQCount[0] = diag_q_count + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = Float32.width // self.v_dtype.width + tP_stage_stride = self.tmem_stage_stride * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + + tmem_cols = self.tmem_total + + load_wg_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.LoadWG), + num_threads=cute.arch.WARP_SIZE * self.num_q_load_warps, + ) + kv_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvLoad), + num_threads=cute.arch.WARP_SIZE * self.num_kv_load_warps, + ) + kv_dequant_k_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantK), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + kv_dequant_v_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.KvDequantV), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + sm_stats_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), + num_threads=cute.arch.WARP_SIZE * 2, + ) + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.StoreEpilogue), + num_threads=cute.arch.WARP_SIZE * self.warps_per_group, + ) + if ( + warp_idx == Int32(self.total_warps - 1) + and warp_idx >= Int32(self.kv_load_warp_base + self.num_kv_load_warps) + ): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + q_load_thread_end = Int32( + (self.q_load_warp_base + self.num_q_load_warps) * cute.arch.WARP_SIZE + ) + is_q_load_thread = tidx >= q_load_thread_base and tidx < q_load_thread_end + if is_q_load_thread and cta_valid_work: + if self.store_reg_decrease: + cute.arch.setmaxregister_decrease(self.num_regs_store) + else: + cute.arch.setmaxregister_increase(self.num_regs_store) + row_start_load = sRowMeta[2] + count_raw_load = sRowMeta[3] + q_batch_offset_load = sRowMeta[5] + # Do not gate on KV validity here; sparse entries past seqused_k + # still need the all-masked path to produce neutral partials. + has_work_load = count_raw_load > Int32(0) + num_q_groups_load = ( + count_raw_load + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + q_group_start = Int32(0) + if const_expr(self.use_q_gather4): + self._wg_load_q_gather4( + mQ_2d, + mQ_gather4_desc, + mK2qQSplitIndices, + sQIdxMeta, + sQ, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + else: + self._wg_load_q_tma( + tma_atom_Q, + mQ_2d, + mK2qQSplitIndices, + sQLoadMIdx, + sQIdxMeta, + sQ_load, + pipeline_q, + load_wg_barrier, + q_group_start, + num_q_groups_load, + count_raw_load, + has_work_load, + head_kv_idx, + row_start_load, + q_batch_offset_load, + num_heads_kv, + True, + ) + + if ( + warp_idx >= Int32(self.kv_load_warp_base) + and warp_idx < Int32(self.kv_load_warp_base + self.num_kv_load_warps) + and cta_valid_work + ): + cute.arch.setmaxregister_decrease(self.num_regs_load) + kv_block_idx_load = sRowMeta[1] + k_batch_offset_load = sRowMeta[6] + has_work_load = sRowMeta[3] > Int32(0) + self._wg_load_kv( + tma_atom_K_fp4, tma_atom_V_fp4, + mK, mV, + mKScale, mVScale, + mKGlobalScale, mVGlobalScale, + sPagedKvIdx, + sKFp4, sVFp4, sK, sV, + mbar_k_tma_ptr, mbar_v_tma_ptr, + mbar_k_ptr, mbar_v_ptr, + kv_load_barrier, + has_work_load, + head_kv_idx, kv_block_idx_load, + k_batch_offset_load, + num_heads_kv, + ) + + if warp_idx == Int32(self.mma_warp_id) and cta_valid_work: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + count_raw_mma = sRowMeta[3] + has_work_mma = count_raw_mma > Int32(0) + num_q_groups_mma = ( + count_raw_mma + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_mma_issue( + tiled_mma_qk, tiled_mma_pv, + thr_mma_qk, thr_mma_pv, + tStS, tOrP, + sK, sV, sQ, + pipeline_q, pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + mbar_k_ptr, mbar_v_ptr, + num_q_groups_mma, has_work_mma, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + cute.arch.griddepcontrol_launch_dependents() + + if ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_k_from_tma_staging( + mKScale, + mKGlobalScale, + sPagedKvIdx, + sKFp4, sK, + mbar_k_tma_ptr, + mbar_k_ptr, + kv_dequant_k_barrier, + has_work_softmax, + self.softmax0_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 0, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + + if ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.store_warp_base) + and cta_valid_work + ): + cute.arch.setmaxregister_increase(self.num_regs_softmax) + kv_block_idx_softmax = sRowMeta[1] + count_raw_softmax = sRowMeta[3] + kv_valid_cols_softmax = sRowMeta[4] + causal_q_offset_softmax = sRowMeta[7] + has_work_softmax = count_raw_softmax > Int32(0) + num_q_groups_softmax = ( + count_raw_softmax + Int32(self.q_tokens_per_group - 1) + ) // Int32(self.q_tokens_per_group) + diag_q_count_softmax = sDiagQCount[0] + self._wg_dequant_v_from_tma_staging( + mVScale, + mVGlobalScale, + sPagedKvIdx, + sVFp4, sV, + mbar_v_tma_ptr, + mbar_v_ptr, + kv_dequant_v_barrier, + has_work_softmax, + self.softmax1_warp_base, + self.warps_per_group, + head_kv_idx, kv_block_idx_softmax, + sRowMeta[6], + num_heads_kv, + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + self._wg_softmax( + 1, + tiled_mma_qk, + tiled_mma_pv, + tStS, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_s, pipeline_p, pipeline_p_lastsplit, + pipeline_o, pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, mLSE_partial, mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + lse_temperature_scale, + mKGlobalScale, + mVGlobalScale, + kv_block_idx_softmax, + kv_valid_cols_softmax, + diag_q_count_softmax, + num_q_groups_softmax, count_raw_softmax, has_work_softmax, + causal_q_offset_softmax, + sRowMeta[0], + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + sRowMeta[5], + mQ_2d, + ) + tmem_alloc_barrier.arrive() + # ------------------------------------------------------------------ + # Warp-specialized helpers + # ------------------------------------------------------------------ + + @cute.jit + def _scale_128x4_offset( + self, + row: Int32, + col: Int32, + scale_cols: cutlass.Constexpr[int], + ) -> Int32: + tiles_n: cutlass.Constexpr[int] = (scale_cols + 3) // 4 + tile_m = row // Int32(128) + tile_n = col // Int32(4) + outer = row % Int32(128) + inner = col % Int32(4) + return ( + (tile_m * Int32(tiles_n) + tile_n) * Int32(512) + + (outer % Int32(32)) * Int32(16) + + (outer // Int32(32)) * Int32(4) + + inner + ) + + @cute.jit + def _load_scale_bf16x2( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return utils.cvt_fp8_e4m3_to_bf16x2_replicated(cutlass.Int32(scale_byte)) + + @cute.jit + def _load_scale_e4m3_u8( + self, + scale: cute.Tensor, + logical_row: Int32, + scale_col: Int32, + ) -> Int32: + scale_offset = self._scale_128x4_offset( + logical_row, + scale_col, + self.head_dim // 16, + ) + scale_ptr = cute.make_ptr( + cutlass.Uint8, + scale.iterator.toint() + Int64(scale_offset), + mem_space=scale.iterator.memspace, + assumed_align=1, + ) + scale_byte = cute.make_tensor(scale_ptr, cute.make_layout(1))[0] + return cutlass.Int32(scale_byte) + + @cute.jit + def _dequant_fp4x16_to_bf16( + self, + src_words: cute.Tensor, + combined_scale_bf16x2: Int32, + dst: cute.Tensor, + ): + r_bf16 = cute.make_rmem_tensor((2,), cutlass.BFloat16) + r_bf16_i32 = cute.recast_tensor(r_bf16, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3 = utils.cvt_fp4x8_e2m1_bf16x8( + src_words[word_idx] + ) + bf16_pairs = (bf16_pair0, bf16_pair1, bf16_pair2, bf16_pair3) + for pair_idx in cutlass.range_constexpr(4): + r_bf16_i32[0] = utils.mul_bf16x2( + bf16_pairs[pair_idx], + combined_scale_bf16x2, + ) + dst[word_idx * 8 + 2 * pair_idx + 0] = r_bf16[0] + dst[word_idx * 8 + 2 * pair_idx + 1] = r_bf16[1] + + @cute.jit + def _dequant_fp4x16_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + + @cute.jit + def _dequant_fp4x32_to_fp8( + self, + src_words: cute.Tensor, + scale_e4m3_lo: Int32, + scale_e4m3_hi: Int32, + dst: cute.Tensor, + ): + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx], + scale_e4m3_lo, + ) + dst_i32[word_idx * 2] = fp8_lo + dst_i32[word_idx * 2 + 1] = fp8_hi + for word_idx in cutlass.range_constexpr(2): + fp8_lo, fp8_hi = utils.cvt_fp4x8_e2m1_scaled_e4m3x8( + src_words[word_idx + 2], + scale_e4m3_hi, + ) + dst_i32[word_idx * 2 + 4] = fp8_lo + dst_i32[word_idx * 2 + 5] = fp8_hi + + @cute.jit + def _flat_kv_scale_row( + self, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return token_idx * num_heads_kv + head_kv_idx + + @cute.jit + def _paged_kv_scale_row( + self, + page_idx: Int32, + token_idx: Int32, + head_kv_idx: Int32, + num_heads_kv: Int32, + ) -> Int32: + return (page_idx * num_heads_kv + head_kv_idx) * Int32(self.page_size) + token_idx + + @cute.jit + def _load_k_fp4_to_smem( + self, + sKFp4: cute.Tensor, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sK: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mKScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sK_vec = sK[(row, None), 0, pair_col, 0] + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.k_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.k_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.k_dtype, + num_bits_per_copy=elems_per_store * self.k_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sKFp4.iterator.toint() + Int64(smem_offset), + mem_space=sKFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mKScale, scale_row, scale_col) + else: + combined_bf16x2 = self._load_scale_bf16x2(mKScale, scale_row, scale_col) + if const_expr(self.has_k_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mKGlobalScale[0], + mKGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + sK_cols = sK[(row, None), 0, scale_col // Int32(2), 0] + sK_vec = cute.local_tile( + sK_cols, + (elems_per_block,), + (scale_col % Int32(2),), + ) + else: + sK_vec = sK[ + (row, None), + 0, + (scale_col % Int32(4), scale_col // Int32(4)), + 0, + ] + if const_expr(self.k_dtype == cutlass.Float8E4M3FN): + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sK_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _load_v_fp4_to_smem( + self, + sVFp4: cute.Tensor, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sV: cute.Tensor, + lane: Int32, + warp_idx_in_wg: Int32, + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + elems_per_block: cutlass.Constexpr[int] = 16 + bytes_per_block: cutlass.Constexpr[int] = 8 + scale_cols: cutlass.Constexpr[int] = self.head_dim // elems_per_block + token_in_block_base = kv_block_idx * Int32(self.n_block_size) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN and self.fp8_pair_dequant): + elems_per_pair: cutlass.Constexpr[int] = 32 + bytes_per_pair: cutlass.Constexpr[int] = 16 + scale_pairs: cutlass.Constexpr[int] = scale_cols // 2 + r_words_pair = cute.make_rmem_tensor((bytes_per_pair // 4,), cutlass.Int32) + r_vals_pair = cute.make_rmem_tensor((elems_per_pair,), cutlass.Float8E4M3FN) + g2r_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_pair * cutlass.Uint8.width, + ) + r2s_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + num_bits_per_copy=16 * cutlass.Float8E4M3FN.width, + ) + total_pair_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_pairs + pair_task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + pair_task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(pair_task, total_pair_tasks, pair_task_stride, unroll=1): + row = task_idx // Int32(scale_pairs) + pair_col = task_idx - row * Int32(scale_pairs) + scale_col = pair_col * Int32(2) + byte_col = pair_col * Int32(bytes_per_pair) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_pair, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_pair // 4)) + cute.copy(g2r_pair_atom, s_vec, r_words_pair) + scale_e4m3_lo = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + scale_e4m3_hi = self._load_scale_e4m3_u8( + mVScale, + scale_row, + scale_col + Int32(1), + ) + self._dequant_fp4x32_to_fp8( + r_words_pair, + scale_e4m3_lo, + scale_e4m3_hi, + r_vals_pair, + ) + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_pair,), (pair_col,)) + r_tiles = cute.logical_divide(r_vals_pair, cute.make_layout(16)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(16)) + for v in cutlass.range_constexpr(elems_per_pair // 16): + cute.copy(r2s_pair_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + return + r_words = cute.make_rmem_tensor((bytes_per_block // 4,), cutlass.Int32) + r_vals = cute.make_rmem_tensor((elems_per_block,), self.v_dtype) + g2r_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=bytes_per_block * cutlass.Uint8.width, + ) + elems_per_store: cutlass.Constexpr[int] = ( + 16 if self.v_dtype == cutlass.Float8E4M3FN else 8 + ) + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.v_dtype, + num_bits_per_copy=elems_per_store * self.v_dtype.width, + ) + total_tasks: cutlass.Constexpr[int] = self.n_block_size * scale_cols + task_stride: cutlass.Constexpr[int] = num_dequant_warps * cute.arch.WARP_SIZE + task = warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + lane + for task_idx in cutlass.range(task, total_tasks, task_stride, unroll=1): + row = task_idx // Int32(scale_cols) + scale_col = task_idx - row * Int32(scale_cols) + byte_col = scale_col * Int32(bytes_per_block) + token = token_in_block_base + row + if const_expr(self.paged_kv): + page = sPagedKvIdx[0] + scale_row = self._paged_kv_scale_row( + page, + row, + head_kv_idx, + num_heads_kv, + ) + else: + token = k_batch_offset + token + scale_row = self._flat_kv_scale_row( + token, + head_kv_idx, + num_heads_kv, + ) + smem_offset = row * Int32(self.head_dim // 2) + byte_col + s_ptr = cute.make_ptr( + cutlass.Int32, + sVFp4.iterator.toint() + Int64(smem_offset), + mem_space=sVFp4.iterator.memspace, + assumed_align=bytes_per_block, + ) + s_vec = cute.make_tensor(s_ptr, cute.make_layout(bytes_per_block // 4)) + cute.copy(g2r_atom, s_vec, r_words) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + scale_e4m3 = self._load_scale_e4m3_u8(mVScale, scale_row, scale_col) + self._dequant_fp4x16_to_fp8(r_words, scale_e4m3, r_vals) + else: + combined_bf16x2 = self._load_scale_bf16x2(mVScale, scale_row, scale_col) + if const_expr(self.has_v_global_scale): + global_bf16x2 = utils.cvt_f16x2_f32( + mVGlobalScale[0], + mVGlobalScale[0], + cutlass.BFloat16, + ) + combined_bf16x2 = utils.mul_bf16x2(combined_bf16x2, global_bf16x2) + self._dequant_fp4x16_to_bf16( + r_words, + combined_bf16x2, + r_vals, + ) + if const_expr(self.v_dtype == cutlass.Float8E4M3FN): + sV_cols = sV[(None, row % Int32(32)), 0, row // Int32(32), 0] + else: + sV_cols = sV[(None, row % Int32(16)), 0, row // Int32(16), 0] + sV_vec = cute.local_tile(sV_cols, (elems_per_block,), (scale_col,)) + r_tiles = cute.logical_divide(r_vals, cute.make_layout(elems_per_store)) + s_tiles = cute.logical_divide(sV_vec, cute.make_layout(elems_per_store)) + for v in cutlass.range_constexpr(elems_per_block // elems_per_store): + cute.copy(r2s_atom, r_tiles[None, v], s_tiles[None, v]) + cute.arch.fence_view_async_shared() + + @cute.jit + def _wg_load_kv( + self, + tma_atom_K_fp4, + tma_atom_V_fp4, + mK: cute.Tensor, + mV: cute.Tensor, + mKScale: cute.Tensor, + mVScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sVFp4: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mbar_k_tma_ptr, + mbar_v_tma_ptr, + mbar_k_ptr, + mbar_v_ptr, + kv_load_barrier, + has_work: Int32, + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(self.kv_load_warp_base) + + if has_work: + if warp_idx_in_wg == Int32(0): + if const_expr(self.paged_kv): + mK_cur = mK[None, None, head_kv_idx, sPagedKvIdx[0]] + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = Int32(0) + else: + mK_cur = cute.domain_offset( + (k_batch_offset, 0), + mK[None, None, head_kv_idx], + ) + gK = cute.local_tile( + mK_cur, + (self.n_block_size, self.head_dim // 2), + (None, 0), + ) + src_idx = kv_block_idx + load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K_fp4, + 0, + cute.make_layout(1), + gK, + sKFp4, + ) + load_K_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_k_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_tma_ptr) + + if warp_idx_in_wg == Int32(1): + if const_expr(self.paged_kv): + mV_cur = mV[None, None, head_kv_idx, sPagedKvIdx[0]] + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = Int32(0) + else: + mV_cur = cute.domain_offset( + (0, k_batch_offset), + mV[None, None, head_kv_idx], + ) + gV = cute.local_tile( + mV_cur, + (self.head_dim // 2, self.n_block_size), + (0, None), + ) + src_idx = kv_block_idx + load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V_fp4, + 0, + cute.make_layout(1), + gV, + sVFp4, + ) + load_V_fn( + src_idx=src_idx, + dst_idx=0, + tma_bar_ptr=mbar_v_tma_ptr, + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_tma_ptr) + + kv_load_barrier.arrive_and_wait() + + @cute.jit + def _wg_dequant_k_from_tma_staging( + self, + mKScale: cute.Tensor, + mKGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sKFp4: cute.Tensor, + sK: cute.Tensor, + mbar_k_tma_ptr, + mbar_k_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_k_tma_ptr, 0) + self._load_k_fp4_to_smem( + sKFp4, + mKScale, + mKGlobalScale, + sPagedKvIdx, + sK, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_k_ptr) + + @cute.jit + def _wg_dequant_v_from_tma_staging( + self, + mVScale: cute.Tensor, + mVGlobalScale: Optional[cute.Tensor], + sPagedKvIdx: cute.Tensor, + sVFp4: cute.Tensor, + sV: cute.Tensor, + mbar_v_tma_ptr, + mbar_v_ptr, + dequant_barrier, + has_work: Int32, + dequant_warp_base: cutlass.Constexpr[int], + num_dequant_warps: cutlass.Constexpr[int], + head_kv_idx: Int32, + kv_block_idx: Int32, + k_batch_offset: Int32, + num_heads_kv: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx - Int32(dequant_warp_base) + lane = cute.arch.lane_idx() + + if has_work: + cute.arch.mbarrier_wait(mbar_v_tma_ptr, 0) + self._load_v_fp4_to_smem( + sVFp4, + mVScale, + mVGlobalScale, + sPagedKvIdx, + sV, + lane, + warp_idx_in_wg, + num_dequant_warps, + head_kv_idx, + kv_block_idx, + k_batch_offset, + num_heads_kv, + ) + dequant_barrier.arrive_and_wait() + if warp_idx_in_wg == Int32(0): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(mbar_v_ptr) + + @cute.jit + def _wg_load_q_gather4( + self, + mQ_2d: cute.Tensor, + mQ_gather4_desc: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + tidx = cute.arch.thread_idx()[0] + q_load_thread_base = Int32(self.q_load_warp_base * cute.arch.WARP_SIZE) + group_tidx = tidx - q_load_thread_base + producer_warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + gathers_per_warp: cutlass.Constexpr[int] = self.m_block_size // ( + self.num_q_load_warps * 4) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if producer_warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + group_tidx = tidx - q_load_thread_base + warp_idx_in_wg = cute.arch.make_warp_uniform( + group_tidx // Int32(cute.arch.WARP_SIZE) + ) + lane_idx = group_tidx % Int32(cute.arch.WARP_SIZE) + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + meta_iters: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + + self.num_q_load_warps * cute.arch.WARP_SIZE - 1) + // (self.num_q_load_warps * cute.arch.WARP_SIZE) + ) + for meta_iter in cutlass.range_constexpr(meta_iters): + tok_idx_g4 = ( + (Int32(meta_iter) * Int32(self.num_q_load_warps) + + warp_idx_in_wg) + * Int32(cute.arch.WARP_SIZE) + + lane_idx + ) + if tok_idx_g4 < Int32(self.q_tokens_per_group): + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx_g4 + if qi < count_raw: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx_g4] = Int32(0) + load_wg_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + q_desc_ptr = mQ_gather4_desc.iterator + sQ_ptr = sQ.iterator + for gather_slot in cutlass.range_constexpr(gathers_per_warp): + gather_idx = ( + Int32(gather_slot) * Int32(self.num_q_load_warps) + + warp_idx_in_wg + ) + tok_base = gather_idx * Int32(self.tokens_per_gather4) + if const_expr(self.qheadperkv == 1): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + qi2 = qi0 + Int32(2) + qi3 = qi0 + Int32(3) + row0 = q_oob_m_idx + row1 = q_oob_m_idx + row2 = q_oob_m_idx + row3 = q_oob_m_idx + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row0 = (q_batch_offset + q_idx0) * num_heads_kv + head_kv_idx + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row1 = (q_batch_offset + q_idx1) * num_heads_kv + head_kv_idx + if qi2 < count_raw: + q_idx2 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(2)]) + row2 = (q_batch_offset + q_idx2) * num_heads_kv + head_kv_idx + if qi3 < count_raw: + q_idx3 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(3)]) + row3 = (q_batch_offset + q_idx3) * num_heads_kv + head_kv_idx + elif const_expr(self.qheadperkv == 2): + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + qi1 = qi0 + Int32(1) + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + row_base1 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + if qi1 < count_raw: + q_idx1 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base + Int32(1)]) + row_base1 = ( + (q_batch_offset + q_idx1) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base1 + row3 = row_base1 + Int32(1) + else: + qi0 = qi_group * Int32(self.q_tokens_per_group) + tok_base + row_base0 = q_oob_m_idx * Int32(self.qheadperkv) + if qi0 < count_raw: + q_idx0 = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_base]) + row_base0 = ( + (q_batch_offset + q_idx0) * num_heads_kv + + head_kv_idx + ) * Int32(self.qheadperkv) + row0 = row_base0 + row1 = row_base0 + Int32(1) + row2 = row_base0 + Int32(2) + row3 = row_base0 + Int32(3) + group_byte_off = gather_idx * Int32( + 4 * self.k_tile * (self.q_dtype.width // 8) + ) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + stage_byte_off = slot * Int32(self.q_stage_stride_bytes) + full_group_byte_off = gather_idx * Int32( + 4 * self.head_dim * (self.q_dtype.width // 8) + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + full_group_byte_off, + q_desc_ptr, + Int32(0), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + else: + for ks_c in cutlass.range_constexpr(self.k_stages): + stage_idx = slot * Int32(self.k_stages) + Int32(ks_c) + stage_byte_off = stage_idx * Int32(self.k_tile_stride_bytes) + if const_expr(ks_c + 1 < self.k_stages): + tma_gather4_prefetch( + q_desc_ptr, + Int32((ks_c + 1) * self.k_tile), + row0, + row1, + row2, + row3, + TMA_CACHE_EVICT_LAST, + ) + tma_gather4_cached( + sQ_ptr, + stage_byte_off + group_byte_off, + q_desc_ptr, + Int32(ks_c * self.k_tile), + row0, + row1, + row2, + row3, + mbar_ptr, + TMA_CACHE_EVICT_LAST, + ) + load_wg_barrier.arrive_and_wait() + + if const_expr(do_final_acquire) and producer_warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_load_q_tma( + self, + tma_atom_Q, + mQ_2d: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + sQLoadMIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + sQ_load: cute.Tensor, + pipeline_q, + load_wg_barrier, + qi_group_start: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + head_kv_idx: Int32, + row_start: Int32, + q_batch_offset: Int32, + num_heads_kv: Int32, + do_final_acquire: cutlass.Constexpr[bool], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + lane_idx = cute.arch.lane_idx() + warp_idx_in_wg = warp_idx - Int32(self.q_load_warp_base) + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + gQ_full = cute.local_tile( + mQ_2d, (self.qheadperkv, self.head_dim), (None, 0)) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_full, sQ_load) + load_Q_fn_k0, load_Q_fn_k1 = None, None + else: + gQ_k0 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 0)) + gQ_k1 = cute.local_tile(mQ_2d, (self.qheadperkv, self.k_tile), (None, 1)) + load_Q_fn_k0, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k0, sQ_load) + load_Q_fn_k1, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ_k1, sQ_load) + load_Q_fn_full = None + q_oob_m_idx = mQ_2d.shape[0] // Int32(self.qheadperkv) + tokens_per_warp: cutlass.Constexpr[int] = ( + (self.q_tokens_per_group + self.num_q_load_warps - 1) + // self.num_q_load_warps + ) + + if has_work: + for qi_group_rel in cutlass.range(num_q_groups, unroll=1): + qi_group = qi_group_start + qi_group_rel + slot = qi_group % Int32(self.q_stage) + phase = (qi_group // Int32(self.q_stage)) & Int32(1) + producer_phase = phase ^ Int32(1) + if warp_idx_in_wg == Int32(0): + pipeline_q.producer_acquire_w_index_phase( + slot, producer_phase) + load_wg_barrier.arrive_and_wait() + + mbar_ptr = pipeline_q.sync_object_full.get_barrier(slot) + q_load_subtiles_per_token = ( + 1 if const_expr(self.q_dtype == cutlass.Float8E4M3FN) + else self.k_stages + ) + sub_stage_base = slot * Int32( + self.q_tokens_per_group * q_load_subtiles_per_token) + load_meta_slot = slot * Int32(self.q_tokens_per_group) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + if warp_idx_in_wg == Int32(0) and lane_idx < Int32(self.q_tokens_per_group): + tok_idx = lane_idx + qi = qi_group * Int32(self.q_tokens_per_group) + tok_idx + if qi < count_raw: + qsplit = self._load_qsplit_idx( + mK2qQSplitIndices, head_kv_idx, row_start, qi + ) + q_idx = self._decode_q_idx_from_qsplit(qsplit) + q_abs = q_batch_offset + q_idx + sQIdxMeta[qidx_meta_slot + tok_idx] = qsplit + sQLoadMIdx[load_meta_slot + tok_idx] = ( + q_abs * num_heads_kv + head_kv_idx + ) + else: + sQIdxMeta[qidx_meta_slot + tok_idx] = Int32(0) + sQLoadMIdx[load_meta_slot + tok_idx] = q_oob_m_idx + load_wg_barrier.arrive_and_wait() + + for qi_slot in cutlass.range_constexpr(tokens_per_warp): + tok_idx = ( + warp_idx_in_wg * Int32(tokens_per_warp) + + Int32(qi_slot) + ) + if tok_idx < Int32(self.q_tokens_per_group): + m_tile_idx = sQLoadMIdx[load_meta_slot + tok_idx] + if const_expr(self.q_dtype == cutlass.Float8E4M3FN): + load_Q_fn_full( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + else: + load_Q_fn_k0( + src_idx=m_tile_idx, + dst_idx=sub_stage_base + tok_idx, + tma_bar_ptr=mbar_ptr, + ) + load_Q_fn_k1( + src_idx=m_tile_idx, + dst_idx=( + sub_stage_base + + Int32(self.q_tokens_per_group) + + tok_idx + ), + tma_bar_ptr=mbar_ptr, + ) + + if const_expr(do_final_acquire) and warp_idx_in_wg == Int32(0): + next_qi_group = qi_group_start + num_q_groups + next_slot = next_qi_group % Int32(self.q_stage) + next_phase = ( + (next_qi_group // Int32(self.q_stage)) & Int32(1) + ) ^ Int32(1) + pipeline_q.producer_acquire_w_index_phase(next_slot, next_phase) + + @cute.jit + def _wg_mma_issue( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + thr0_qk: cute.core.ThrMma, + thr0_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOrP: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ: cute.Tensor, + pipeline_q, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + mbar_k_ptr, + mbar_v_ptr, + num_q_groups: Int32, + has_work: Int32, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + + if is_mma_warp: + if has_work: + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0 = tSrQ[None, None, None, 0] + tSrK0 = tSrK[None, None, None, 0] + tOrV = tiled_mma_pv.make_fragment_B(sV) + tOrV0 = tOrV[None, None, None, 0] + sV0 = sV[None, None, None, 0] + pv_mma_op = tiled_mma_pv.op + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[None, None, None, 0].iterator) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, self.q_stage - 1].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, + q_smem_base, + tSrQ0.layout, + var_name_prefix="lean_q_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="lean_qk_idesc") + sQ_stage_stride = ( + sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 + q_wrap_offset = -(self.q_stage - 1) * sQ_stage_stride + q_advance_offset = sQ_stage_stride + gemm_qk_s0_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s0_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_wrap = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_wrap_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_qk_s1_advance = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + Int32(self.tmem_stage_stride + self.tmem_s_offset), + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0.layout, + smem_var_name_prefix="lean_q_desc", + idesc_var_name="lean_qk_idesc", + smem_offset=q_advance_offset, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_0 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset), + tOrP[None, None, None, 0], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + gemm_pv_1 = partial( + sm100_helpers.gemm_ptx_partial, + pv_mma_op, + Int32(self.tmem_o_offset + self.tmem_o_stage_stride), + tOrP[None, None, None, 1], + sA=None, + split_arrive=( + self.split_P_arrive if self.split_P_arrive > 0 else None), + tA_addr=Int32(self.tmem_stage_stride + self.tmem_p_offset), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + cute.arch.mbarrier_wait(mbar_k_ptr, 0) + # Issue order: + # Q0K, Q1K, P0V, Q2K, P1V, Q3K, ... + # This reuses each slot as soon as its previous PV drains, + # instead of batching both PVs after both QKs of a pair. + # The schedule is still 2-slot safe: + # - QK(qi) consumes slot qi&1 + # - PV(qi-2) frees the same slot before QK(qi) reuses it + # - phases still toggle every 2 groups per slot + + # Prologue: issue up to the first two QK tiles. Q slots come + # from the q_stage ring; S slots remain a 2-slot ring. + pipeline_q.consumer_wait_w_index_phase(Int32(0), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(0), Int32(1)) + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(0)) + pipeline_q.consumer_release_w_index(Int32(0)) + + if num_q_groups > Int32(1): + pipeline_q.consumer_wait_w_index_phase(Int32(1), Int32(0)) + pipeline_s.producer_acquire_w_index_phase(Int32(1), Int32(1)) + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(Int32(1)) + pipeline_q.consumer_release_w_index(Int32(1)) + + cute.arch.mbarrier_wait(mbar_v_ptr, 0) + + # Steady-state: for qi >= 2, reuse the S/P slot as a BLK128- + # style handoff. The MMA warp waits for softmax to release + # the slot after P is visible, issues PV, then immediately + # reuses the same acquired slot for the next QK. + for qi in cutlass.range(Int32(2), num_q_groups, unroll=1): + pv_qi = qi - Int32(2) + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + q_slot = qi % Int32(self.q_stage) + q_phase = (qi // Int32(self.q_stage)) & Int32(1) + s_slot = qi & Int32(1) + s_phase = (qi // Int32(2)) & Int32(1) + pipeline_q.consumer_wait_w_index_phase(q_slot, q_phase) + pipeline_s.producer_acquire_w_index_phase( + s_slot, s_phase ^ Int32(1)) + if s_slot == Int32(0): + if q_slot == Int32(0): + gemm_qk_s0_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s0_advance(smem_desc_start_b=k_smem_start) + else: + if q_slot == Int32(0): + gemm_qk_s1_wrap(smem_desc_start_b=k_smem_start) + else: + gemm_qk_s1_advance(smem_desc_start_b=k_smem_start) + pipeline_s.producer_commit_w_index(s_slot) + pipeline_q.consumer_release_w_index(q_slot) + + # Drain the remaining one or two PV tiles. + drain_begin = Int32(0) if num_q_groups == Int32(1) else num_q_groups - Int32(2) + for pv_qi in cutlass.range(drain_begin, num_q_groups, unroll=1): + pv_slot = pv_qi & Int32(1) + pv_phase = (pv_qi // Int32(2)) & Int32(1) + pipeline_p.consumer_wait_w_index_phase(pv_slot, pv_phase) + pipeline_o.producer_acquire_w_index_phase( + pv_slot, pv_phase ^ Int32(1)) + if pv_slot == Int32(0): + gemm_pv_0( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + else: + gemm_pv_1( + tCrB=tOrV0, + sB=sV0, + mbar_ptr=( + pipeline_p_lastsplit.sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None), + mbar_phase=(pv_phase if self.split_P_arrive > 0 else None), + zero_init=True, + ) + pipeline_o.producer_commit_w_index(pv_slot) + if cutlass.const_expr(self.split_P_arrive > 0): + pipeline_p_lastsplit.consumer_release_w_index(pv_slot) + pipeline_p.consumer_release_w_index(pv_slot) + + @cute.jit + def _softmax_step( + self, + slot: cutlass.Constexpr[int], + s_consumer_phase: Int32, + p_producer_phase: Int32, + sm_stats_producer_phase: Int32, + softmax: SoftmaxSm100, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx: Int32, + thr_tmem_load, + thr_tmem_store, + tStS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tScS_t2r: cute.Tensor, + tScP_shape, + sQIdxMeta: cute.Tensor, + qidx_meta_slot: Int32, + group_tidx: Int32, + masked_tok_count: Int32, + kv_block_col_start: Int32, + seq_len_q: Int32, + causal_q_offset: Int32, + kv_valid_cols: Int32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + return_temperature_lse: cutlass.Constexpr[bool], + apply_causal_mask: cutlass.Constexpr[bool] = False, + signal_stats_barrier: cutlass.Constexpr[bool] = True, + ): + slot_rt = Int32(slot) + + pipeline_s.consumer_wait_w_index_phase(slot_rt, s_consumer_phase) + + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, + self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_k_global_scale + ): + k_global = mKGlobalScale[0] + for i in cutlass.range_constexpr(0, cute.size(tSrS_t2r.shape), 2): + tSrS_t2r[i], tSrS_t2r[i + 1] = cute.arch.mul_packed_f32x2( + (tSrS_t2r[i], tSrS_t2r[i + 1]), + (k_global, k_global), + ) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seq_len_q, + seq_len_q + causal_q_offset, + False, + False, + False, + False, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + ) + if const_expr(self.causal and apply_causal_mask): + need_causal_mask = masked_tok_count > Int32(0) + if need_causal_mask: + tok_idx = group_tidx // Int32(self.qheadperkv) + q_idx = self._decode_q_idx_from_qsplit( + sQIdxMeta[qidx_meta_slot + tok_idx]) + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=True, + row_idx=q_idx, + kv_valid_cols=kv_valid_cols, + kv_block_col_start=kv_block_col_start, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + else: + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=Int32(0), + n_block=Int32(0), + mask_seqlen=True, + mask_causal=False, + kv_valid_cols=kv_valid_cols, + ) + + # Each sparse CTA computes exactly one KV block for the current Q group, + # so full-tile softmax is always the first and only online-softmax step. + row_max, _ = softmax.update_row_max(tSrS_t2r.load(), True) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + if const_expr(return_temperature_lse): + lse_temperature_row_sum = softmax.compute_scaled_exp2_row_sum( + tSrS_t2r, + lse_temperature_scale, + ) + + if cutlass.const_expr(self.split_P_arrive > 0): + # This full barrier is the late-P handoff consumed inside + # gemm_ptx_partial after its early PV k-slices are issued. + pipeline_p_lastsplit.producer_acquire_w_index_phase( + slot_rt, p_producer_phase) + pipeline_p.producer_acquire_w_index_phase(slot_rt, p_producer_phase) + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + ex2_emu_freq=self.ex2_emu_freq, + ex2_emu_start_frg=self.ex2_emu_start_frg, + ) + + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k]) + if cutlass.const_expr(self.split_P_arrive > 0): + split_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive // self.n_block_size) + if cutlass.const_expr(k + 1 == split_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_p.producer_commit_w_index(slot_rt) + cute.arch.fence_view_async_tmem_store() + if cutlass.const_expr(self.split_P_arrive == 0): + pipeline_p.producer_commit_w_index(slot_rt) + else: + pipeline_p_lastsplit.producer_commit_w_index(slot_rt) + pipeline_sm_stats.producer_acquire_w_index_phase( + slot_rt, sm_stats_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), Float32(0.0), True) + del tSrS_t2r + sScale_slot = cute.make_tensor( + sScale.iterator + slot_rt * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_slot[group_tidx] = softmax.row_sum[0] + sScale_slot[group_tidx + Int32(self.m_block_size)] = ( + softmax.row_max[0]) + if const_expr(return_temperature_lse): + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot_rt * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sScale_temperature_slot[group_tidx] = lse_temperature_row_sum + cute.arch.fence_view_async_shared() + + if const_expr(signal_stats_barrier): + sm_stats_barrier.arrive_w_index(index=stats_barrier_idx) + pipeline_s.consumer_release_w_index(slot_rt) + + @cute.jit + def _wg_softmax( + self, + stage: cutlass.Constexpr[int], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tStS: cute.Tensor, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + lse_temperature_scale: Float32, + mKGlobalScale: Optional[cute.Tensor], + mVGlobalScale: Optional[cute.Tensor], + kv_block_idx: Int32, + kv_valid_cols: Int32, + diag_q_count: Int32, + num_q_groups: Int32, + count_raw: Int32, + has_work: Int32, + causal_q_offset: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + mQ_2d: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + warp_idx_in_wg = warp_idx % Int32(self.warps_per_group) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stats_barrier_idx = ( + Int32(stage) * Int32(self.warps_per_group) + warp_idx_in_wg) + + thr0_qk = tiled_mma_qk.get_slice(0) + tScS = thr0_qk.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_tiler_qk[0] // thr0_qk.thr_id.shape, + self.mma_tiler_qk[1]) + tilePlikeFP32 = ( + self.mma_tiler_qk[1] // Float32.width * self.q_dtype.width) + tScP_shape = (cta_qk_tiler[0], tilePlikeFP32) + tSAcc = tStS[(None, None), 0, 0, stage] + + softmax = SoftmaxSm100.create(softmax_scale_log2) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) + # P-store Repetition is dtype-aware: each PV MMA K-segment is + # ``32 / (q_dtype.width / 8)`` fp8/bf16 columns wide, which equals + # ``32 * Float32.width / q_dtype.width`` packed fp32 TMEM columns + # ``// (q_dtype.width / 8)``. Concretely, R=16 packs two bf16 PV K + # segments per chunk (shape[2]=4 ⇒ 3/4 publish boundary aligns), + # while fp8 (PV K=32 fp8 ⇒ 8 fp32 cols) needs R=8 so + # shape[2]=4 and split_idx=3 publishes exactly 24 fp32 cols + # (= 96 fp8 cols = 3 PV K segments) at the early-arrive edge. + store_rep = const_expr( + 8 if self.q_dtype == cutlass.Float8E4M3FN else 16 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(store_rep)), + Float32) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + total_q = mQ_2d.shape[0] // head_q + thr0_pv = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr0_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr0_pv.make_fragment_C(pv_acc_shape) + corr_tile_size = 64 + tOcO = thr0_pv.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + o_tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.pv_acc_dtype) + + if has_work: + kv_block_col_start = Int32(0) + if const_expr(self.causal): + kv_block_col_start = kv_block_idx * Int32(self.n_block_size) + + num_stage_groups = ( + num_q_groups + Int32(1 - stage)) // Int32(2) + for qi_iter in cutlass.range(num_stage_groups, unroll=1): + qi_group = qi_iter * Int32(2) + Int32(stage) + phase = qi_iter & Int32(1) + producer_phase = phase ^ Int32(1) + qidx_meta_slot = ( + (qi_group & Int32(self.qidx_meta_stages - 1)) + * Int32(self.q_tokens_per_group) + ) + + softmax.reset() + + if const_expr(self.causal): + qi_group_start = qi_group * Int32(self.q_tokens_per_group) + masked_tok_count = cutlass.max( + Int32(0), + cutlass.min( + Int32(self.q_tokens_per_group), + diag_q_count - qi_group_start, + ), + ) + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + masked_tok_count, + kv_block_col_start, + seq_len_q, + causal_q_offset, + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + True, + False, + ) + else: + self._softmax_step( + stage, + phase, + producer_phase, + producer_phase, + softmax, + sScale, + sScaleTemperature, + pipeline_s, + pipeline_p, + pipeline_p_lastsplit, + pipeline_sm_stats, + sm_stats_barrier, + stats_barrier_idx, + thr_tmem_load, + thr_tmem_store, + tStS_t2r, + tStP_r2t, + tScS_t2r, + tScP_shape, + sQIdxMeta, + qidx_meta_slot, + group_tidx, + Int32(0), + kv_block_col_start, + seq_len_q, + Int32(0), + kv_valid_cols, + lse_temperature_scale, + mKGlobalScale, + const_expr(mLSE_temperature_partial is not None), + False, + False, + ) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(stage)) + self._epilogue_step( + qi_group, + group_tidx, + warp_idx_in_wg, + tOtO_base, + tOcO_i, + o_tmem_copy_atom, + sScale, + sScaleTemperature, + sSplitIdx, + sQIdx, + sQIdxMeta, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial, + mLSE_partial, + mLSE_temperature_partial, + softmax_scale_log2, + lse_temperature_scale_log2, + mVGlobalScale, + count_raw, + batch_idx, + head_kv_idx, + seq_len_q, + head_q, + num_heads_kv, + q_batch_offset, + total_q, + False, + stage, + ) + + @cute.jit + def _store_o_partial_vec4( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + ): + stg_128_cs(ptr, v0, v1, v2, v3) + + @cute.jit + def _store_o_partial_vec8_half( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + ): + if cutlass.const_expr(self.o_dtype is cutlass.BFloat16): + stg_128_bf16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + else: + stg_128_f16_cs(ptr, v0, v1, v2, v3, v4, v5, v6, v7) + + @cute.jit + def _store_o_partial_vec16_fp8( + self, + ptr: cute.Pointer, + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + v8: Float32, + v9: Float32, + v10: Float32, + v11: Float32, + v12: Float32, + v13: Float32, + v14: Float32, + v15: Float32, + ): + stg_128_fp8_e4m3_cs( + ptr, + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, + v11, + v12, + v13, + v14, + v15, + ) + + @cute.jit + def _epilogue_step( + self, + qi_group: Int32, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tOtO_base: cute.Tensor, + tOcO_i: cute.Tensor, + o_tmem_copy_atom, + sScale: cute.Tensor, + sScaleTemperature: cute.Tensor, + sSplitIdx: cute.Tensor, + sQIdx: cute.Tensor, + sQIdxMeta: cute.Tensor, + pipeline_o, + pipeline_sm_stats, + sm_stats_barrier, + epilogue_barrier, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mLSE_temperature_partial: Optional[cute.Tensor], + softmax_scale_log2: Float32, + lse_temperature_scale_log2: Float32, + mVGlobalScale: Optional[cute.Tensor], + count_raw: Int32, + batch_idx: Int32, + head_kv_idx: Int32, + seq_len_q: Int32, + head_q: Int32, + num_heads_kv: Int32, + q_batch_offset: Int32, + total_q: Int32, + use_stats_barrier: cutlass.Constexpr[bool], + softmax_stage: cutlass.Constexpr[int], + ): + slot = qi_group & Int32(1) + phase = (qi_group // Int32(2)) & Int32(1) + stage_base = slot * Int32(self.tmem_o_stage_stride) + corr_tile_size = 64 + sScale_slot = cute.make_tensor( + sScale.iterator + slot * Int32(self.m_block_size * 2), + cute.make_layout(self.m_block_size * 2), + ) + sScale_temperature_slot = cute.make_tensor( + sScaleTemperature.iterator + slot * Int32(self.m_block_size), + cute.make_layout(self.m_block_size), + ) + sSplitIdx_slot = cute.make_tensor( + sSplitIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + sQIdx_slot = cute.make_tensor( + sQIdx.iterator + slot * Int32(self.q_tokens_per_group), + cute.make_layout((self.q_tokens_per_group,)), + ) + qidx_meta_slot = (qi_group + & Int32(self.qidx_meta_stages - 1)) * Int32( + self.q_tokens_per_group) + + pipeline_o.consumer_wait_w_index_phase(slot, phase) + if const_expr(use_stats_barrier): + sm_stats_barrier.arrive_and_wait_w_index( + index=slot * Int32(self.warps_per_group) + warp_idx_in_wg) + + if group_tidx < Int32(self.q_tokens_per_group): + tok = group_tidx + qi = qi_group * Int32(self.q_tokens_per_group) + tok + if qi < count_raw: + qsplit = sQIdxMeta[qidx_meta_slot + tok] + q_idx = self._decode_q_idx_from_qsplit(qsplit) + sQIdx_slot[tok] = q_idx + sSplitIdx_slot[tok] = self._decode_split_idx_from_qsplit(qsplit) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + tOtO = cute.make_tensor( + tOtO_base.iterator + stage_base + Int32(self.tmem_o_offset), + tOtO_base.layout) + for col_pass_idx in cutlass.range(Int32(2), unroll=1): + col_pass = col_pass_idx * Int32(corr_tile_size) + tOtO_pass_ptr = cute.make_ptr( + self.pv_acc_dtype, + tOtO.iterator.toint() + col_pass, + cute.AddressSpace.tmem, + assumed_align=8, + ) + tOtO_pass = cute.make_tensor(tOtO_pass_ptr, tOtO.layout) + tOtO_pass_i = cute.logical_divide( + tOtO_pass, + cute.make_layout((self.m_block_size, corr_tile_size))) + tiled_tmem_load_pass = tcgen05.make_tmem_copy( + o_tmem_copy_atom, tOtO_pass_i[(None, None), 0]) + thr_tmem_load_pass = tiled_tmem_load_pass.get_slice(group_tidx) + tOtO_t2r_pass = thr_tmem_load_pass.partition_S( + tOtO_pass_i[(None, None), None]) + tOcO_t2r_pass = thr_tmem_load_pass.partition_D( + tOcO_i[(None, None), None]) + + tOtO_t2r_i = tOtO_t2r_pass[None, None, None, 0] + tOcO_t2r_i = tOcO_t2r_pass[None, None, None, 0] + tOrO_frg = cute.make_rmem_tensor_like( + tOcO_t2r_i, self.pv_acc_dtype) + cute.copy(tiled_tmem_load_pass, tOtO_t2r_i, tOrO_frg) + + tOrO_mn = make_16x256b_tensor_mn_view(tOrO_frg) + tOrO_mn = cute.make_tensor( + tOrO_mn.iterator, + cute.select(tOrO_mn.layout, mode=[0, 1])) + tOcO_mn = make_16x256b_tensor_mn_view(tOcO_t2r_i) + tOcO_mn = cute.make_tensor( + tOcO_mn.iterator, + cute.select(tOcO_mn.layout, mode=[0, 1])) + num_rows = cute.size(tOrO_mn, mode=[0]) + num_cols = cute.size(tOrO_mn, mode=[1]) + + for r in cutlass.range_constexpr(num_rows): + if const_expr(self.o_dtype is Float32): + for c4 in cutlass.range_constexpr(num_cols // 4): + c_base = Int32(c4) * Int32(4) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + fake_col = real_col_to_stg128_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec4( + ptr, + o0, + o1, + o2, + o3, + ) + elif const_expr(self.o_dtype in [cutlass.BFloat16, cutlass.Float16]): + assert num_cols % 8 == 0, ( + "half O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 8" + ) + for c8 in cutlass.range_constexpr(num_cols // 8): + c_base = Int32(c8) * Int32(8) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = (qi_group + * Int32(self.q_tokens_per_group) + + tok) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0)) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = ( + flat_row * Int64(self.head_dim)) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2( + (o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2( + (o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2( + (o6, o7), scale_pair) + fake_col = real_col_to_stg128_half_fake_col(col) + ptr = (mO_partial.iterator + row_base_ptr + + Int64(fake_col)) + self._store_o_partial_vec8_half( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + ) + else: + assert num_cols % 16 == 0, ( + "fp8 O_partial STG.128 requires the epilogue " + "TMEM fragment column count to be a multiple of 16" + ) + for c16 in cutlass.range_constexpr(num_cols // 16): + c_base = Int32(c16) * Int32(16) + row_col = tOcO_mn[r, c_base] + row = row_col[0] + col = row_col[1] + col_pass + if row < Int32(self.m_block_size): + tok = row // Int32(self.qheadperkv) + row_in_tok = row - tok * Int32(self.qheadperkv) + qi = ( + qi_group * Int32(self.q_tokens_per_group) + + tok + ) + if qi < count_raw: + q_idx = sQIdx_slot[tok] + split = sSplitIdx_slot[tok] + q_abs = q_batch_offset + q_idx + flat_row = ( + Int64(split) + * Int64(total_q) + * Int64(head_q) + + Int64(q_abs) * Int64(head_q) + + Int64(head_kv_idx) + * Int64(self.qheadperkv) + + Int64(row_in_tok) + ) + row_sum_val = sScale_slot[row] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val + ) + row_scale = cute.arch.rcp_approx( + row_sum_val + if not is_zero_or_nan + else Float32(1.0) + ) + if const_expr( + self.q_dtype == cutlass.Float8E4M3FN + and self.has_v_global_scale + ): + row_scale *= mVGlobalScale[0] + row_base_ptr = flat_row * Int64(self.head_dim) + o0 = tOrO_mn[r, c_base] + o1 = tOrO_mn[r, c_base + Int32(1)] + o2 = tOrO_mn[r, c_base + Int32(2)] + o3 = tOrO_mn[r, c_base + Int32(3)] + o4 = tOrO_mn[r, c_base + Int32(4)] + o5 = tOrO_mn[r, c_base + Int32(5)] + o6 = tOrO_mn[r, c_base + Int32(6)] + o7 = tOrO_mn[r, c_base + Int32(7)] + o8 = tOrO_mn[r, c_base + Int32(8)] + o9 = tOrO_mn[r, c_base + Int32(9)] + o10 = tOrO_mn[r, c_base + Int32(10)] + o11 = tOrO_mn[r, c_base + Int32(11)] + o12 = tOrO_mn[r, c_base + Int32(12)] + o13 = tOrO_mn[r, c_base + Int32(13)] + o14 = tOrO_mn[r, c_base + Int32(14)] + o15 = tOrO_mn[r, c_base + Int32(15)] + scale_pair = (row_scale, row_scale) + o0, o1 = cute.arch.mul_packed_f32x2((o0, o1), scale_pair) + o2, o3 = cute.arch.mul_packed_f32x2((o2, o3), scale_pair) + o4, o5 = cute.arch.mul_packed_f32x2((o4, o5), scale_pair) + o6, o7 = cute.arch.mul_packed_f32x2((o6, o7), scale_pair) + o8, o9 = cute.arch.mul_packed_f32x2((o8, o9), scale_pair) + o10, o11 = cute.arch.mul_packed_f32x2((o10, o11), scale_pair) + o12, o13 = cute.arch.mul_packed_f32x2((o12, o13), scale_pair) + o14, o15 = cute.arch.mul_packed_f32x2((o14, o15), scale_pair) + fake_col = real_col_to_stg128_fp8_fake_col(col) + ptr = ( + mO_partial.iterator + + row_base_ptr + + Int64(fake_col) + ) + self._store_o_partial_vec16_fp8( + ptr, + o0, + o1, + o2, + o3, + o4, + o5, + o6, + o7, + o8, + o9, + o10, + o11, + o12, + o13, + o14, + o15, + ) + cute.arch.fence_view_async_tmem_load() + + tok_local = Int32(group_tidx) // Int32(self.qheadperkv) + h_local = Int32(group_tidx) % Int32(self.qheadperkv) + qi_lse = qi_group * Int32(self.q_tokens_per_group) + tok_local + if qi_lse < count_raw: + row_sum_val = sScale_slot[group_tidx] + row_max_val = sScale_slot[group_tidx + Int32(self.m_block_size)] + is_zero_or_nan = ( + row_sum_val == Float32(0.0) + or row_sum_val != row_sum_val) + LN2 = Float32(math.log(2.0)) + lse_cur = ( + (row_max_val * softmax_scale_log2 + + cute.math.log2(row_sum_val, fastmath=True)) * LN2 + if not is_zero_or_nan else -Float32.inf) + q_idx_lse = sQIdx_slot[tok_local] + h_abs = head_kv_idx * Int32(self.qheadperkv) + h_local + split_lse = sSplitIdx_slot[tok_local] + q_abs_lse = q_batch_offset + q_idx_lse + mLSE_partial[split_lse, q_abs_lse, h_abs] = lse_cur + if const_expr(mLSE_temperature_partial is not None): + row_sum_temperature_val = sScale_temperature_slot[group_tidx] + is_temperature_zero_or_nan = ( + row_sum_temperature_val == Float32(0.0) + or row_sum_temperature_val != row_sum_temperature_val) + lse_temperature_cur = ( + (row_max_val * lse_temperature_scale_log2 + + cute.math.log2(row_sum_temperature_val, fastmath=True)) * LN2 + if not is_temperature_zero_or_nan else -Float32.inf) + mLSE_temperature_partial[split_lse, q_abs_lse, h_abs] = ( + lse_temperature_cur) + epilogue_barrier.arrive_and_wait_w_index(index=Int32(softmax_stage)) + + pipeline_sm_stats.consumer_release_w_index(slot) + pipeline_o.consumer_release_w_index(slot) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/combine.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..a3894130432f6483291fe23c064efa7369f6d509 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd/combine.py @@ -0,0 +1,1498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse forward combine kernel and public launcher. + +This keeps the local fake-layout -> real-layout epilogue needed by the lean +sparse forward path. +""" + +# Modified Step 7: O_out write with SMEM fake->real column permutation. +# O_partial dim is in STG.128 fake layout; O_out dim is real layout. +import math +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, Int64, Boolean, const_expr + +from ....src.common import utils +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + +from ....src.common.pack_gqa import PackGQAComb +from ....src.common.tma_utils import ( + stg128_fake_col_to_real_col, + stg128_fp8_fake_col_to_real_col, + stg128_half_fake_col_to_real_col, +) + + +class SparseAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + tile_m: int = 8, + k_block_size: int = 64, + topk: int = 16, + num_threads: int = 256, + stages: int = 4, + use_pdl: bool = False, + min_blocks_per_mp: int = 0, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param tile_m: m block size + :param k_block_size: k block size + :param topk: exact number of split partials + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.topk = topk + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + self.use_pdl = use_pdl + self.min_blocks_per_mp = min_blocks_per_mp + self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16) + self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + tile_m, + k_block_size, + topk, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [ + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + Float32, + ]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if tile_m % 8 != 0: + return False + if topk > 256: + return False + if (tile_m * topk) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store). + # Keep this independent from O_partial: fp8 partial uses 16 elements + # per 128b transaction, while bf16/fp16 O stores must remain 8-wide. + output_copy_elems = universal_copy_bits // self.dtype.width + assert self.k_block_size % output_copy_elems == 0 + gmem_threads_per_row_o = k_block_gmem // output_copy_elems + assert self.num_threads % gmem_threads_per_row_o == 0 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o), + order=(1, 0), + ) + vO_layout = cute.make_layout((1, output_copy_elems)) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tO_layout, + vO_layout, + ) + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.topk, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1) + ) + + # O_partial staging layout. + if const_expr( + self.dtype_partial + in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN] + ): + smem_layout_atom_o = _get_cpasync_smem_layout_atom( + self.dtype_partial, self.k_block_size + ) + self.smem_layout_o = cute.tile_to_shape( + smem_layout_atom_o, + (self.tile_m, self.k_block_size, self.stages), + (0, 1, 2), + ) + else: + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + mLSE_temperature_partial: Optional[cute.Tensor] = None, + mLSE_temperature: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + mSplitCounts: Optional[cute.Tensor] = None, + mOutputScale: Optional[cute.Tensor] = None, + qhead_per_kvhead: Int32 = Int32(1), + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr( + mLSE_temperature_partial is not None + and mLSE_temperature_partial.element_type not in [Float32] + ): + raise TypeError("temperature LSE partial tensor must be Float32") + if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]): + raise TypeError("temperature LSE tensor must be Float32") + if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)): + raise ValueError( + "temperature LSE partial and output tensors must either both be provided or both be None" + ) + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]): + raise ValueError( + "temperature LSE partial tensor must have 3 or 4 dimensions: " + "(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]): + raise ValueError( + "temperature LSE tensor must have 2 or 3 dimensions: " + "(batch, seqlen, nheads) or (total_q, nheads)" + ) + if const_expr(mSplitCounts is not None): + if const_expr(mSplitCounts.element_type not in [Int32]): + raise TypeError("split_counts tensor must be Int32") + if const_expr(cu_seqlens is not None): + if const_expr(len(mSplitCounts.shape) != 2): + raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)") + elif const_expr(len(mSplitCounts.shape) != 3): + raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)") + if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]): + raise TypeError("output_scale tensor must be Float32") + + mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b) + # Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads. + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + mLSE_temperature_partial = ( + cute.make_tensor( + mLSE_temperature_partial.iterator, + cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose), + ) + if mLSE_temperature_partial is not None + else None + ) + mLSE_temperature = ( + cute.make_tensor( + mLSE_temperature.iterator, + cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose), + ) + if mLSE_temperature is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + # Output-dtype permutation buffer for Step 7 (tile_m × k_block_size). + # Accumulation stays fp32; the final dtype conversion happens before + # the fake→real SMEM scatter to reduce half-output SMEM pressure. + if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]): + smem_layout_perm = cute.make_layout( + (self.tile_m, self.k_block_size), + stride=(self.k_block_size + 16, 1), + ) + else: + smem_layout_perm = cute.make_ordered_layout( + (self.tile_m, self.k_block_size), order=(1, 0) + ) + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sLSETemperature: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + sO_perm: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch) + # Head separated from seqlen → enables future TMA (contiguous Sq tiles) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + varlen_batch_idx, + semaphore_to_reset, + mSplitCounts, + mOutputScale, + qhead_per_kvhead, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + smem_layout_perm, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + self.use_pdl, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + min_blocks_per_mp=self.min_blocks_per_mp, + use_pdl=self.use_pdl, + ) + + @cute.jit + def decode_flat_row_idx( + self, + idx: Int32, + head_divmod: FastDivmodDivisor, + ): + """Decode flattened tile rows under the H_q-innermost contract.""" + q_idx_local, head_idx = divmod(idx, head_divmod) + return q_idx_local, head_idx + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSE_temperature_partial: Optional[cute.Tensor], + mLSE_temperature: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + mSplitCounts: Optional[cute.Tensor], + mOutputScale: Optional[cute.Tensor], + qhead_per_kvhead: Int32, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout | cute.ComposedLayout, + smem_layout_perm: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + use_pdl: cutlass.Constexpr[bool], + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm) + + # Handle semaphore reset — wait for dependent grids first + if const_expr(use_pdl and semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 + ): + cute.arch.griddepcontrol_wait() + semaphore_to_reset[0] = 0 + + if const_expr(num_splits_dynamic_ptr is not None): + raise ValueError("K2 combine requires compile-time exact topK") + num_splits = Int32(self.topk) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + output_scale = Float32(1.0) + if const_expr(mOutputScale is not None): + output_scale = mOutputScale[0] + + if const_expr(not varlen) or m_block * self.tile_m < max_idx: + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + if const_expr(use_pdl): + cute.arch.griddepcontrol_wait() + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + # `cLSE` (identity tensor for row/split coord tracking) is reused + # later in steps 4-5, so it must be defined on both branches. + cLSE = cute.make_identity_tensor((self.topk, self.tile_m)) + # Reshape mLSE_partial to PackGQA packed layout and delegate the + # tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq) + # into one compound dim with H_q innermost (stride 1), so thread + # rows that vary along h_pos produce one-sector coalesced reads. + # Non-varlen path only — varlen keeps the original inline loop. + if const_expr(not varlen): + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + # mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose + # [3,0,2,1] on [topK,B,Sq,H_q] and dropping B. + # Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim: + mLSE_partial_reord = cute.make_tensor( + mLSE_partial_cur.iterator, + cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2) + # shape ((H_q, Sq), topK) with H_q innermost. + packgqa = PackGQAComb( + m_block_size=self.tile_m, + head_dim_padded=0, # unused for LSE load + check_hdim_oob=False, # unused for LSE load + qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly + ) + packgqa.load_LSE( + mLSE_partial_packed, + sLSE, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_reord = cute.make_tensor( + mLSE_temperature_partial_cur.iterator, + cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]), + ) + mLSE_temperature_partial_packed = cute.group_modes( + mLSE_temperature_partial_reord, 0, 2) + packgqa.load_LSE( + mLSE_temperature_partial_packed, + sLSE_temperature, + self.topk, + gmem_tiled_copy_LSE, + tidx, + m_block, + num_splits, + seqlen, + head_divmod, + mSplitCounts, + batch_idx, + qhead_per_kvhead, + ) + else: + # Varlen path keeps the same H_q-innermost flat-row contract: + # after transpose [1, 0, 2], mLSE_partial_cur is + # (q_local, split, head). + # mSplitCounts is the authoritative valid-split count per + # packed (q_abs, kv_head); masked splits stay at -inf and + # therefore drop out of the final kernel LSE_out reduction. + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur = seqlen_info.offset_batch( + mLSE_temperature_partial, batch_idx, dim=3) + mLSE_temperature_partial_copy = cute.tiled_divide( + mLSE_temperature_partial_cur, (1,)) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + row_count = ( + mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead] + if const_expr(mSplitCounts is not None) + else num_splits + ) + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + if const_expr(mLSE_temperature_partial is not None): + mLSE_temperature_partial_cur_copy = ( + mLSE_temperature_partial_copy[None, m_idx, None, head_idx]) + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < num_splits and si < row_count: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + if const_expr(mLSE_temperature_partial is not None): + cute.copy( + gmem_thr_copy_LSE, + mLSE_temperature_partial_cur_copy[None, si], + tLSEsLSE_temperature[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + if const_expr(mLSE_temperature_partial is not None): + tLSEsLSE_temperature[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) + + # Precompute per-row values for flattened (q_local, head) tiles. + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate in tile + idx = m_block * self.tile_m + mi + if idx >= max_idx: + tOhidx[m] = -1 + tOmidx[m] = 0 + tOSplitCount[m] = 0 + tOrOptr[m] = cutlass.Int64(0) + else: + tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod) + if const_expr(mSplitCounts is None): + tOSplitCount[m] = num_splits + elif const_expr(cu_seqlens is None): + tOSplitCount[m] = mSplitCounts[ + batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + else: + tOSplitCount[m] = mSplitCounts[ + offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead + ] + tOrOptr[m] = utils.elem_pointer( + mO_partial_cur, + (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]), + ).toint() + + tOpO = None + if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOSplitCount, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + if const_expr(mLSE_temperature_partial is not None): + ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature) + ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature) + cute.copy( + s2r_tiled_copy_LSE, + ts2rsLSE_temperature, + ts2rrLSE_temperature, + ) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row. Invalid splits + # have already been filled with -inf, so Step 5 can write the + # kernel-native LSE_out directly. + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + # Normalize scales + inv_sum = 0.0 + if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur: + final_lse[m] = -Float32.inf + else: + final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + if const_expr(mLSE_temperature_partial is not None): + final_lse_temperature = cute.make_rmem_tensor( + cute.size(ts2rrLSE_temperature, mode=[2]), Float32) + for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_temperature_max = cute.arch.warp_reduction_max( + ts2rrLSE_temperature[None, None, m] + .load() + .reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ), + threads_in_group=threads_per_col, + ) + lse_temperature_max_cur = ( + 0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max + ) + LOG2_E = math.log2(math.e) + lse_temperature_sum_cur = 0.0 + for s in cutlass.range( + cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + ts2rrLSE_temperature[0, s, m] * LOG2_E + - (lse_temperature_max_cur * LOG2_E), + fastmath=True, + ) + lse_temperature_sum_cur += scale + lse_temperature_sum_cur = cute.arch.warp_reduction_sum( + lse_temperature_sum_cur, threads_in_group=threads_per_col + ) + if ( + max_valid_split[m] < 0 + or lse_temperature_sum_cur == 0.0 + or lse_temperature_sum_cur != lse_temperature_sum_cur + ): + final_lse_temperature[m] = -Float32.inf + else: + final_lse_temperature[m] = ( + cute.math.log(lse_temperature_sum_cur, fastmath=True) + + lse_temperature_max + ) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.tile_m: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # This writeback is the authoritative LSE_out returned by the + # public Sparse Attention / Sparse Page Attention interface. + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + mLSE_cur = mLSE[None, None, batch_idx] + else: + mLSE_cur = cute.domain_offset((offset, 0), mLSE) + if const_expr(mLSE_temperature is not None): + if const_expr(cu_seqlens is None): + mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx] + else: + mLSE_temperature_cur = cute.domain_offset( + (offset, 0), mLSE_temperature) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.tile_m + mi + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mLSE_cur[m_idx, head_idx] = final_lse[m] + if const_expr(mLSE_temperature is not None): + mLSE_temperature_cur[m_idx, head_idx] = ( + final_lse_temperature[m]) + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # Flush any outstanding async-copy groups before the local Step-7 + # permutation buffer is read on the tail of the kernel. + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # =============================== + # Step 7: Write final O to gmem (fake→real via SMEM) + # =============================== + + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) + if const_expr(cu_seqlens is None): + mO_cur = mO[None, None, None, batch_idx] + else: + mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + num_vals = const_expr(cute.size(tOcO, mode=[0])) + if const_expr(not use_pdl): + # Direct / standalone calls don't participate in the K1->K2 + # dependency chain. Use a simple per-element real-column store + # path here to keep mixed-shape launches stable. + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO[k]: + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype) + else: + # 7a: fp32 accumulator -> output dtype SMEM with fake→real + # permutation. The dedicated permutation buffer stays separate + # from the O_partial pipeline staging buffer. + sO_perm = sO_perm_buf + + if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]): + # O_partial uses a dtype-specific STG.128 fake layout, but + # sO_perm is in the final O dtype. For all supported fake + # layouts, adjacent fake pairs map to adjacent real columns, + # so write the final BF16/F16 O pair as one 32-bit SMEM store. + assert num_vals % 2 == 0 + r2s_o_pair_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Int32, + num_bits_per_copy=32, + ) + rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32) + sO_perm_i32_base = cute.make_ptr( + dtype=cutlass.Int32, + value=sO_perm.iterator.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v_pair in cutlass.range(num_vals // 2, unroll_full=True): + v = v_pair * 2 + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o0 = tOrO[v, m, k] + o1 = tOrO[v + 1, m, k] + if const_expr(mOutputScale is not None): + o0, o1 = cute.arch.mul_packed_f32x2( + (o0, o1), + (output_scale, output_scale), + ) + rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype) + smem_pair_ptr = cute.make_ptr( + dtype=cutlass.Int32, + value=( + sO_perm_i32_base.toint() + + Int64( + row_local * sO_perm_i32_row_stride + + real_col // Int32(2) + ) + * Int64(4) + ), + mem_space=sO_perm.iterator.memspace, + assumed_align=4, + ) + sO_pair = cute.make_tensor( + smem_pair_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair) + else: + # 7a: iterate over ALL val elements in mode[0]. + # tOcO[v, m, k][1] gives different fake_col for each v. + r2s_o_scalar_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, + ) + rO_scalar = cute.make_rmem_tensor((1,), self.dtype) + for m in cutlass.range(num_rows, unroll_full=True): + row_local = tOcO[0, m, 0][0] + if tOhidx[m] >= 0: + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + for v in cutlass.range(num_vals, unroll_full=True): + fake_col = tOcO[v, 0, k][1] + if const_expr(self.use_stg128_fp8_layout): + real_col = stg128_fp8_fake_col_to_real_col(fake_col) + elif const_expr(self.use_stg128_half_layout): + real_col = stg128_half_fake_col_to_real_col(fake_col) + else: + real_col = stg128_fake_col_to_real_col(fake_col) + o_val = tOrO[v, m, k] + if const_expr(mOutputScale is not None): + o_val = o_val * output_scale + rO_scalar[0] = o_val.to(self.dtype) + smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col)) + smem_scalar_ptr = cute.make_ptr( + dtype=self.dtype, + value=smem_ptr.toint(), + mem_space=sO_perm.iterator.memspace, + assumed_align=self.dtype.width // 8, + ) + sO_scalar = cute.make_tensor( + smem_scalar_ptr, + cute.make_layout((1,), stride=(1,)), + ) + cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar) + + cute.arch.sync_threads() + + # 7b: SMEM (real order, output dtype) → GMEM + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOcO_store = gmem_thr_copy_O.partition_D(cO) + tOsO_store = gmem_thr_copy_O.partition_D(sO_perm) + rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + num_store_rows = const_expr(cute.size(tOcO_store, mode=[1])) + num_store_vals = const_expr(cute.size(tOcO_store, mode=[0])) + tOpO_store = None + if const_expr(not self.is_even_k): + tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean) + for k in cutlass.range(cute.size(tOpO_store), unroll_full=True): + tOpO_store[k] = ( + tOcO_store[0, 0, k][1] + < mO_partial.shape[1] - k_block * self.k_block_size + ) + + # Read output dtype from SMEM (now in real column order). + for m in cutlass.range(num_store_rows, unroll_full=True): + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k]) + + # Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7) + for m in cutlass.range(num_store_rows, unroll_full=True): + row_local = tOcO_store[0, m, 0][0] + idx = m_block * self.tile_m + row_local + if idx < max_idx: + m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod) + mO_cur_copy = cute.tiled_divide( + mO_cur[m_idx, None, head_idx], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True): + k_idx = tOcO_store[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO_store[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOSplitCount: cute.Tensor, + tOpO: Optional[cute.Tensor], + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]): + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_cur_copy[None, k_idx, split], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, k].fill(0) + + +def _get_cutlass_dtype(torch_dtype: torch.dtype): + if torch_dtype not in torch2cute_dtype_map: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + return torch2cute_dtype_map[torch_dtype] + + +_combine_compile_cache = {} + + +def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = const_expr(dtype.width // 8) + bytes_per_row = const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout( + (8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), + order=(1, 0), + ), + ) + + +def combine( + o_partial_fake, + lse_partial, + o_out, + lse_out, + *, + lse_temperature_partial=None, + lse_temperature_out=None, + cu_seqlens=None, + seqused=None, + split_counts=None, + output_scale=None, + use_pdl=False, +): + """K2: merge sparse forward split partials into the final output. + + STG.128 fake-layout handling remains an internal implementation detail. + When lse_out is provided, the kernel writes the final authoritative + log-sum-exp for each query row/head directly into that tensor. + + Args: + o_partial_fake: + Batched: [num_splits, batch, Sq, head_q, dim] + Varlen: [num_splits, total_q, head_q, dim] + lse_partial: + Batched: [num_splits, batch, Sq, head_q] + Varlen: [num_splits, total_q, head_q] + o_out: + Batched: [batch, Sq, head_q, dim] + Varlen: [total_q, head_q, dim] + lse_out: + Batched: [batch, Sq, head_q] + Varlen: [total_q, head_q] + lse_temperature_partial: + Optional temperature-scaled LSE partial with the same shape as + lse_partial. + lse_temperature_out: + Optional temperature-scaled final LSE with the same shape as + lse_out. + cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine. + seqused: Optional [batch] int32 effective lengths for combine. + split_counts: Optional int32 rowwise valid split counts prepared from + q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen: + [total_q, head_kv]. + output_scale: Optional fp32 tensor with at least one element. When + provided, the final O accumulator is multiplied once before store. + use_pdl: When True, wait on PDL dependencies from the producer K1 + kernel. When False, launch without PDL waits. + """ + D = o_partial_fake.shape[-1] + num_splits = o_partial_fake.shape[0] + return_temperature_lse = ( + lse_temperature_partial is not None or lse_temperature_out is not None + ) + if (lse_temperature_partial is None) != (lse_temperature_out is None): + raise ValueError( + "lse_temperature_partial and lse_temperature_out must either both be provided or both be None" + ) + if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape: + raise ValueError( + "lse_temperature_partial must have the same shape as lse_partial, " + f"got {lse_temperature_partial.shape} vs {lse_partial.shape}" + ) + if lse_temperature_out is not None: + if lse_out is None: + raise ValueError("lse_temperature_out requires lse_out") + if lse_temperature_out.shape != lse_out.shape: + raise ValueError( + "lse_temperature_out must have the same shape as lse_out, " + f"got {lse_temperature_out.shape} vs {lse_out.shape}" + ) + if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32: + raise TypeError("temperature LSE tensors must be torch.float32") + + partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype) + out_dtype = _get_cutlass_dtype(o_out.dtype) + if output_scale is not None: + if output_scale.dtype != torch.float32: + raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}") + if output_scale.numel() < 1: + raise ValueError("output_scale must contain at least one element") + if output_scale.device != o_out.device: + raise ValueError("output_scale must be on the same device as o_out") + output_scale = output_scale.contiguous() + if split_counts is not None: + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_out.ndim == 4: + if split_counts.ndim != 3: + raise ValueError( + f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[:2] != o_out.shape[:2]: + raise ValueError( + f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}" + ) + else: + if cu_seqlens is None: + raise ValueError("split_counts with varlen output requires cu_seqlens") + if split_counts.ndim != 2: + raise ValueError( + f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}" + ) + if split_counts.shape[0] != o_out.shape[0]: + raise ValueError( + f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q " + f"({o_out.shape[0]})" + ) + if o_out.shape[-2] % split_counts.shape[-1] != 0: + raise ValueError( + f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})" + ) + qheadperkv = o_out.shape[-2] // split_counts.shape[-1] + else: + qheadperkv = 1 + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}") + if cu_seqlens.ndim != 1: + raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous") + if seqused is not None: + if seqused.dtype != torch.int32: + raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}") + if seqused.ndim != 1: + raise ValueError(f"seqused must be rank-1, got {seqused.shape}") + if not seqused.is_contiguous(): + raise ValueError("seqused must be contiguous") + + k_block_size = 128 if D > 64 else 64 + tile_m = 64 + has_cu_seqlens = cu_seqlens is not None + has_seqused = seqused is not None + has_lse = lse_out is not None + has_split_counts = split_counts is not None + has_output_scale = output_scale is not None + min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0 + + key = ( + "combine", + D, + k_block_size, + tile_m, + num_splits, + partial_dtype, + out_dtype, + has_cu_seqlens, + has_seqused, + has_lse, + bool(return_temperature_lse), + has_split_counts, + has_output_scale, + use_pdl, + min_blocks_per_mp, + ) + if key not in _combine_compile_cache: + from ....src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _combine_compile_cache[key] = loaded + else: + from ....quack.compile_utils import make_fake_tensor + + kernel = SparseAttentionForwardCombine( + dtype=out_dtype, + dtype_partial=partial_dtype, + head_dim=D, + tile_m=tile_m, + k_block_size=k_block_size, + topk=num_splits, + use_pdl=use_pdl, + min_blocks_per_mp=min_blocks_per_mp, + # stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM, + # theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35% + # -> 88.64%. Runtime latency within noise (kernel already at HBM + # bandwidth ceiling in practice) but the cleaner SOL profile + # matters for downstream NCU comparison. + stages=2, + ) + div = 128 // partial_dtype.width + if has_cu_seqlens: + total_q, nheads = (cute.sym_int64() for _ in range(2)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, total_q, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + mO = make_fake_tensor( + out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1) + if return_temperature_lse + else None + ) + else: + batch, sq, nheads = (cute.sym_int64() for _ in range(3)) + mO_partial = make_fake_tensor( + partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div + ) + mLSE_partial = make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + mO = make_fake_tensor( + out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width + ) + mLSE = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if has_lse + else None + ) + mLSE_temperature_partial = ( + make_fake_tensor( + Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3 + ) + if return_temperature_lse + else None + ) + mLSE_temperature = ( + make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2) + if return_temperature_lse + else None + ) + if not has_split_counts: + mSplitCounts = None + elif has_cu_seqlens: + total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2)) + mSplitCounts = make_fake_tensor( + Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1 + ) + else: + nheads_kv = cute.sym_int64() + mSplitCounts = make_fake_tensor( + Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2 + ) + mOutputScale = ( + make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0) + if has_output_scale + else None + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + _combine_compile_cache[key] = cute.compile( + kernel, + mO_partial, + mLSE_partial, + mO, + mLSE, + mLSE_temperature_partial, + mLSE_temperature, + None + if cu_seqlens is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None + if seqused is None + else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0), + None, + None, + None, + mSplitCounts, + mOutputScale, + Int32(qheadperkv), + stream, + options="--enable-tvm-ffi", + ) + save_aot(key, _combine_compile_cache[key]) + + with torch.cuda.nvtx.range("K2_Combine"): + _combine_compile_cache[key]( + o_partial_fake, + lse_partial, + o_out, + lse_out, + lse_temperature_partial, + lse_temperature_out, + cu_seqlens, + seqused, + None, + None, + None, + split_counts, + output_scale, + qheadperkv, + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d64a0616bd5bb9c987e43b87bcbf9e89001fbb36 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/__init__.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""CUTE DSL launchers for paged fp8 decode forward.""" + +from __future__ import annotations + +import torch + +from .atten_fwd import run_decode_attention +from .combine import run_decode_combine + + +def decode_forward_paged_fp8( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + merge_indptr: torch.Tensor, + O_partial: torch.Tensor | None, + LSE_partial: torch.Tensor | None, + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + max_split_count: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + O_partial_dummy: torch.Tensor | None = None, + LSE_partial_dummy: torch.Tensor | None = None, +) -> None: + """Launch dense paged fp8 decode forward and optional compressed combine. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated + placeholder buffers for the non-split path. When supplied, ``run_decode_attention`` + skips the per-call ``torch.empty`` it would otherwise need to satisfy the + kernel's positional arg signature, saving ~5us on small-kv calls. + """ + + run_decode_attention( + q, + k, + v, + page_table, + seqused_k, + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + o_indptr, + out, + lse, + O_partial, + LSE_partial, + softmax_scale=float(softmax_scale), + seqlen_q=int(seqlen_q), + page_size=int(page_size), + kv_chunk_size_pages=int(kv_chunk_size_pages), + split_kv=bool(split_kv), + causal=bool(causal), + return_lse=bool(return_lse), + O_partial_dummy=O_partial_dummy, + LSE_partial_dummy=LSE_partial_dummy, + ) + if split_kv: + if O_partial is None or LSE_partial is None: + raise ValueError("split decode requires O_partial and LSE_partial") + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + run_decode_combine( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q=int(seqlen_q), + q_tokens_per_group=q_tokens_per_group, + max_split_count=int(max_split_count), + ) + + +__all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..9a56bb20363deffd4c850533484427bc128b3c84 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py @@ -0,0 +1,2691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Dense paged fp8 decode forward path. + +This file owns the CUTE DSL entry point for decode attention via +``SparseDecodeAttentionForwardSm100`` — SM100 UTCMMA + persistent +scheduling, paged fp8 Q/K/V, BSA blk128-style intra-warp overlap pipeline. +Forward only. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable, Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as cutlass_pipeline +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cutlass_dsl import BaseDSL +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from ....quack import copy_utils, layout_utils + +from ....src.common import pipeline +from ....src.common import blackwell_helpers as sm100_helpers +from ....src.common import mma_sm100_desc as sm100_desc +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map +from ....src.common.named_barrier import NamedBarrierFwdSm100 +from ....src.common.softmax import SoftmaxSm100 +from ....src.common.mask import AttentionMask +from ....src.common.seqlen_info import SeqlenInfoQK +from ....src.common.pack_gqa import pack_gqa_layout +from ....src.common.tile_scheduler import SchedulingMode +from ....src.sm100.fwd_decode.tile_scheduler import ( + DecodeTileScheduler, + DecodeTileSchedulerArguments, +) + + +class SparseDecodeAttentionForwardSm100: + """SM100 dense paged fp8 decode forward attention (UTCMMA + CLC). + + Scope (Phase 1): + - Dense decode, ``split_kv=False``, single q-tile per work item + (``packed_q = seqlen_q * qhead_per_kv <= tile_m=128``). + - Causal only. KV reverse page loop; first reverse block applies + causal/seqlen mask, the rest is unmasked. + - fp8 Q/K/V, bf16 O, fp32 LSE. P is quantized to fp8_e4m3fn before PV + via ``SoftmaxSm100.apply_exp2_convert`` (mirror of prefill fp8 PV). + - per-batch ``mSeqUsedK[b]`` heterogeneous; no uniform-length assumptions. + + Production scope reached at Phase 4+: + - Multi q-tile (Phase 2), split-KV partial writeback (Phase 3), + CLC persistent scheduling (Phase 4), TC SOL >= 90% (Phase 7). + """ + + # UTCMMA K-tile width (matches prefill SparseAttentionForwardSm100). + k_tile = 64 + + def __init__( + self, + head_dim: int = 128, + qhead_per_kv: int = 16, + m_block_size: int = 128, + n_block_size: int = 128, + page_size: int = 128, + split_kv: bool = False, + causal: bool = True, + write_lse: bool = True, + disable_softmax_exp2: bool = False, + ): + # --- structural constraints (Phase 1 scope) ------------------------- + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeAttentionForwardSm100 currently supports only D=128, " + f"got D={head_dim}" + ) + if m_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires tile_m=128, got {m_block_size}" + ) + if n_block_size != 128: + raise NotImplementedError( + f"decode UMMA forward requires n_block_size=128, got {n_block_size}" + ) + if page_size != n_block_size: + raise ValueError( + f"page_size ({page_size}) must equal n_block_size ({n_block_size})" + ) + if qhead_per_kv not in (16, 8, 4, 2, 1): + raise ValueError( + f"qhead_per_kv must be in {{1, 2, 4, 8, 16}}, got {qhead_per_kv}" + ) + if not causal: + raise NotImplementedError( + "decode UMMA forward currently supports only causal=True" + ) + + self.head_dim = int(head_dim) + self.qhead_per_kv = int(qhead_per_kv) + self.m_block_size = int(m_block_size) + self.n_block_size = int(n_block_size) + self.page_size = int(page_size) + self.tile_m = int(m_block_size) + self.split_kv = bool(split_kv) + self.causal = bool(causal) + self.write_lse = bool(write_lse) + self.disable_softmax_exp2 = bool(disable_softmax_exp2) + # FA fp8 SM100 fwd uses a threshold of 4.0 to avoid rescaling O for + # small row-max movements; correction receives acc_scale directly. + self.rescale_threshold = 4.0 + + # q tokens packed per (m_block_size) row group along M. + self.q_tokens_per_group = self.m_block_size // self.qhead_per_kv + + self.mma_tiler_qk = (self.m_block_size, self.n_block_size, self.head_dim) + self.mma_tiler_pv = (self.m_block_size, self.head_dim, self.n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + + # --- pipeline ring stages (BSA blk128 q_stage=1, s_stage=2) --- + self.q_stage = 1 + self.s_stage = 2 + self.o_stage = 2 + # Keep the fp8 decode KV ring deep enough to cover the K0/Q/K1/V0... + # order. This matches sage's fp8 setting and removes the underfed + # two-stage KV pipeline seen in the q8/16K non-split case. + self.kv_stage = 4 + self.k_stages = 2 + # Match prefill: PV is split at 3/4 of n_block_size for fp8. The + # producer (P store) must publish exactly 3N/4 fp8 columns at the + # signal point; that requires the TMEM-store atom Repetition to be + # ``8`` (one PV ``f8f6f4`` K=32 segment = 8 fp32 packed cols), so + # ``shape[2]=4`` chunks and ``split_idx=3`` lands on the 3N/4 + # boundary exactly. The previous N/2 cap was a workaround for + # ``Repetition(16)`` whose coarser chunk boundary could not + # represent 3N/4. + self.split_P_arrive = self.n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + + # --- warp layout (16 warps / 512 threads) — BSA-aligned (Phase 1.10.6b) + # 0-3 softmax WG 0 + # 4-7 softmax WG 1 + # 8-11 correction WG (acc_O rescale across pages + final epilogue + # write-back; participates in TmemPtr barrier) + # 12 MMA issue warp + # 13 spare / future CLC scheduler + # 14 load warp (serial Q + K + V TMA loads) + # 15 empty / register-budget reserve + self.warps_per_group = 4 + self.softmax0_warp_base = 0 + self.softmax1_warp_base = self.softmax0_warp_base + self.warps_per_group + self.correction_warp_base = ( + self.softmax1_warp_base + self.warps_per_group) + self.mma_warp_id = self.correction_warp_base + self.warps_per_group + self.spare_warp_id = self.mma_warp_id + 1 + self.load_warp_id = self.spare_warp_id + 1 + self.empty_warp_id = self.load_warp_id + 1 + self.total_warps = 16 + self.threads_per_cta = cute.arch.WARP_SIZE * self.total_warps + + # --- TMEM layout (fp8 P width-pack: 4 fp8 lanes per fp32 column) --- + # S0/S1: [0:128], [128:256] + # O0/O1: [256:384], [384:512] for head_dim_v=128 + # P (fp8) overlays the second half of each S tile via recast_ptr. + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + self.tmem_s_offset = 0 + self.tmem_stage_stride = self.n_block_size + self.tmem_o_stage_stride = self.head_dim + self.tmem_o_offset = self.s_stage * self.n_block_size + # fp8 P occupies n_block_size * fp8_width / fp32_width = n/4 fp32 cols. + # P offset is set in __call__ once q_dtype is known (defer to Phase 1.3). + raw_tmem_total = self.tmem_o_offset + self.o_stage * self.tmem_o_stage_stride + # SM100 TMEM allocation requires a power-of-two column count. + self.tmem_total = 1 << (raw_tmem_total - 1).bit_length() + + # --- register budget per role (BSA hdim>=96 default) --- + self.num_regs_softmax = 184 + self.num_regs_correction = 88 + self.num_regs_other = 56 + self.num_regs_mma = self.num_regs_other + self.num_regs_load = self.num_regs_other + self.num_regs_epilogue = self.num_regs_other + self.num_regs_empty = self.num_regs_other + + # exp2 emulation for causal: matches prefill ex2_emu_freq=16. + # disable_softmax_exp2 (Phase 7 SOL gate) bypasses both emulation and + # native exp2 — the convert pass becomes a pure fp32 -> fp8 cast. + self.ex2_emu_freq = 16 if (self.causal and not self.disable_softmax_exp2) else 0 + self.ex2_emu_start_frg = 1 + self.buffer_align_bytes = 1024 + + # --- SM100 cluster config (single-CTA for decode, no 2-CTA pair) - + self.use_2cta_instrs = False + self.cta_group_size = 1 + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (1, 1, 1) + self.use_clc_scheduler = True + self.scheduling_mode = SchedulingMode.CLC + self.sched_stages = 2 + self.clc_scheduler_warp_id = self.empty_warp_id + + self.arch = BaseDSL._get_dsl().get_arch_enum() + + # ------------------------------------------------------------------ + # Host-side: TMA descriptors, SMEM layout, launch + # Phase 1.2+ fills in the body. Phase 1.1 keeps signatures stable so + # the rest of the codepath (run_decode_attention dispatch in 1.10) + # can wire to this class without further churn. + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # [B, Sq, Hq, D] fp8 + mK: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mV: cute.Tensor, # [num_pages, Hkv, page_size, D] fp8 + mPageTable: cute.Tensor, # [B, max_pages] int32 + mSeqUsedK: cute.Tensor, # [B] int32 + mRequestIndices: cute.Tensor, # [work_capacity] int32 + mQoTileIndices: cute.Tensor, # [work_capacity] int32 + mKvTileIndices: cute.Tensor, # [work_capacity] int32 + mBlockValidMask: cute.Tensor, # [work_capacity] int32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] bf16 + mLSE: cute.Tensor, # [total_q, Hq] fp32 + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + softmax_scale: Float32, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + stream: cuda.CUstream = None, + ): + # --- dtype contract ------------------------------------------------ + if const_expr(mQ.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA Q must be Float8E4M3FN") + if const_expr(mK.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA K must be Float8E4M3FN") + if const_expr(mV.element_type is not cutlass.Float8E4M3FN): + raise TypeError("decode UMMA V must be Float8E4M3FN") + if const_expr(mO.element_type is not cutlass.BFloat16): + raise TypeError("decode UMMA output O must be BFloat16") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode UMMA output LSE must be Float32") + if const_expr(self.split_kv): + if const_expr(mO_partial is None or mO_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 O_partial") + if const_expr(mLSE_partial is None or mLSE_partial.element_type is not Float32): + raise TypeError("decode UMMA split path requires Float32 LSE_partial") + + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = ( + mO_partial.element_type if const_expr(self.split_kv) + else mO.element_type + ) + # f8f6f4 MMA descriptor kind for fp8 Q/K/V. + self.mma_kind = "f8f6f4" + # fp8 P width-pack ratio: each fp32 TMEM column holds 4 fp8 P lanes. + # Computed here so __init__ stays dtype-agnostic and the TMEM offsets + # can later be derived from this ratio in Phase 1.3. + elem_bytes = const_expr(self.q_dtype.width // 8) + p_cols_as_fp32 = const_expr( + self.n_block_size * self.q_dtype.width // Float32.width + ) + self.tmem_s_to_p_offset = self.n_block_size - p_cols_as_fp32 + self.tmem_p_offset = self.tmem_s_offset + self.tmem_s_to_p_offset + + mQ, mK, mV, mO, mLSE = [ + assume_tensor_aligned(t) for t in (mQ, mK, mV, mO, mLSE) + ] + if const_expr(mO_partial is not None): + mO_partial = assume_tensor_aligned(mO_partial) + if const_expr(mLSE_partial is not None): + mLSE_partial = assume_tensor_aligned(mLSE_partial) + mO_epilogue = mO_partial if const_expr(self.split_kv) else mO + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO_epilogue) + self.epi_tile = (self.m_block_size, self.head_dim) + + # ------------------------------------------------------------------ + # UTCMMA TiledMma: QK^T + PV. PV uses MN-major V operand (V already + # transposed in the layout below) and a TMEM operand source for P. + # Phase 1.4 builds tiled_mma_qk; Phase 1.5 adds tiled_mma_pv so sV + # layout can derive the MN-major swizzle. + # ------------------------------------------------------------------ + cta_group = tcgen05.CtaGroup.ONE + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, + Float32, cta_group, self.mma_tiler_qk[:2]) + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN, + Float32, cta_group, self.mma_tiler_pv[:2], tcgen05.OperandSource.TMEM) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + # ------------------------------------------------------------------ + # Paged K/V tensor view permutation. + # Input layout [num_pages, Hkv, page_size, D] (nhsd) is permuted to + # [page_size, D, Hkv, num_pages] for the paged TMA descriptor (K). + # V gets an additional (s,d) swap to become MN-major: + # [D, page_size, Hkv, num_pages]. + # ------------------------------------------------------------------ + mK_paged = cute.make_tensor( + mK.iterator, cute.select(mK.layout, mode=[2, 3, 1, 0]) + ) + mV_kv = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=[2, 3, 1, 0]) + ) + mV_paged = cute.make_tensor( + mV_kv.iterator, cute.select(mV_kv.layout, mode=[1, 0, 2, 3]) + ) + + # ------------------------------------------------------------------ + # Q SMEM layout + BSA/FA PackGQA full-tile TMA atom. + # + # Runtime Q is [B, Sq, Hq, D]. We transpose to [Sq, D, Hq, B], then + # fold qhead_per_kv into the M dimension: + # ((qhead_per_kv, Sq), D, Hkv, B) + # This lets one Q TMA load cover the whole packed (tile_m, D) tile + # instead of issuing one TMA per q token. + # ------------------------------------------------------------------ + total_q_stages = self.q_stage + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, total_q_stages) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + mQ = cute.make_tensor( + mQ.iterator, cute.select(mQ.layout, mode=[1, 3, 2, 0])) + nheads_kv = mK.shape[1] + mQ = pack_gqa_layout(mQ, self.qhead_per_kv, nheads_kv, head_idx=2) + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + + # ------------------------------------------------------------------ + # K / V SMEM layouts + TMA atoms (paged). + # sK uses the QK MMA operand B swizzle; sV uses the PV MMA operand B + # swizzle (MN-major). tP_layout is the TMEM-side P descriptor — no + # SMEM is actually allocated for P, it overlays the S region in TMEM + # via cute.recast_ptr in Phase 1.7. + # ------------------------------------------------------------------ + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage) + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage) + tP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage) + + tma_atom_K, mK_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mK_paged, cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape) + tma_atom_V, mV_paged = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, mV_paged, cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape) + + # ------------------------------------------------------------------ + # Phase 1.10.6b-B-2: TMA-store atom for the epilogue write-back. + # Non-split writes bf16 final O; split-KV writes fp32 O_partial. + # sO follows FA/BSA epilogue layout: one full m_block x D tile in + # SMEM. Both paths expose global O as a packed-GQA tensor view so the + # final store is a full BSA-style m_block x D TMA tile. + # ------------------------------------------------------------------ + sO_layout = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.q_stage, + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_heads_kv_tma = mK.shape[1] + total_o_rows_tma = ( + mO_epilogue.shape[0] + // (num_heads_kv_tma * self.qhead_per_kv) + ) + head_stride_tma = self.head_dim + o_row_stride_tma = ( + num_heads_kv_tma * self.qhead_per_kv * self.head_dim) + kv_head_stride_tma = self.qhead_per_kv * self.head_dim + mO_epilogue_tma = cute.make_tensor( + mO_epilogue.iterator, + cute.make_layout( + ((self.qhead_per_kv, total_o_rows_tma), self.head_dim, num_heads_kv_tma), + stride=((head_stride_tma, o_row_stride_tma), 1, kv_head_stride_tma), + ), + ) + tma_atom_O, mO_tma = cpasync.make_tiled_tma_atom( + tma_store_op, + mO_epilogue_tma, + cute.select(sO_layout, mode=[0, 1]), + self.epi_tile, + ) + + # Pre-multiply softmax scale by log2(e) so the inner exp2 path can + # operate without re-scaling at every iteration. Mirrors prefill. + softmax_scale_log2 = softmax_scale * Float32(math.log2(math.e)) + + work_capacity = mRequestIndices.shape[0] + num_heads_kv = mK.shape[1] + tile_sched_args = DecodeTileSchedulerArguments( + Int32(work_capacity), + Int32(num_heads_kv), + cluster_shape_mn=self.cluster_shape_mn, + ) + tile_sched_params = DecodeTileScheduler.to_underlying_arguments( + tile_sched_args, + scheduling_mode=self.scheduling_mode, + ) + self.tile_scheduler_cls = DecodeTileScheduler + grid = DecodeTileScheduler.get_grid_shape(tile_sched_params) + + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + # ------------------------------------------------------------------ + # SharedStorage mirrors BSA blk128's pipeline mesh for dense paged + # decode: Q, shared K/V, S/P/O, P-lastsplit, O-acc, O-epilogue and + # softmax stats mbarriers, plus the TMEM allocator state and SMEM + # staging tensors. + # ------------------------------------------------------------------ + @cute.struct + class SharedStorage: + mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] + mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_P_full_lastsplit: cute.struct.MemRange[ + Int64, self.s_stage * 2] + mbar_O_full: cute.struct.MemRange[Int64, self.s_stage * 2] + mbar_softmax_stats0: cute.struct.MemRange[Int64, 2] + mbar_softmax_stats1: cute.struct.MemRange[Int64, 2] + mbar_O_epi: cute.struct.MemRange[Int64, self.s_stage * 2] + # Phase 1.10.6b-B-2: bf16 sO SMEM staging buffer for the TMA + # store epilogue. Sized for one full m_block_size × head_dim + # tile (single stage; overlap with sQ left for later perf tune). + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: Int32 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(sV_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # ------------------------------------------------------------------ + # Launch — decode tasks are consumed from the + # (work_idx, head_kv_idx) scheduler space. In CLC mode grid is the + # BSA-style hardware problem shape; in static mode it is capped to the + # SM count and each CTA walks the flattened task stream. + # ------------------------------------------------------------------ + # q_tma_bytes (and Phase 1.5+: kv_tma_bytes / q_subtile_bytes) are + # recomputed inside the kernel from the constexpr SMEM layouts. + # Passing them as Constexpr[int] kernel args ended up marshalling + # to dynamic Int32 here, which then tripped MbarrierArray's + # `if tx_count < 0` check inside PipelineTmaUmma.create. + self.kernel( + mQ, mK_paged, mV_paged, + mPageTable, mSeqUsedK, + mRequestIndices, mQoTileIndices, mKvTileIndices, mBlockValidMask, + mSplitCounts, mOIndptr, + mO, mO_tma, mLSE, + mO_partial, mLSE_partial, + softmax_scale_log2, + sQ_layout, sK_layout, sV_layout, tP_layout, sO_layout, + tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, + tiled_mma_qk, tiled_mma_pv, + tile_sched_params, + seqlen_q, page_size, kv_chunk_size_pages, + Int32(num_heads_kv), + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=( + self.cluster_shape_mnk + if cute.size(self.cluster_shape_mnk) > 1 else None + ), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + # --- runtime tensors ------------------------------------------------- + mQ: cute.Tensor, # [((qhead_per_kv, Sq), D, Hkv, B)] + mK_paged: cute.Tensor, # [page_size, D, Hkv, num_pages] fp8 + mV_paged: cute.Tensor, # [D, page_size, Hkv, num_pages] fp8 + mPageTable: cute.Tensor, + mSeqUsedK: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mBlockValidMask: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mO_tma: cute.Tensor, + mLSE: cute.Tensor, + mO_partial: Optional[cute.Tensor], + mLSE_partial: Optional[cute.Tensor], + # --- scalars --------------------------------------------------------- + softmax_scale_log2: Float32, + # --- SMEM layouts ---------------------------------------------------- + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + # --- TMA atoms ------------------------------------------------------- + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + # --- TiledMma -------------------------------------------------------- + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: DecodeTileScheduler.Params, + # --- Int32 iteration bounds ------------------------------------------ + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + ): + # ------------------------------------------------------------------ + # Thread / warp identity, work-item dispatch. + # ------------------------------------------------------------------ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + if warp_idx == Int32(0): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + # ------------------------------------------------------------------ + # SMEM allocation — same SharedStorage type was registered on the + # class in __call__ (Phase 1.3). Every warp materialises the same + # storage view; later phases populate sQ/sK/sV/mbar contents. + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + # sQ is the MMA-operand layout and now also the Q TMA load target: + # PackGQA makes the global Q view match the full BSA (tile_m, D) tile. + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + # ------------------------------------------------------------------ + # TMEM allocator — MMA warp performs the allocation, all softmax / + # store / MMA warps participate in the TmemPtr named barrier that + # broadcasts the allocator pointer. Spare warp and KV-load warps + # do not touch TMEM directly. + # ------------------------------------------------------------------ + # TmemPtr participants: 2 softmax WGs (8 warps) + correction WG + # (4 warps) + MMA warp = 13 warps × WARP_SIZE. Load / spare / + # empty warps don't touch TMEM and don't arrive on this barrier. + tmem_alloc_warps: cutlass.Constexpr[int] = ( + self.warps_per_group * 3 + 1) + tmem_alloc_threads = cute.arch.WARP_SIZE * tmem_alloc_warps + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=tmem_alloc_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + ) + tmem_cols = self.tmem_total + + # ------------------------------------------------------------------ + # Cluster layout + warp-specialized pipelines. + # Mirrors prefill (src/sm100/fwd/atten_fwd.py:617-683): cta_layout_vmnk + # is rebuilt in-kernel from tiled_mma_qk.thr_id.shape so its size is + # constexpr (the `cute.size(cta_layout_vmnk) == 1` check inside + # PipelineTmaUmma.create folds at compile time). pipeline_q is + # joined by the BSA S/P/O and shared K/V pipelines below. + # ------------------------------------------------------------------ + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)) + + ThreadCooperativeGroup = partial( + cutlass_pipeline.CooperativeGroup, cutlass_pipeline.Agent.Thread) + tma_thread = ThreadCooperativeGroup(1) + mma_thread = ThreadCooperativeGroup(1) + # One softmax WG participates per S/P/O stage; correction and the + # epilogue warp handle O rescale and TMA write-back. + softmax_warps = ThreadCooperativeGroup(self.warps_per_group) + softmax_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group) + + # Recompute TMA byte counts inside the kernel from the constexpr SMEM + # layouts — see note in __call__ above the self.kernel(...) call for + # why these can't be plumbed through as Constexpr[int] kernel args. + q_tma_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + k_tma_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + + pipeline_q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=q_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # Decode KV follows BSA's single K/V ring: K0 is primed before Q, + # then K1, V0, K2, V1, ... share one PipelineTmaUmma state while + # landing in separate sK/sV SMEM tensors. For fp8 decode K/V TMA + # tiles have the same byte count, so the shared barrier uses K's count. + pipeline_kv = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=tma_thread, + consumer_group=mma_thread, + tx_count=k_tma_bytes, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + # ------------------------------------------------------------------ + # BSA pipeline mesh. + # pipeline_s_p_o — MMA→{softmax,correction} (8-warp cluster + # consumer). MMA producer_commit signals + # "S ready"; consumer_release signals "P stored + # and acc_O rescaled — MMA can issue next QK". + # pipeline_o_acc — MMA→correction (acc_O updated by PV). + # pipeline_sm_stats0/1 — softmax→correction stage-local stats. + # This avoids the per-warp NamedBarrier used by + # the BSA reference while preserving the same + # first/rescale/final signal sequence. + # pipeline_o_epi — correction→epilogue warp 13 (final O ready). + # ------------------------------------------------------------------ + softmax_correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE + * (self.warps_per_group + self.warps_per_group) # = 256 + ) + correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * self.warps_per_group # = 128 + ) + epilogue_warp_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE # warp 13 = 32 threads + ) + + pipeline_s_p_o = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=softmax_correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), + num_stages=self.s_stage, + producer_group=softmax_warps, + consumer_group=mma_thread, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_o_acc = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_O_full.data_ptr(), + num_stages=self.s_stage, + producer_group=mma_thread, + consumer_group=correction_threads, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_sm_stats0 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats0.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_sm_stats1 = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats1.data_ptr(), + num_stages=1, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + pipeline_o_epi = pipeline.PipelineAsync.create( + barrier_storage=storage.mbar_O_epi.data_ptr(), + num_stages=self.s_stage, + producer_group=correction_threads, + consumer_group=epilogue_warp_threads, + defer_sync=True, + ) + + # Fence mbar init across all regular pipelines. CLC pipeline setup + # follows the BSA ordering: arrive after mbar init, create scheduler + # state, then wait before TMEM allocation and role dispatch. + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps = ( + self.threads_per_cta // cute.arch.WARP_SIZE + ) * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_clc_consumer_warps, + ) + clc_pipeline = cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ) + tile_scheduler = self.tile_scheduler_cls.create( + tile_sched_params, clc_response_ptr=clc_response_ptr + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + tile_scheduler.set_clc_pipeline( + clc_pipeline, clc_consumer_state) + else: + clc_pipeline = None + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + # Single load warp issues Q + K + V TMA serially; no inter-warp + # broadcast / Q-load WG barrier needed (the BSA-aligned layout + # collapses the previous 4-warp Q-load fan-out into one warp). + + # ------------------------------------------------------------------ + # Phase 1.10.3: pre-dispatch TMEM partitions for softmax read/write. + # Mirrors prefill softmax body setup + # (src/sm100/fwd/atten_fwd.py:807-829, 1891-1921). Built once across + # all warps so each softmax WG can take its stage slice. + # ------------------------------------------------------------------ + thr_mma_qk_pre = tiled_mma_qk.get_slice(0) + qk_acc_shape_pre = thr_mma_qk_pre.partition_shape_C( + self.mma_tiler_qk[:2]) + tStS_base_pre = thr_mma_qk_pre.make_fragment_C(qk_acc_shape_pre) + tStS_pre = cute.make_tensor( + tStS_base_pre.iterator, + cute.append( + tStS_base_pre.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tScS_pre = thr_mma_qk_pre.partition_C( + cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS_pre = tScS_pre[(None, None), 0, 0] + # fp8 P occupies n_block_size * fp8_width / fp32_width fp32 cols. + tilePlikeFP32 = const_expr( + self.mma_tiler_qk[1] * self.q_dtype.width // Float32.width) + tmem_load_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + # Repetition(8) gives ``tStP_r2t.shape[2] = tilePlikeFP32 / 8 = 4`` + # chunks for fp8 (tilePlikeFP32=32), with each chunk publishing + # 8 fp32 cols = 32 fp8 cols = exactly one PV ``f8f6f4`` K=32 + # segment. ``split_idx = 4 * 3N/4 / N = 3`` aligns the early + # publish edge to the producer/consumer K boundary. Larger + # Repetition (e.g. 16) would coarsen shape[2] to 2 and force + # split_idx to floor to 1, publishing only N/2 of P before MMA's + # first three K=32 segments need cols 0..3N/4 — that mismatch is + # the NaN source the workaround used to dodge with split=N/2. + tmem_store_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), + Float32, + ) + tmem_store_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tmem_load_vec_atom_pre = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + # ------------------------------------------------------------------ + # Warp role dispatch. Bodies are filled in Phase 1.3-1.9: + # softmax WG 0/1 (warps 0-3, 4-7) — softmax + P fp32->fp8 convert + # store / Q-load WG (warps 8-11) — Q TMA gather + epilogue store + # MMA warp (warp 12) — UTCMMA QK + PV issue + # correction WG (warps 8-11) — per-page acc_O rescale + epilogue + # MMA warp (warp 12) — UTCMMA QK + PV issue + # spare warp (warp 13) — empty / future CLC scheduler + # load warp (warp 14) — serial Q + K + V TMA loads + # empty warp (warp 15) — register-budget reserve + # ------------------------------------------------------------------ + is_softmax0_warp = ( + warp_idx >= Int32(self.softmax0_warp_base) + and warp_idx < Int32(self.softmax1_warp_base) + ) + is_softmax1_warp = ( + warp_idx >= Int32(self.softmax1_warp_base) + and warp_idx < Int32(self.correction_warp_base) + ) + is_correction_warp = ( + warp_idx >= Int32(self.correction_warp_base) + and warp_idx < Int32(self.mma_warp_id) + ) + is_mma_warp = warp_idx == Int32(self.mma_warp_id) + is_spare_warp = warp_idx == Int32(self.spare_warp_id) + is_load_warp = warp_idx == Int32(self.load_warp_id) + is_empty_warp = warp_idx == Int32(self.empty_warp_id) + + if const_expr(self.use_clc_scheduler): + if warp_idx == Int32(self.clc_scheduler_warp_id): + cute.arch.setmaxregister_decrease(self.num_regs_empty) + self.clc_scheduler_warp(clc_pipeline, tile_scheduler) + is_empty_warp = False + + if is_softmax0_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg0 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg0 + self.softmax_loop( + 0, + self.softmax0_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats0, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_softmax1_warp: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr_wg1 = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_wg1 + self.softmax_loop( + 1, + self.softmax1_warp_base, + softmax_scale_log2, + tStS_pre, + tScS_pre, + tilePlikeFP32, + tmem_load_atom_pre, + tmem_store_atom_pre, + tmem_store_vec_atom_pre, + thr_mma_qk_pre, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats1, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + ) + tmem_alloc_barrier.arrive() + + if is_correction_warp: + cute.arch.setmaxregister_decrease(self.num_regs_correction) + # Participate in TmemPtr handshake so the MMA warp can free. + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr_corr + + self.correction_loop( + tiled_mma_pv, + tStS_pre, + tScS_pre, + tmem_load_vec_atom_pre, + pipeline_s_p_o, + pipeline_sm_stats0, + pipeline_sm_stats1, + pipeline_o_acc, + pipeline_o_epi, + sO, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mSplitCounts, + mOIndptr, + mLSE, + mLSE_partial, + mBlockValidMask, + tile_scheduler, + seqlen_q, + page_size, + kv_chunk_size_pages, + num_heads_kv, + softmax_scale_log2, + ) + tmem_alloc_barrier.arrive() + + if is_spare_warp: + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + self.epilogue_s2g( + mO_tma, + sO, + tma_atom_O, + pipeline_o_epi, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mOIndptr, + mBlockValidMask, + tile_scheduler, + seqlen_q, + ) + + if is_load_warp: + self.load( + tiled_mma_qk, + tiled_mma_pv, + mQ, + mK_paged, + mV_paged, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_q, + pipeline_kv, + mRequestIndices, + mQoTileIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + if is_empty_warp: + cute.arch.setmaxregister_decrease(self.num_regs_empty) + + if is_mma_warp: + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ---------------------------------------------------------------- + # MMA warp — Phase 1.6: QK fp8×fp8→fp32 UMMA. Phase 1.10.1 now + # wraps the body in the real TMEM allocator lifecycle: + # tmem.allocate(cols) -> wait_for_alloc -> retrieve_ptr + # -> ... QK work ... + # -> relinquish_alloc_permit -> tmem_alloc_barrier.arrive_and_wait + # -> free(ptr, cols) + # Softmax WG 0/1 participate via wait_for_alloc + retrieve_ptr + + # tmem_alloc_barrier.arrive (4+4+1 = 9 warps). + # ---------------------------------------------------------------- + tmem.allocate(tmem_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + _ = tmem_ptr # consumed by gemm_pv via raw TMEM offsets + + self.mma( + sQ, + sK, + sV, + tP_layout, + tiled_mma_qk, + tiled_mma_pv, + pipeline_q, + pipeline_kv, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_o_acc, + mRequestIndices, + mKvTileIndices, + mSeqUsedK, + mBlockValidMask, + tile_scheduler, + page_size, + kv_chunk_size_pages, + ) + + # Phase 1.10.1: TMEM allocator teardown. + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr, num_columns=tmem_cols) + + @cute.jit + def clc_scheduler_warp( + self, + clc_pipeline: cutlass_pipeline.PipelineClcFetchAsync, + tile_scheduler: DecodeTileScheduler, + ) -> None: + clc_producer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, + self.sched_stages, + ) + clc_consumer_state = cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, + self.sched_stages, + ) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + clc_pipeline.producer_acquire(clc_producer_state) + mbarrier_addr = clc_pipeline.producer_get_barrier( + clc_producer_state) + tile_scheduler.advance_to_next_work( + mbarrier_addr=mbarrier_addr, + response_stage=clc_producer_state.index, + ) + clc_producer_state.advance() + + clc_pipeline.consumer_wait(clc_consumer_state) + work_tile = tile_scheduler.get_current_work( + response_stage=clc_consumer_state.index) + clc_pipeline.consumer_release(clc_consumer_state) + clc_consumer_state.advance() + clc_pipeline.producer_tail(clc_producer_state) + + @cute.jit + def correction_loop( + self, + tiled_mma_pv: cute.TiledMma, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tmem_load_vec_atom_pre: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_sm_stats0: pipeline.PipelineAsync, + pipeline_sm_stats1: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + pipeline_o_epi: pipeline.PipelineAsync, + sO: cute.Tensor, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mLSE: cute.Tensor, + mLSE_partial: Optional[cute.Tensor], + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + num_heads_kv: Int32, + softmax_scale_log2: Float32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg_corr = warp_idx - Int32(self.correction_warp_base) + group_tidx_corr = ( + warp_idx_in_wg_corr * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + + # First iter: no correction is required. Notify MMA that the + # initial O slots are available, matching BSA's correction_loop. + for stage_init in cutlass.range_constexpr(self.s_stage): + pipeline_s_p_o.consumer_release_w_index(Int32(stage_init)) + + o_corr_consumer_phase = Int32(0) + sm_stats0_consumer_phase = Int32(0) + sm_stats1_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + thr0_rs = tiled_mma_pv.get_slice(0) + pv_acc_shape_rs_c = thr0_rs.partition_shape_C( + self.mma_tiler_pv[:2]) + tOtO_base_rs_c = thr0_rs.make_fragment_C(pv_acc_shape_rs_c) + tOtO_rs_c = cute.make_tensor( + tOtO_base_rs_c.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base_rs_c.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tScS_vec_layout_corr = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec_corr = cute.make_tensor( + tScS_pre.iterator, tScS_vec_layout_corr) + tSAcc_corr0 = tStS_pre[(None, None), 0, 0, 0] + tSAcc_corr1 = tStS_pre[(None, None), 0, 0, 1] + tStS_vec0_layout_corr = cute.composition( + tSAcc_corr0.layout, cute.make_layout((self.m_block_size, 2))) + tStS_vec1_layout_corr = cute.composition( + tSAcc_corr1.layout, cute.make_layout((self.m_block_size, 2))) + tStStats0_t2r_src = cute.make_tensor( + tSAcc_corr0.iterator, tStS_vec0_layout_corr) + tStStats1_t2r_src = cute.make_tensor( + tSAcc_corr1.iterator, tStS_vec1_layout_corr) + thr_tmem_load_vec = tcgen05.make_tmem_copy( + tmem_load_vec_atom_pre, + tStStats0_t2r_src, + ).get_slice(group_tidx_corr) + tStStats0_t2r = thr_tmem_load_vec.partition_S(tStStats0_t2r_src) + tStStats1_t2r = thr_tmem_load_vec.partition_S(tStStats1_t2r_src) + tScStats_t2r = thr_tmem_load_vec.partition_D(tScS_vec_corr) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_corr = mRequestIndices[work_idx] + qo_tile_corr = mQoTileIndices[work_idx] + seqused_k_corr = mSeqUsedK[batch_idx_corr] + split_idx_corr = mKvTileIndices[work_idx] + kv_pages_corr = ( + seqused_k_corr + page_size - Int32(1)) // page_size + kv_page_begin_corr = split_idx_corr * kv_chunk_size_pages + kv_page_end_corr = cutlass.min( + kv_pages_corr, + kv_page_begin_corr + kv_chunk_size_pages, + ) + page_count_corr = kv_page_end_corr - kv_page_begin_corr + block_iter_count_corr = ( + page_count_corr + Int32(1)) & ~Int32(1) + stage0_count_corr = block_iter_count_corr // Int32(2) + stage1_count_corr = block_iter_count_corr // Int32(2) + + if stage0_count_corr > Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + if stage1_count_corr > Int32(0): + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + for page_rel_corr in cutlass.range( + Int32(self.s_stage), block_iter_count_corr, unroll=1 + ): + # sm_stats[0] now holds the deferred-exp2 log2-delta: + # 0.0 means "no rescale needed", a negative value is the + # raw delta that needs exp2 to become a true scale factor. + if (page_rel_corr & Int32(1)) == Int32(0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 0], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + scale_corr_log2 = tSrStats[0] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + should_rescale = ( + cute.arch.vote_ballot_sync( + scale_corr_log2 < Float32(0.0)) != 0 + ) + if should_rescale: + scale_corr = cute.math.exp2( + scale_corr_log2, fastmath=True) + self.correction_rescale( + tiled_mma_pv, + tOtO_rs_c[None, None, None, 1], + group_tidx_corr, + scale_corr, + ) + pipeline_s_p_o.consumer_release_w_index(Int32(1)) + + for stage_wait in cutlass.range_constexpr(self.s_stage): + stage_count_wait = ( + stage0_count_corr + if const_expr(stage_wait == 0) + else stage1_count_corr + ) + if stage_count_wait > Int32(0): + pipeline_o_acc.consumer_wait_w_index_phase( + Int32(stage_wait), o_corr_consumer_phase) + + row_sum0 = Float32(0.0) + row_sum1 = Float32(0.0) + row_max0 = -Float32.inf + row_max1 = -Float32.inf + for stage_final in cutlass.range_constexpr(self.s_stage): + if const_expr(stage_final == 0): + pipeline_sm_stats0.consumer_wait_w_index_phase( + Int32(0), sm_stats0_consumer_phase) + sm_stats0_consumer_phase = ( + sm_stats0_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats0_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum0 = tSrStats[0] + row_max0 = tSrStats[1] + pipeline_sm_stats0.consumer_release_w_index(Int32(0)) + else: + pipeline_sm_stats1.consumer_wait_w_index_phase( + Int32(0), sm_stats1_consumer_phase) + sm_stats1_consumer_phase = ( + sm_stats1_consumer_phase ^ Int32(1)) + tSrStats = cute.make_rmem_tensor( + tScStats_t2r.shape, self.qk_acc_dtype) + cute.copy( + thr_tmem_load_vec, tStStats1_t2r, tSrStats) + cute.arch.fence_view_async_tmem_load() + row_sum1 = tSrStats[0] + row_max1 = tSrStats[1] + pipeline_sm_stats1.consumer_release_w_index(Int32(0)) + + zero0 = row_sum0 == Float32(0.0) or row_sum0 != row_sum0 + zero1 = row_sum1 == Float32(0.0) or row_sum1 != row_sum1 + rm0 = -Float32.inf if zero0 else row_max0 + rm1 = -Float32.inf if zero1 else row_max1 + row_max_comb = cutlass.max(rm0, rm1) + row_max_safe = ( + Float32(0.0) + if row_max_comb == -Float32.inf + else row_max_comb + ) + scale0 = ( + Float32(0.0) + if zero0 + else cute.math.exp2( + (rm0 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + scale1 = ( + Float32(0.0) + if zero1 + else cute.math.exp2( + (rm1 - row_max_safe) * softmax_scale_log2, + fastmath=True, + ) + ) + row_sum_comb = row_sum0 * scale0 + row_sum1 * scale1 + combined_zero_or_nan = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + inv_sum = cute.arch.rcp_approx( + Float32(1.0) + if combined_zero_or_nan else row_sum_comb) + final_scale0 = scale0 * inv_sum + final_scale1 = scale1 * inv_sum + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(0), corr_epi_producer_phase) + self.correction_epilogue_combine( + tiled_mma_pv, + sO[None, None, 0], + group_tidx_corr, + final_scale0, + final_scale1, + ) + + if const_expr(self.write_lse or self.split_kv): + if group_tidx_corr < Int32(self.m_block_size): + is_bad_lse = ( + row_sum_comb == Float32(0.0) + or row_sum_comb != row_sum_comb + ) + LN2 = Float32(math.log(2.0)) + lse_val = ( + -Float32.inf if is_bad_lse + else ( + row_max_safe * softmax_scale_log2 + + cute.math.log2(row_sum_comb, fastmath=True) + ) * LN2 + ) + tok_lse = group_tidx_corr // Int32(self.qhead_per_kv) + if tok_lse < seqlen_q: + h_in_kv_lse = ( + group_tidx_corr + - tok_lse * Int32(self.qhead_per_kv)) + q_idx_lse = ( + qo_tile_corr * Int32(self.q_tokens_per_group) + + tok_lse + ) + h_abs_lse = ( + head_kv_idx * Int32(self.qhead_per_kv) + + h_in_kv_lse + ) + if const_expr(self.split_kv): + q_tokens_per_group = Int32( + self.q_tokens_per_group) + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row_lse = ( + mOIndptr[batch_idx_corr] + + split_idx_corr * q_stride_partial + + q_idx_lse + ) + mLSE_partial[ + partial_row_lse, h_abs_lse] = lse_val + else: + q_abs_lse = ( + batch_idx_corr * seqlen_q + q_idx_lse) + mLSE[q_abs_lse, h_abs_lse] = lse_val + + for stage_release in cutlass.range_constexpr(self.s_stage): + stage_count_release = ( + stage0_count_corr + if const_expr(stage_release == 0) + else stage1_count_corr + ) + if stage_count_release > Int32(0): + pipeline_s_p_o.consumer_release_w_index( + Int32(stage_release)) + pipeline_o_acc.consumer_release_w_index( + Int32(stage_release)) + if block_iter_count_corr > Int32(0): + o_corr_consumer_phase = ( + o_corr_consumer_phase ^ Int32(1)) + + pipeline_o_epi.producer_commit_w_index(Int32(0)) + corr_epi_producer_phase = ( + corr_epi_producer_phase ^ Int32(1)) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_o_epi.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), corr_epi_producer_phase) + + @cute.jit + def epilogue_s2g( + self, + mO_tma: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + pipeline_o_epi: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mOIndptr: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + ) -> None: + epi_consumer_phase = Int32(0) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + split_idx = mKvTileIndices[work_idx] + + pipeline_o_epi.consumer_wait_w_index_phase( + Int32(0), epi_consumer_phase) + q_tokens_per_group = Int32(self.q_tokens_per_group) + gO = cute.local_tile( + mO_tma[None, None, head_kv_idx], + self.epi_tile, + (None, 0), + ) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO) + if const_expr(not self.split_kv): + q_abs = ( + batch_idx * seqlen_q + + qo_tile * q_tokens_per_group + ) + dst_idx = q_abs // q_tokens_per_group + else: + q_stride_partial = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + partial_row = ( + mOIndptr[batch_idx] + + split_idx * q_stride_partial + + qo_tile * q_tokens_per_group + ) + dst_idx = partial_row // q_tokens_per_group + store_O(src_idx=Int32(0), dst_idx=dst_idx) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0) + pipeline_o_epi.consumer_release_w_index(Int32(0)) + epi_consumer_phase = epi_consumer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def correction_epilogue_combine( + self, + tiled_mma_pv: cute.TiledMma, + sO: cute.Tensor, + tidx: Int32, + scale0: Float32, + scale1: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + pv_acc_shape = thr_mma.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO_base = thr_mma.make_fragment_C(pv_acc_shape) + tOtO = cute.make_tensor( + tOtO_base.iterator + Int32(self.tmem_o_offset), + cute.append( + tOtO_base.layout, + cute.make_layout( + (self.s_stage,), + stride=(self.tmem_o_stage_stride,), + ), + ), + ) + tOsO = thr_mma.get_slice(0).partition_C(sO) + tOcO_full = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = ( + 8 * 32 // self.o_dtype.width + ) + tOsO_i = cute.logical_divide( + tOsO, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOcO_i = cute.logical_divide( + tOcO_full, + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO0_i = cute.logical_divide( + tOtO[None, None, None, 0], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + tOtO1_i = cute.logical_divide( + tOtO[None, None, None, 1], + cute.make_layout((self.m_block_size, corr_tile_size)), + ) + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_load_atom = sm100_utils.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=self.use_2cta_instrs, + ) + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO0_i[(None, None), 0]) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load) + tiled_smem_store = cute.make_tiled_copy_D( + smem_copy_atom, tiled_tmem_load) + tOtO0_t2r = thr_tmem_load.partition_S( + tOtO0_i[(None, None), None]) + tOtO1_t2r = thr_tmem_load.partition_S( + tOtO1_i[(None, None), None]) + tOsO_s2r = copy_utils.partition_D_position_independent( + thr_tmem_load, tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D( + tOcO_i[(None, None), None]) + + for col_pass_idx in cutlass.range( + self.head_dim // corr_tile_size, unroll_full=True): + tOtO0_t2r_i = tOtO0_t2r[None, 0, 0, col_pass_idx] + tOtO1_t2r_i = tOtO1_t2r[None, 0, 0, col_pass_idx] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, col_pass_idx] + frg_shape = tOcO_t2r[None, 0, 0, col_pass_idx].shape + tOrO0_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + tOrO1_frg = cute.make_fragment(frg_shape, self.pv_acc_dtype) + is_zero_output = ( + scale0 == Float32(0.0) and scale1 == Float32(0.0) + ) + if not is_zero_output: + cute.copy(tiled_tmem_load, tOtO0_t2r_i, tOrO0_frg) + cute.copy(tiled_tmem_load, tOtO1_t2r_i, tOrO1_frg) + for j in cutlass.range( + 0, cute.size(tOrO0_frg), 2, unroll_full=True + ): + o0_a, o0_b = cute.arch.mul_packed_f32x2( + (tOrO0_frg[j], tOrO0_frg[j + 1]), + (scale0, scale0), + ) + o1_a, o1_b = cute.arch.mul_packed_f32x2( + (tOrO1_frg[j], tOrO1_frg[j + 1]), + (scale1, scale1), + ) + tOrO0_frg[j], tOrO0_frg[j + 1] = ( + cute.arch.add_packed_f32x2( + (o0_a, o0_b), (o1_a, o1_b)) + ) + else: + tOrO0_frg.fill(Float32(0.0)) + copy_utils.cvt_copy(tiled_smem_store, tOrO0_frg, tOsO_r2s_i) + cute.arch.fence_view_async_shared() + + @cute.jit + def correction_rescale( + self, + tiled_mma_pv: cute.TiledMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ) -> None: + thr_mma = tiled_mma_pv.get_slice(0) + tOcO = thr_mma.partition_C( + cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size: cutlass.Constexpr[int] = 16 + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition( + tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count: cutlass.Constexpr[int] = self.head_dim // corr_tile_size + for fi in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment( + tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor( + tOtO_t2r.iterator + fi * corr_tile_size, + tOtO_t2r.layout, + ) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range( + 0, cute.size(tOrO_frg), 2, unroll_full=True + ): + tOrO_frg[j], tOrO_frg[j + 1] = ( + cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + ) + tOtO_r2t_i = cute.make_tensor( + tOtO_r2t.iterator + fi * corr_tile_size, + tOtO_r2t.layout, + ) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tP_layout: cute.ComposedLayout, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + thr_mma_qk = tiled_mma_qk.get_slice(0) + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tSrQ0_layout = tSrQ[None, None, None, 0].layout + tSrK0_layout = tSrK[None, None, None, 0].layout + qk_mma_op = tiled_mma_qk.op + q_smem_base = sm100_desc.smem_desc_base_from_tensor( + sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor( + sK, sm100_desc.Major.K) + q_smem_start = sm100_desc.make_smem_desc_start_addr( + sQ[None, None, None, 0].iterator) + sm100_helpers.declare_ptx_smem_desc( + q_smem_start, q_smem_base, tSrQ0_layout, + var_name_prefix="decode_q_smem_desc", + ) + sm100_helpers.declare_ptx_idesc( + qk_mma_op, var_name="decode_qk_idesc") + gemm_qk = partial( + sm100_helpers.gemm_ptx_precomputed_varname, + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK0_layout, + smem_var_name_prefix="decode_q_smem_desc", + idesc_var_name="decode_qk_idesc", + smem_offset=0, + zero_init=True, + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + + thr_mma_pv = tiled_mma_pv.get_slice(0) + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_base = thr_mma_qk.make_fragment_C(qk_acc_shape) + tStS = cute.make_tensor( + tStS_base.iterator, + cute.append( + tStS_base.layout, + cute.make_layout( + (self.s_stage,), stride=(self.tmem_stage_stride,)), + ), + ) + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP_base = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + tP_width_ratio = const_expr(Float32.width // self.v_dtype.width) + tP_stage_stride = const_expr( + self.tmem_stage_stride * tP_width_ratio) + tOrP = cute.make_tensor( + tOrP_base.iterator + self.tmem_p_offset * tP_width_ratio, + cute.append( + tOrP_base.layout, + cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)), + ), + ) + tOrV = tiled_mma_pv.make_fragment_B(sV) + pv_mma_op = tiled_mma_pv.op + sm100_helpers.declare_ptx_idesc( + pv_mma_op, var_name="decode_pv_idesc") + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage) + phase_s0 = Int32(0) + phase_s1 = Int32(0) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_mma = mRequestIndices[work_idx] + split_idx_mma = mKvTileIndices[work_idx] + seqused_k_mma = mSeqUsedK[batch_idx_mma] + kv_pages_mma = ( + seqused_k_mma + page_size - Int32(1)) // page_size + kv_page_begin_mma = split_idx_mma * kv_chunk_size_pages + kv_page_end_mma = cutlass.min( + kv_pages_mma, + kv_page_begin_mma + kv_chunk_size_pages, + ) + page_count_mma = kv_page_end_mma - kv_page_begin_mma + block_iter_count_mma = ( + page_count_mma + Int32(1)) & ~Int32(1) + + pipeline_q.consumer_wait_w_index_phase( + Int32(0), mma_q_consumer_phase) + mma_q_consumer_phase = mma_q_consumer_phase ^ Int32(1) + if block_iter_count_mma > Int32(0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(0)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(1): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(Int32(1)) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + if block_iter_count_mma > Int32(self.s_stage): + for page_rel_pv in cutlass.range( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + unroll=1, + ): + pv_slot = page_rel_pv & Int32(1) + pv_stage_iter = page_rel_pv // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + page_rel_qk = page_rel_pv + Int32(self.s_stage) + qk_slot = page_rel_qk & Int32(1) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + k_smem_start = sm100_desc.make_smem_desc_start_addr( + sK[ + None, None, None, + mma_kv_consumer_state.index, + ].iterator + ) + gemm_qk( + Int32(self.tmem_s_offset) + + qk_slot * Int32(self.tmem_stage_stride), + smem_desc_start_b=k_smem_start, + ) + pipeline_s_p_o.producer_commit_w_index(qk_slot) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + pipeline_q.consumer_release_w_index(Int32(0)) + + if block_iter_count_mma > Int32(0): + page_rel_epi_begin = cutlass.max( + Int32(0), + block_iter_count_mma - Int32(self.s_stage), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin, block_iter_count_mma, unroll=1 + ): + pv_slot = page_rel_epi & Int32(1) + pv_stage_iter = page_rel_epi // Int32(self.s_stage) + pv_phase = phase_s0 + if pv_slot != Int32(0): + pv_phase = phase_s1 + pipeline_s_p_o.producer_acquire_w_index_phase( + pv_slot, pv_phase) + pipeline_kv.consumer_wait(mma_kv_consumer_state) + v_idx = mma_kv_consumer_state.index + sm100_helpers.gemm_ptx_partial( + pv_mma_op, + Int32(self.tmem_o_offset) + + pv_slot * Int32(self.tmem_o_stage_stride), + tOrP[None, None, None, pv_slot], + tOrV[None, None, None, v_idx], + sA=None, + sB=sV[None, None, None, v_idx], + tA_addr=( + Int32(self.tmem_p_offset) + + pv_slot * Int32(self.tmem_stage_stride) + ), + zero_init=pv_stage_iter == Int32(0), + mbar_ptr=( + pipeline_p_lastsplit + .sync_object_full.get_barrier(pv_slot) + if self.split_P_arrive > 0 else None + ), + mbar_phase=( + pv_phase + if self.split_P_arrive > 0 else None), + split_arrive=( + self.split_P_arrive + if self.split_P_arrive > 0 else None + ), + cta_group=self.cta_group_size, + mma_kind=self.mma_kind, + ) + pipeline_o_acc.producer_commit_w_index(pv_slot) + if pv_slot == Int32(0): + phase_s0 = phase_s0 ^ Int32(1) + else: + phase_s1 = phase_s1 ^ Int32(1) + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + + work_tile = tile_scheduler.consumer_advance() + + @cute.jit + def softmax_loop( + self, + stage: cutlass.Constexpr[int], + warp_base: cutlass.Constexpr[int], + softmax_scale_log2: Float32, + tStS_pre: cute.Tensor, + tScS_pre: cute.Tensor, + tilePlikeFP32: cutlass.Constexpr[int], + tmem_load_atom_pre: cute.CopyAtom, + tmem_store_atom_pre: cute.CopyAtom, + tmem_store_vec_atom_pre: cute.CopyAtom, + thr_mma_qk_pre: cute.core.ThrMma, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + seqlen_q: Int32, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + warp_idx_in_wg = warp_idx - Int32(warp_base) + group_tidx = ( + warp_idx_in_wg * Int32(cute.arch.WARP_SIZE) + + tidx % Int32(cute.arch.WARP_SIZE) + ) + stage_i32 = Int32(stage) + + tSAcc = tStS_pre[(None, None), 0, 0, stage] + thr_tmem_load = tcgen05.make_tmem_copy( + tmem_load_atom_pre, tSAcc).get_slice(group_tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) + tScS_t2r = thr_tmem_load.partition_D(tScS_pre) + tStP_layout = cute.composition( + tSAcc.layout, + cute.make_layout((self.m_block_size, tilePlikeFP32)), + ) + tStP = cute.make_tensor( + tSAcc.iterator + self.tmem_s_to_p_offset, + tStP_layout, + ) + thr_tmem_store = tcgen05.make_tmem_copy( + tmem_store_atom_pre, tStP).get_slice(group_tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + tScS_vec_layout = cute.composition( + tScS_pre.layout, cute.make_layout((self.m_block_size, 2))) + tScS_vec = cute.make_tensor(tScS_pre.iterator, tScS_vec_layout) + tStS_vec_layout = cute.composition( + tSAcc.layout, cute.make_layout((self.m_block_size, 2))) + tStStats_r2t_dst = cute.make_tensor( + tSAcc.iterator, tStS_vec_layout) + thr_tmem_store_vec = tcgen05.make_tmem_copy( + tmem_store_vec_atom_pre, + tStStats_r2t_dst, + ).get_slice(group_tidx) + tStStats_r2t = thr_tmem_store_vec.partition_D(tStStats_r2t_dst) + tScStats_r2t = thr_tmem_store_vec.partition_S(tScS_vec) + tScP_shape = ( + self.mma_tiler_qk[0] // thr_mma_qk_pre.thr_id.shape, + tilePlikeFP32, + ) + + tSrP_r2t_f32 = cute.make_rmem_tensor( + thr_tmem_store.partition_S( + cute.make_identity_tensor(tScP_shape)).shape, + Float32, + ) + s_consumer_phase = Int32(0) + sm_stats_producer_phase = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, _, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=self.rescale_threshold, + ) + softmax.reset() + batch_idx = mRequestIndices[work_idx] + qo_tile = mQoTileIndices[work_idx] + seqused_k = mSeqUsedK[batch_idx] + split_idx = mKvTileIndices[work_idx] + kv_pages = ( + seqused_k + page_size - Int32(1)) // page_size + kv_page_begin = split_idx * kv_chunk_size_pages + kv_page_end = cutlass.min( + kv_pages, kv_page_begin + kv_chunk_size_pages + ) + page_count = kv_page_end - kv_page_begin + block_iter_count = (page_count + Int32(1)) & ~Int32(1) + if const_expr(stage == 0): + stage_page_count = block_iter_count // Int32(2) + else: + stage_page_count = block_iter_count // Int32(2) + + seqlen_info = SeqlenInfoQK( + Int32(0), + Int32(0), + Int32(0), + Int32(0), + seqlen_q, + seqused_k, + False, + False, + False, + True, + ) + mask = AttentionMask( + self.m_block_size, + self.n_block_size, + seqlen_info, + qhead_per_kvhead_packgqa=self.qhead_per_kv, + ) + wg_count = stage_page_count + if wg_count > Int32(0): + page_rel0 = stage_i32 + page_rel0_clamped = cutlass.min( + page_rel0, page_count - Int32(1)) + page_idx_global = kv_page_end - Int32(1) - page_rel0_clamped + kv_valid_cols = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global * page_size, + ) + if page_rel0 >= page_count: + kv_valid_cols = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, + mask, + stage_i32, + s_consumer_phase, + page_idx_global, + qo_tile, + kv_valid_cols, + tStS_t2r, + tScS_t2r, + tStP_r2t, + tSrP_r2t_f32, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, + warp_idx_in_wg, + tStStats_r2t, + tScStats_r2t, + sm_stats_producer_phase, + is_first=True, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + for stage_iter in cutlass.range( + Int32(1), wg_count, unroll=1 + ): + page_rel = ( + stage_iter * Int32(self.s_stage) + stage_i32) + page_rel_clamped = cutlass.min( + page_rel, page_count - Int32(1)) + page_idx_global_n = ( + kv_page_end - Int32(1) - page_rel_clamped) + kv_valid_cols_n = cutlass.min( + Int32(self.n_block_size), + seqused_k - page_idx_global_n * page_size, + ) + # Dummy-iter analysis: with s_stage=2, the WG that + # handles stage_i32=0 only ever sees page_rel ≤ + # block_iter_count - 2 < page_count → NEVER dummy. + # The WG with stage_i32=1 sees page_rel = + # block_iter_count - 1 at its last iter, which + # equals page_count iff page_count is odd → only + # WG1 may need the runtime mask_dummy_only guard. + # Pass None for WG0 so the const_expr branch in + # softmax_step eliminates the runtime check + # entirely (compile-time disappears). + if const_expr(stage == 0): + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + # mask_dummy_only=None → no runtime check + ) + else: + is_dummy = page_rel >= page_count + if is_dummy: + kv_valid_cols_n = Int32(0) + sm_stats_producer_phase = self.softmax_step( + softmax, mask, stage_i32, s_consumer_phase, + page_idx_global_n, qo_tile, kv_valid_cols_n, + tStS_t2r, tScS_t2r, tStP_r2t, tSrP_r2t_f32, + thr_tmem_load, thr_tmem_store, thr_tmem_store_vec, + pipeline_s_p_o, pipeline_p_lastsplit, + pipeline_sm_stats, + group_tidx, warp_idx_in_wg, + tStStats_r2t, tScStats_r2t, + sm_stats_producer_phase, + is_first=False, + apply_mask=False, + mask_dummy_only=is_dummy, + ) + s_consumer_phase = s_consumer_phase ^ Int32(1) + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = softmax.row_sum[0] + tSrStats[1] = softmax.row_max[0] + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + else: + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = Float32(0.0) + tSrStats[1] = -Float32.inf + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + + work_tile = tile_scheduler.consumer_advance() + + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + mask: AttentionMask, + stage: Int32, + s_phase: Int32, + page_idx: Int32, + qo_tile: Int32, + kv_valid_cols: Int32, + tStS_t2r: cute.Tensor, + tScS_t2r: cute.Tensor, + tStP_r2t: cute.Tensor, + tSrP_r2t_f32: cute.Tensor, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_vec: cute.CopyAtom, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + group_tidx: Int32, + warp_idx_in_wg: Int32, + tStStats_r2t: cute.Tensor, + tScStats_r2t: cute.Tensor, + sm_stats_producer_phase: Int32, + is_first: cutlass.Constexpr[bool], + apply_mask: cutlass.Constexpr[bool] = True, + mask_dummy_only: Optional[cutlass.Boolean] = None, + ) -> Int32: + # apply_mask=False is the inner-page fast path: skip both the seqlen + # bounds check and the causal-diagonal check, which together cost ~15 + # cyc per iter on the producer pre-publication critical path that + # gates correction WG's consumer_wait (top long_scoreboard PC in NCU). + # Callers must only set apply_mask=False when they can prove the tile + # is fully unmasked (no partial-page seqlen tail, no causal diagonal + # cut). + # + # mask_dummy_only (runtime bool, used only when apply_mask=False): + # when True the iter is a "dummy" rounded-up iter that needs the + # mask to zero out garbage S — runs the mask at runtime cost. For + # non-dummy iters it stays the fast no-mask path. + pipeline_s_p_o.consumer_wait_w_index_phase(stage, s_phase) + sm_stats_try_acquire = ( + pipeline_sm_stats.producer_try_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase) + ) + tSrS_t2r = cute.make_rmem_tensor( + tScS_t2r.shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(apply_mask): + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + elif const_expr(mask_dummy_only is not None): + if mask_dummy_only: + # Dummy iter — zero everything via mask (kv_valid_cols=0 + # makes mask_r2p_lambda set all positions to -inf). + mask.apply_mask_sm100( + tSrS_t2r, + tScS_t2r, + m_block=qo_tile, + n_block=page_idx, + mask_seqlen=True, + mask_causal=self.causal, + kv_valid_cols=kv_valid_cols, + ) + # Publish acc_scale in log2-domain (un-exp2'd); correction WG does + # the exp2 only when an actual rescale fires. Removes MUFU.EX2 from + # the sm_stats publication critical path that gates correction's + # consumer_wait (the dominant long_scoreboard hot PC in NCU). + row_max, acc_scale_log2 = softmax.update_row_max_deferred_exp2( + tSrS_t2r.load(), is_first) + pipeline_sm_stats.producer_acquire_w_index_phase( + Int32(0), sm_stats_producer_phase, sm_stats_try_acquire) + tSrStats = cute.make_rmem_tensor( + tScStats_r2t.shape, self.qk_acc_dtype) + tSrStats[0] = acc_scale_log2 + tSrStats[1] = row_max + cute.copy(thr_tmem_store_vec, tSrStats, tStStats_r2t) + cute.arch.fence_view_async_tmem_store() + pipeline_sm_stats.producer_commit_w_index(Int32(0)) + sm_stats_producer_phase = sm_stats_producer_phase ^ Int32(1) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # exp2 for the internal row_sum carry happens AFTER producer_commit, so + # it no longer extends correction's consumer-wait window. + # acc_scale_log2 == 0.0 in the threshold/first-iter paths makes + # exp2(0)=1.0, which is the no-rescale identity for the row_sum carry — + # semantically equivalent to the original ``acc_scale=1.0`` branch. + if const_expr(is_first): + row_sum_init = Float32(0.0) + else: + acc_scale_mult = cute.math.exp2(acc_scale_log2, fastmath=True) + row_sum_init = softmax.row_sum[0] * acc_scale_mult + # Bulk EX2 emulation parameters. + # + # ex2_emu_freq=16 emulate exp2 with FFMA2 polynomial on + # 15 of every 16 (j, k) positions; the + # remaining 1/16 still issues MUFU.EX2. + # This cuts the MUFU.EX2 throughput bottleneck + # in the softmax inner loop (≈22k cyc + # saved per stage at baseline). + # ex2_emu_res=3 degree-3 polynomial; res=4 broke + # kv=1024 close-tolerance even with + # poly_degree=5 — 3 is the most aggressive + # setting that still passes cos_sim ≥ 0.99 + # against the reference for the fp8 PV path. + # ex2_emu_start_frg=1 skip the emulation for fragment index 0 + # (preserves accuracy on the first iter + # where row_max is least settled). + # + # If you tune these, re-run the variable-kv self-consistency check + # (split vs non-split must stay at cos_min ≥ 0.99). + softmax.row_sum[0] = softmax.scale_apply_exp2_convert_sum( + tSrS_t2r, + row_max, + tSrP_r2t, + row_sum_init, + ex2_emu_freq=16, + ex2_emu_res=3, + ex2_emu_start_frg=1, + ) + for k in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): + cute.copy( + thr_tmem_store, + tSrP_r2t_f32[None, None, k], + tStP_r2t[None, None, k], + ) + if const_expr(self.split_P_arrive > 0): + split_P_arrive_idx = ( + cute.size(tStP_r2t.shape[2]) + * self.split_P_arrive + // self.n_block_size + ) + if const_expr(k + 1 == split_P_arrive_idx): + cute.arch.fence_view_async_tmem_store() + pipeline_s_p_o.consumer_release_w_index(stage) + cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_p_lastsplit.producer_commit_w_index(stage) + else: + pipeline_s_p_o.consumer_release_w_index(stage) + return sm_stats_producer_phase + + @cute.jit + def load( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mQ: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + mRequestIndices: cute.Tensor, + mQoTileIndices: cute.Tensor, + mKvTileIndices: cute.Tensor, + mSeqUsedK: cute.Tensor, + mBlockValidMask: cute.Tensor, + tile_scheduler: DecodeTileScheduler, + page_size: Int32, + kv_chunk_size_pages: Int32, + ) -> None: + cute.arch.setmaxregister_decrease(self.num_regs_load) + thr_mma_qk_ld = tiled_mma_qk.get_slice(0) + thr_mma_pv_ld = tiled_mma_pv.get_slice(0) + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_idx, head_kv_idx, _, _ = work_tile.tile_idx + if mBlockValidMask[work_idx] != Int32(0): + batch_idx_ld = mRequestIndices[work_idx] + qo_tile_ld = mQoTileIndices[work_idx] + split_idx_ld = mKvTileIndices[work_idx] + seqused_k_ld = mSeqUsedK[batch_idx_ld] + kv_pages_ld = ( + seqused_k_ld + page_size - Int32(1)) // page_size + kv_page_begin_ld = split_idx_ld * kv_chunk_size_pages + kv_page_end_ld = cutlass.min( + kv_pages_ld, kv_page_begin_ld + kv_chunk_size_pages + ) + page_count_ld = kv_page_end_ld - kv_page_begin_ld + block_iter_count_ld = ( + page_count_ld + Int32(1)) & ~Int32(1) + physical_page_v0 = Int32(0) + physical_page_v1 = Int32(0) + + mQ_cur_ld = mQ[None, None, None, batch_idx_ld][ + None, None, head_kv_idx + ] + tiler_gQ_ld = ( + (self.mma_tiler_qk[0] * self.q_stage), + self.head_dim, + ) + gQ_ld = cute.local_tile( + mQ_cur_ld, tiler_gQ_ld, (qo_tile_ld, 0)) + gQ_ld = layout_utils.select( + cute.flat_divide(gQ_ld, (self.mma_tiler_qk[0],)), + mode=[0, 2, 1], + ) + tSgQ_ld = thr_mma_qk_ld.partition_A(gQ_ld) + load_Q_fn_full, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ_ld, sQ + ) + mK_cur_ld = mK_paged[None, None, head_kv_idx, None] + gK_ld = cute.local_tile( + mK_cur_ld, + cute.select(self.mma_tiler_qk, mode=[1, 2]), + (None, 0, None), + ) + tSgK_ld = thr_mma_qk_ld.partition_B(gK_ld) + tKsK_ld, tKgK_ld = cpasync.tma_partition( + tma_atom_K, 0, cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_ld, 0, 3), + ) + mV_cur_ld = mV_paged[None, None, head_kv_idx, None] + gV_ld = cute.local_tile( + mV_cur_ld, + cute.select(self.mma_tiler_pv, mode=[1, 2]), + (0, None, None), + ) + tOgV_ld = thr_mma_pv_ld.partition_B(gV_ld) + tVsV_ld, tVgV_ld = cpasync.tma_partition( + tma_atom_V, 0, cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV_ld, 0, 3), + ) + + if block_iter_count_ld > Int32(0): + # Prime K0 before Q; then follow BSA order + # K1, V0, K2, V1, ... + page_idx_ld0 = kv_page_end_ld - Int32(1) + physical_page_v0 = mPageTable[batch_idx_ld, page_idx_ld0] + physical_page_v1 = physical_page_v0 + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v0, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + self.load_Q( + load_Q_fn_full, + pipeline_q, + Int32(0), + q_producer_phase, + ) + q_producer_phase = q_producer_phase ^ Int32(1) + + if block_iter_count_ld > Int32(0): + if block_iter_count_ld > Int32(1): + page_rel_k1 = cutlass.min( + Int32(1), page_count_ld - Int32(1)) + page_idx_ld1 = kv_page_end_ld - Int32(1) - page_rel_k1 + physical_page_v1 = mPageTable[ + batch_idx_ld, page_idx_ld1] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_v1, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + if block_iter_count_ld > Int32(2): + for page_rel in cutlass.range( + Int32(0), + block_iter_count_ld - Int32(2), + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + page_rel_k_ld = cutlass.min( + page_rel + Int32(2), + page_count_ld - Int32(1), + ) + page_idx_k_ld = ( + kv_page_end_ld - Int32(1) - page_rel_k_ld) + physical_page_k_ld = mPageTable[ + batch_idx_ld, page_idx_k_ld] + self.load_KV_physical( + tma_atom_K, + tKgK_ld, + tKsK_ld, + physical_page_k_ld, + pipeline_kv, + kv_producer_state, + ) + if (page_rel & Int32(1)) == Int32(0): + physical_page_v0 = physical_page_k_ld + else: + physical_page_v1 = physical_page_k_ld + kv_producer_state.advance() + + page_rel_epi_begin_ld = cutlass.max( + Int32(0), + block_iter_count_ld - Int32(2), + ) + for page_rel_epi in cutlass.range( + page_rel_epi_begin_ld, + block_iter_count_ld, + unroll=1, + ): + page_rel_v_ld = cutlass.min( + page_rel_epi, page_count_ld - Int32(1)) + physical_page_v_ld = physical_page_v0 + if (page_rel_epi & Int32(1)) != Int32(0): + physical_page_v_ld = physical_page_v1 + self.load_KV_physical( + tma_atom_V, + tVgV_ld, + tVsV_ld, + physical_page_v_ld, + pipeline_kv, + kv_producer_state, + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.consumer_advance() + + pipeline_kv.producer_tail(kv_producer_state) + pipeline_q.producer_acquire_w_index_phase( + Int32(self.q_stage - 1), q_producer_phase) + + @cute.jit + def load_Q( + self, + load_Q_fn: Callable, + pipeline_q: pipeline.PipelineAsync, + stage: Int32, + phase: Int32, + ) -> None: + pipeline_q.producer_acquire_w_index_phase(stage, phase) + load_Q_fn( + src_idx=Int32(0), + dst_idx=stage, + tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage), + ) + + @cute.jit + def load_KV_physical( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + physical_page: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + ) -> None: + pipeline_kv.producer_acquire(producer_state) + cute.copy( + tma_atom, + tXgX[(None, 0, physical_page)], + tXsX[(None, producer_state.index)], + tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state), + ) + +_atten_compile_cache: dict[tuple[object, ...], object] = {} + + +def run_decode_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + request_indices: torch.Tensor, + qo_tile_indices: torch.Tensor, + kv_tile_indices: torch.Tensor, + block_valid_mask: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + O_partial: Optional[torch.Tensor], + LSE_partial: Optional[torch.Tensor], + *, + softmax_scale: float, + seqlen_q: int, + page_size: int, + kv_chunk_size_pages: int, + split_kv: bool, + causal: bool, + return_lse: bool = True, + disable_softmax_exp2: bool = False, + O_partial_dummy: Optional[torch.Tensor] = None, + LSE_partial_dummy: Optional[torch.Tensor] = None, +) -> None: + """Launch the SM100 UMMA paged decode attention CUTE DSL kernel. + + qhead_per_kv is derived from input shapes (q.shape[1] // k.shape[1]). + disable_softmax_exp2 toggles the sage-style host flag (decision §1.7); + default False keeps full ex2 emulation. + + ``O_partial_dummy`` / ``LSE_partial_dummy`` let callers pre-allocate the + placeholder buffers for the non-split path, avoiding ~5us of per-call + ``torch.empty`` overhead in tight decoding loops. + """ + + q_dtype = torch2cute_dtype_map[q.dtype] + o_dtype = torch2cute_dtype_map[out.dtype] + qhead_per_kv = q.shape[1] // k.shape[1] + q_tokens_per_group = 128 // int(qhead_per_kv) + write_lse = bool(return_lse) or bool(split_kv) + if int(seqlen_q) != q_tokens_per_group: + raise NotImplementedError( + "decode fp8 currently assumes one full packed-q tile: " + f"seqlen_q must equal {q_tokens_per_group}, got {seqlen_q}" + ) + key = ( + "decode_attention", + q.shape[-1], + q_dtype, + o_dtype, + bool(split_kv), + bool(causal), + int(qhead_per_kv), + int(seqlen_q), + bool(write_lse), + bool(disable_softmax_exp2), + ) + if key not in _atten_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + head_q = cute.sym_int64() + num_pages = cute.sym_int64() + head_kv = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + max_pages = cute.sym_int64() + work_capacity = cute.sym_int64() + partial_rows = cute.sym_int64() + partial_rows_flat = cute.sym_int64() + head_dim = int(q.shape[-1]) + kernel = SparseDecodeAttentionForwardSm100( + head_dim=head_dim, + qhead_per_kv=int(qhead_per_kv), + page_size=int(page_size), + split_kv=bool(split_kv), + causal=bool(causal), + write_lse=bool(write_lse), + disable_softmax_exp2=bool(disable_softmax_exp2), + ) + # Always pass non-None fake tensors so the @cute.kernel positional + # arg marshalling stays stable; the kernel only reads these when + # split_kv=True (decision #10 epilogue branch). + fake_O_partial = make_fake_tensor( + Float32, (partial_rows_flat, head_dim), divisibility=4) + fake_LSE_partial = make_fake_tensor( + Float32, (partial_rows, head_q), divisibility=1, leading_dim=1) + # Q is passed as a [B, Sq, Hq, D] view so the kernel can build the same + # PackGQA TMA view used by FA/BSA and issue one full-tile Q TMA. + # O still uses the compact 2D view for the packed-GQA TMA epilogue. + total_q_flat = cute.sym_int64() + _atten_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor( + q_dtype, (batch, int(seqlen_q), head_q, head_dim), + divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(q_dtype, (num_pages, head_kv, int(page_size), head_dim), divisibility=16), + make_fake_tensor(Int32, (batch, max_pages), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (work_capacity,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(o_dtype, (total_q_flat, head_dim), divisibility=128 // o_dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + fake_O_partial, + fake_LSE_partial, + Float32(float(softmax_scale)), + Int32(int(seqlen_q)), + Int32(int(page_size)), + Int32(int(kv_chunk_size_pages)), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + q_4d = q.view( + q.shape[0] // int(seqlen_q), int(seqlen_q), q.shape[1], q.shape[2]) + out_2d = out.view(out.shape[0] * out.shape[1], out.shape[2]) + # Compile keeps non-None fake partial buffers for positional stability + # (see fake_O_partial / fake_LSE_partial above). Runtime callers that + # don't need them (split_kv=False) pass None; allocate small uninitialized + # dummy buffers so the kernel signature still matches without launching + # torch fill kernels. + if O_partial is None: + # Reuse caller-cached dummy when available (e.g. the + # SparseDecodePagedAttentionWrapper plan() pre-allocation), else + # allocate a small placeholder on the fly. + O_partial_kernel = ( + O_partial_dummy + if O_partial_dummy is not None + else torch.empty( + (1, q.shape[2]), dtype=torch.float32, device=q.device) + ) + else: + O_partial_kernel = O_partial.view( + O_partial.shape[0] * O_partial.shape[1], O_partial.shape[2]) + if LSE_partial is None: + LSE_partial = ( + LSE_partial_dummy + if LSE_partial_dummy is not None + else torch.empty( + (1, q.shape[1]), dtype=torch.float32, device=q.device) + ) + with torch.cuda.nvtx.range("Decode_Attention"): + _atten_compile_cache[key]( + q_4d, k, v, page_table, seqused_k, + request_indices, qo_tile_indices, kv_tile_indices, block_valid_mask, + split_counts, o_indptr, + out_2d, lse, O_partial_kernel, LSE_partial, + softmax_scale, seqlen_q, page_size, kv_chunk_size_pages, + ) + + +__all__ = ["SparseDecodeAttentionForwardSm100", "run_decode_attention"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bab26c200fff9c62644849b18e55f060fa8783f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Paged decode split-KV scheduling backed by the precompiled Torch op. + +The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is +built ahead of time by kernel-builder. The op returns the schedule arrays +plus a fixed-order scalar summary, which is reassembled into the schedule +dict here. +""" + +from __future__ import annotations + +import torch + +from ....._ops import ops + +# Order of the scalar summary returned by the op; must match +# csrc/build_decode_schedule.cu. +_SCALAR_KEYS = ( + "split_kv", + "cta_tile_q", + "num_q_tiles", + "kv_chunk_size_pages", + "kv_chunk_size_tokens", + "work_count", + "padded_work_count", + "partial_rows", + "max_split_count", + "max_grid_size", + "active_blocks_per_sm", + "num_sms", + "base_cta", +) + + +def build_decode_schedule( + seqused_k: torch.Tensor, + *, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: int = 0, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> dict[str, object]: + """GPU-only schedule build: single CUDA kernel produces all schedule + index arrays on device. Only a small summary tensor is D2H'd at the end + so the wrapper can size O_partial, pick the kernel grid, and choose + split/non-split compile path. + + ``max_seqlen_k`` is required as the host-side worst-case bound for + padding the work-tile arrays. + """ + + ( + request_indices, + qo_tile_indices, + kv_tile_indices, + block_valid_mask, + split_counts, + kv_pages, + merge_indptr, + o_indptr, + scalars, + ) = ops.build_decode_schedule( + seqused_k, + int(page_size), + int(seqlen_q), + int(num_qo_heads), + int(num_kv_heads), + int(head_dim), + int(max_seqlen_k), + bool(enable_cuda_graph), + int(max_grid_size), + int(fixed_split_size), + bool(disable_split_kv), + ) + + raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars))) + raw["split_kv"] = bool(raw["split_kv"]) + raw["request_indices"] = request_indices + raw["qo_tile_indices"] = qo_tile_indices + raw["kv_tile_indices"] = kv_tile_indices + raw["block_valid_mask"] = block_valid_mask + raw["split_counts"] = split_counts + raw["kv_pages"] = kv_pages + raw["merge_indptr"] = merge_indptr + raw["o_indptr"] = o_indptr + + # The CUDA kernel writes into worst-case-padded buffers (size = + # batch * num_q_tiles * max_pages_global) but only the first + # ``padded_work_count`` entries are valid. Downstream consumers + # (tile_scheduler) take grid size from ``request_indices.shape[0]`` + # so we narrow the views to that count; the underlying allocation + # is unchanged so this is a view, no copy. + pad = int(raw["padded_work_count"]) + for key in ( + "request_indices", + "qo_tile_indices", + "kv_tile_indices", + "block_valid_mask", + ): + raw[key] = raw[key].narrow(0, 0, pad) + return raw + + +__all__ = ["build_decode_schedule"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/combine.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/combine.py new file mode 100644 index 0000000000000000000000000000000000000000..3d308bd26c281e744cc7289b1265d8192c1f39e7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/combine.py @@ -0,0 +1,680 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""LDGSTS split-KV combine for paged decode attention.""" + +import math +from functools import partial +from typing import Type + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Int64, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.cute.nvgpu import cpasync + +from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map + + +class SparseDecodeForwardCombine: + """Combine split-KV decode partials with FA-style LDGSTS staging. + + ``mO_partial`` and ``mLSE_partial`` use the split-major padded layout: + ``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where + ``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``. + A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one + ``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM + via ``cpasync.CopyG2SOp`` before the split reduction. + """ + + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + *, + tile_m: int = 64, + k_block_size: int = 128, + max_splits: int = 4, + num_threads: int = 256, + stages: int = 2, + ): + if head_dim != 128: + raise NotImplementedError( + f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}" + ) + if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]: + raise TypeError(f"Unsupported output dtype: {dtype}") + if dtype_partial is not Float32: + raise TypeError("decode O_partial must be Float32") + if k_block_size != head_dim: + raise NotImplementedError("decode combine currently uses one D=128 k block") + if tile_m % 8 != 0: + raise ValueError("decode combine tile_m must be divisible by 8") + if max_splits < 1 or max_splits > 256: + raise ValueError("decode combine max_splits must be in [1, 256]") + + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.tile_m = tile_m + self.k_block_size = k_block_size + self.max_splits = max_splits + self.num_threads = num_threads + self.stages = stages + self.is_even_k = head_dim % k_block_size == 0 + + def _setup_attributes(self) -> None: + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 + if self.k_block_size % 128 == 0 + else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout + ) + + lse_copy_bits = Float32.width + m_block_smem = ( + 128 + if self.tile_m % 128 == 0 + else ( + 64 + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, cute.make_layout(1) + ) + + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + lse_atom_splits = min(self.max_splits, 8) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)), + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) + ) + self.smem_layout_o = cute.make_ordered_layout( + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32 + mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32 + mSplitCounts: cute.Tensor, # [B] int32 + mOIndptr: cute.Tensor, # [B + 1] int32 + mO: cute.Tensor, # [total_q, Hq, D] + mLSE: cute.Tensor, # [total_q, Hq] fp32 + seqlen_q: Int32, + q_tokens_per_group: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mO_partial.element_type is not Float32): + raise TypeError("decode O_partial tensor must be Float32") + if const_expr(mLSE_partial.element_type is not Float32): + raise TypeError("decode LSE_partial tensor must be Float32") + if const_expr(mLSE.element_type is not Float32): + raise TypeError("decode LSE tensor must be Float32") + if const_expr(mO.element_type != self.dtype): + raise TypeError("decode O tensor dtype must match kernel dtype") + if const_expr(mSplitCounts.element_type is not Int32): + raise TypeError("decode split_counts tensor must be Int32") + if const_expr(mOIndptr.element_type is not Int32): + raise TypeError("decode o_indptr tensor must be Int32") + + mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [ + assume_tensor_aligned(t) + for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE) + ] + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.tile_m], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + total_q = mO.shape[0] + head_q = mO.shape[1] + batch = mSplitCounts.shape[0] + head_divmod = FastDivmodDivisor(head_q) + grid = ( + cute.ceil_div(seqlen_q * head_q, self.tile_m), + cute.ceil_div(self.head_dim, self.k_block_size), + batch, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mSplitCounts, + mOIndptr, + mO, + mLSE, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + head_divmod, + Int32(total_q), + Int32(head_q), + seqlen_q, + q_tokens_per_group, + ).launch( + grid=grid, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mSplitCounts: cute.Tensor, + mOIndptr: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + head_divmod: FastDivmodDivisor, + total_q: Int32, + head_q: Int32, + seqlen_q: Int32, + q_tokens_per_group: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) + sO = storage.sO.get_tensor(smem_layout_o) + + split_count = mSplitCounts[batch_idx] + q_stride = ( + (seqlen_q + q_tokens_per_group - Int32(1)) + // q_tokens_per_group + ) * q_tokens_per_group + max_idx = seqlen_q * head_q + + if m_block * Int32(self.tile_m) < max_idx: + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + partial_base = mOIndptr[batch_idx] + q_idx + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] + if si < split_count: + partial_row = partial_base + si * q_stride + lse_ptr = ( + mLSE_partial.iterator + + Int64(partial_row) * Int64(head_q) + + Int64(q_head) + ) + lse_gmem_ptr = cute.make_ptr( + Float32, + lse_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=4, + ) + lse_src = cute.make_tensor(lse_gmem_ptr, (1,)) + cute.copy( + gmem_thr_copy_LSE, + lse_src, + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + else: + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + tLSEsLSE[None, s, m].fill(-Float32.inf) + cute.arch.cp_async_commit_group() + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOqidx = cute.make_rmem_tensor(num_rows, Int32) + tOhidx = cute.make_rmem_tensor(num_rows, Int32) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] + idx = m_block * Int32(self.tile_m) + mi + if idx >= max_idx: + tOqidx[m] = Int32(0) + tOhidx[m] = -Int32(1) + else: + tOqidx[m], tOhidx[m] = divmod(idx, head_divmod) + + load_O_partial = partial( + self.load_O_partial, + mO_partial, + mOIndptr, + gmem_tiled_copy_O_partial, + tOsO_partial, + tOqidx, + tOhidx, + tOcO, + batch_idx, + q_stride, + split_count, + head_q, + k_block, + ) + + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < split_count: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = cute.arch.warp_reduction_max( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + threads_in_group=threads_per_col, + ) + max_valid_idx = -Int32(1) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] + max_valid_split[m] = cute.arch.warp_reduction_max( + max_valid_idx, threads_in_group=threads_per_col + ) + + lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max + LOG2_E = Float32(math.log2(math.e)) + lse_sum_cur = Float32(0.0) + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = cute.math.exp2( + (ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E, + fastmath=True, + ) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale + lse_sum_cur = cute.arch.warp_reduction_sum( + lse_sum_cur, threads_in_group=threads_per_col + ) + lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max + inv_sum = ( + Float32(0.0) + if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur) + else cute.arch.rcp_approx(lse_sum_cur) + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + if mi < Int32(self.tile_m): + sMaxValidSplit[mi] = max_valid_split[m] + + if k_block == Int32(0): + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == Int32(0): + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * Int32(self.tile_m) + mi + if idx < max_idx: + q_idx, q_head = divmod(idx, head_divmod) + q_abs = batch_idx * seqlen_q + q_idx + mLSE[q_abs, q_head] = lse_sum[m] + + cute.arch.sync_threads() + + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): + thr_max_valid_split = max( + thr_max_valid_split, + sMaxValidSplit[tOcO[0, m, 0][0]], + ) + + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) + tOrO.fill(Float32(0.0)) + + stage_load = self.stages - 1 + stage_compute = 0 + for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4): + scale = cute.make_rmem_tensor(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] + + split_to_load = s + Int32(self.stages - 1) + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + cute.arch.cp_async_wait_group(self.stages - 1) + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0): + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= Int32(0): + q_abs = batch_idx * seqlen_q + tOqidx[m] + row_ptr = ( + mO.iterator + + ( + (Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + mO_partial: cute.Tensor, + mOIndptr: cute.Tensor, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOsO_partial: cute.Tensor, + tOqidx: cute.Tensor, + tOhidx: cute.Tensor, + tOcO: cute.Tensor, + batch_idx: Int32, + q_stride: Int32, + split_count: Int32, + head_q: Int32, + k_block: Int32, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= Int32(0): + if split < split_count: + partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m] + row_ptr = ( + mO_partial.iterator + + ( + (Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m])) + * Int64(self.head_dim) + + Int64(k_block * Int32(self.k_block_size)) + ) + ) + row_gmem_ptr = cute.make_ptr( + mO_partial.element_type, + row_ptr.toint(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + mO_partial_row = cute.make_tensor( + row_gmem_ptr, + cute.make_layout((self.k_block_size,)), + ) + mO_partial_row_copy = cute.tiled_divide( + mO_partial_row, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_tiled_copy_O_partial, + mO_partial_row_copy[None, k_idx], + tOsO_partial_cur[None, m, k], + ) + else: + tOsO_partial_cur[None, m, None].fill(Float32(0.0)) + + +_combine_compile_cache: dict[tuple[object, ...], object] = {} + + +def _next_power_of_2(x: int) -> int: + return 1 << (max(int(x), 1) - 1).bit_length() + + +def run_decode_combine( + O_partial: torch.Tensor, + LSE_partial: torch.Tensor, + split_counts: torch.Tensor, + o_indptr: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + *, + seqlen_q: int, + q_tokens_per_group: int, + max_split_count: int, +) -> None: + """Launch LDGSTS decode split-KV combine.""" + + if O_partial.dtype != torch.float32: + raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}") + if LSE_partial.dtype != torch.float32: + raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}") + if lse.dtype != torch.float32: + raise TypeError(f"lse must be torch.float32, got {lse.dtype}") + if split_counts.dtype != torch.int32: + raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}") + if o_indptr.dtype != torch.int32: + raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}") + if out.ndim != 3 or O_partial.ndim != 3: + raise ValueError("decode combine expects O tensors with shape [rows, heads, D]") + if LSE_partial.ndim != 2 or lse.ndim != 2: + raise ValueError("decode combine expects LSE tensors with shape [rows, heads]") + if out.shape[1:] != O_partial.shape[1:]: + raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}") + if lse.shape != out.shape[:2]: + raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}") + if LSE_partial.shape != O_partial.shape[:2]: + raise ValueError( + f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}" + ) + if split_counts.ndim != 1 or o_indptr.ndim != 1: + raise ValueError("split_counts and o_indptr must be rank-1 tensors") + if o_indptr.shape != (split_counts.shape[0] + 1,): + raise ValueError( + f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)" + ) + seqlen_q = int(seqlen_q) + q_tokens_per_group = int(q_tokens_per_group) + if seqlen_q <= 0: + raise ValueError("seqlen_q must be positive") + if q_tokens_per_group <= 0: + raise ValueError("q_tokens_per_group must be positive") + if out.shape[0] != split_counts.shape[0] * seqlen_q: + raise ValueError( + f"out rows {out.shape[0]} must equal batch*seqlen_q " + f"{split_counts.shape[0]}*{seqlen_q}" + ) + + max_split_count = int(max_split_count) + if max_split_count <= 0: + raise ValueError("max_split_count must be positive") + if max_split_count > 256: + raise NotImplementedError( + f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}" + ) + max_splits = max(4, _next_power_of_2(max_split_count)) + tile_m = 64 + k_block_size = int(out.shape[-1]) + stages = 2 + + dtype = torch2cute_dtype_map[out.dtype] + key = ( + "decode_combine_ldgsts", + out.shape[-1], + dtype, + O_partial.dtype, + seqlen_q, + q_tokens_per_group, + tile_m, + k_block_size, + max_splits, + stages, + ) + if key not in _combine_compile_cache: + from ....quack.compile_utils import make_fake_tensor + + total_q = cute.sym_int64() + batch = cute.sym_int64() + batch_plus_one = cute.sym_int64() + partial_rows = cute.sym_int64() + head_q = cute.sym_int64() + head_dim = int(out.shape[-1]) + kernel = SparseDecodeForwardCombine( + dtype=dtype, + dtype_partial=Float32, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + max_splits=max_splits, + stages=stages, + ) + _combine_compile_cache[key] = cute.compile( + kernel, + make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4), + make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1), + make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0), + make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0), + make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width), + make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1), + Int32(seqlen_q), + Int32(q_tokens_per_group), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"): + _combine_compile_cache[key]( + O_partial, + LSE_partial, + split_counts, + o_indptr, + out, + lse, + seqlen_q, + q_tokens_per_group, + ) + + +__all__ = ["SparseDecodeForwardCombine", "run_decode_combine"] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..13b487402bf52d008b7ff7edbe9d584f366256b9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Decode-specific tile scheduler for paged fp8 attention. + +The pre-schedule step builds a dense worklist over decode KV chunks. Static +persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id. +CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``, +and maps the canceled CTA coordinate back to the same logical task space. +""" + +from dataclasses import dataclass +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor + +from ....quack.cute_dsl_utils import ParamsBase + +from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo + + +@dataclass +class DecodeTileSchedulerArguments(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + +class DecodeTileScheduler: + """Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks.""" + + @dataclass + class Params(ParamsBase): + work_capacity: Int32 + num_heads_kv: Int32 + num_heads_kv_divmod: FastDivmodDivisor + total_tasks: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + + def __init__( + self, + params: Params, + task_idx: Int32, + clc_scheduler=None, + clc_pipeline=None, + clc_consumer_state=None, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ): + self.params = params + self._task_idx = task_idx + self._clc_scheduler = clc_scheduler + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + self._clc_response_ptr = clc_response_ptr + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: DecodeTileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1" + total_tasks = args.work_capacity * args.num_heads_kv + return DecodeTileScheduler.Params( + args.work_capacity, + args.num_heads_kv, + FastDivmodDivisor(args.num_heads_kv), + total_tasks, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + ) + + @staticmethod + def _clc_grid_shape(params: Params): + return ( + cute.round_up(params.work_capacity, params.cluster_shape_m), + params.num_heads_kv, + Int32(1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, + clc_response_ptr=None, + *, + loc=None, + ip=None, + ) -> "DecodeTileScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + from cutlass.utils import ( + ClcDynamicPersistentTileScheduler, + ClcDynamicPersistentTileSchedulerParams, + ) + + cutlass_params = ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + block_idx = cute.arch.block_idx() + grid_dim = cute.arch.grid_dim() + clc_scheduler = ClcDynamicPersistentTileScheduler.create( + cutlass_params, + block_idx, + grid_dim, + clc_response_ptr, + ) + return DecodeTileScheduler( + params, + block_idx[0], + clc_scheduler, + clc_response_ptr=clc_response_ptr, + loc=loc, + ip=ip, + ) + + if const_expr(params.cluster_shape_m == 1): + task_idx = cute.arch.block_idx()[0] + else: + task_idx = cute.arch.cluster_idx()[0] + return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return DecodeTileScheduler._clc_grid_shape(params) + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m + grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) + + @cute.jit + def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo: + work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod) + return WorkTileInfo( + (Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)), + is_valid, + ) + + @cute.jit + def _clc_work_to_coords(self, work) -> WorkTileInfo: + work_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + work_idx = work_idx // self.params.cluster_shape_m + return WorkTileInfo( + ( + Int32(work_idx), + Int32(work.tile_idx[1]), + Int32(0), + Int32(0), + ), + work.is_valid_tile, + ) + + @cute.jit + def _clc_response_to_work( + self, + response_stage: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + # CLC responses are 16B opaque records. The scheduler warp can query + # the next stage before all consumer warps have read the current one, + # so each pipeline stage needs its own response slot. + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response( + response_ptr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32( + self.params.cluster_shape_m) + return WorkTileInfo( + ( + Int32(m_idx) + cta_idx_in_cluster, + Int32(n_idx), + Int32(l_idx), + Int32(0), + ), + is_valid, + ) + + @cute.jit + def get_current_work( + self, + response_stage: Int32 = Int32(0), + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_response_to_work( + response_stage, loc=loc, ip=ip) + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + is_valid = self._task_idx < self.params.total_tasks + return self._task_to_work(self._task_idx, is_valid) + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self._clc_scheduler.initial_work_tile_info() + self._task_idx = ( + work.tile_idx[0] * self.params.num_heads_kv + + work.tile_idx[1] + ) + return self._clc_work_to_coords(work) + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work( + self, + *, + loc=None, + ip=None, + mbarrier_addr=None, + response_stage: Int32 = Int32(0), + ): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + assert mbarrier_addr is not None + response_ptr = self._clc_response_ptr + response_stage * Int32(4) + with cute.arch.elect_one(): + cute.arch.issue_clc_query( + mbarrier_addr, response_ptr, loc=loc, ip=ip) + else: + assert mbarrier_addr is None + if const_expr(self.params.cluster_shape_m == 1): + self._task_idx += cute.arch.grid_dim()[0] + else: + self._task_idx += cute.arch.cluster_dim()[0] + + def consumer_advance(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + response_stage = self._clc_consumer_state.index + self._clc_pipeline.consumer_wait(self._clc_consumer_state) + work_tile = self.get_current_work(response_stage=response_stage) + self._clc_pipeline.consumer_release(self._clc_consumer_state) + self._clc_consumer_state.advance() + return work_tile + self.advance_to_next_work() + return self.get_current_work() + + def set_clc_pipeline(self, clc_pipeline, clc_consumer_state): + self._clc_pipeline = clc_pipeline + self._clc_consumer_state = clc_consumer_state + + def producer_tail(self, *, loc=None, ip=None): + pass + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj in objs: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + objs = [self.params, self._task_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [ + self._clc_scheduler, + self._clc_pipeline, + self._clc_consumer_state, + self._clc_response_ptr, + ] + for obj, n_items in zip(objs, self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return DecodeTileScheduler(*obj_list, loc=self._loc) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_k2q_csr.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_k2q_csr.py new file mode 100644 index 0000000000000000000000000000000000000000..8e59b3d55bd3e9b164dac1e474dd648501c1aa51 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_k2q_csr.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Sparse k2q CSR builder for SM100. + +Thin dispatcher that calls the CUDA C++ kernel pipeline in +``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and +``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than +silently falling back to a torch-reference path. +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL + + +_SUPPORTED_TOPK = (4, 8, 16, 32) +_SUPPORTED_BLK_KV = 128 + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +class SparseK2qCsrBuilderSm100: + """Build the k2q CSR reverse index for sparse attention on SM100. + + The public API matches the historical CUTE DSL builder so callers + (``sparse_index_utils.build_k2q_csr``, attention kernels) need no + changes. Internally the kernel pipeline runs five CUDA C++ kernels: + ``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem`` + -> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``). + """ + + def __init__(self) -> None: + # No persistent state — the JIT-compiled extension is loaded + # lazily by ``src.sm100.build_k2q_csr`` on first call. + self._run = None + self._run_with_schedule = None + + def _ensure_loaded(self) -> None: + if self._run is None: + from ...src.sm100.build_k2q_csr import ( + run_build_k2q_csr, + run_build_k2q_csr_with_schedule, + ) + self._run = run_build_k2q_csr + self._run_with_schedule = run_build_k2q_csr_with_schedule + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + *, + total_k: int, + blk_kv: int = 128, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]: + # ---- Validation ---------------------------------------------------- + if blk_kv != _SUPPORTED_BLK_KV: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports blk_kv == " + f"{_SUPPORTED_BLK_KV}, got {blk_kv}" + ) + if q2k_indices.dtype != torch.int32: + raise TypeError( + f"q2k_indices must be torch.int32, got {q2k_indices.dtype}" + ) + if q2k_indices.ndim != 3: + raise ValueError( + f"q2k_indices must be rank-3 [head_kv, total_q, topK], " + f"got shape {tuple(q2k_indices.shape)}" + ) + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous") + if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32") + if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError( + "cu_seqlens_q and cu_seqlens_k must share shape [B + 1]" + ) + if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda): + raise ValueError("all inputs must be CUDA tensors") + if ( + q2k_indices.device != cu_seqlens_q.device + or q2k_indices.device != cu_seqlens_k.device + ): + raise ValueError("all inputs must share a device") + if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous(): + raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous") + + total_k = int(total_k) + if total_k < 0: + raise ValueError(f"total_k must be non-negative, got {total_k}") + + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + if topk not in _SUPPORTED_TOPK: + raise ValueError( + f"SparseK2qCsrBuilderSm100 only supports topK in " + f"{_SUPPORTED_TOPK}, got {topk}" + ) + + batch = int(cu_seqlens_q.shape[0] - 1) + if batch < 0: + raise ValueError("cu_seqlens tensors must have shape [B + 1]") + if return_schedule and max_seqlen_k is None: + raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True") + max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k + max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv) + if total_rows is not None: + total_rows = int(total_rows) + elif total_k % blk_kv == 0: + total_rows = total_k // blk_kv + else: + total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv) + if total_rows < 0: + raise ValueError(f"total_rows must be non-negative, got {total_rows}") + total_rows = max(total_rows, 0) + nnz_upper_bound = total_q * topk + qhead_per_kv = int(qhead_per_kv) + if qhead_per_kv <= 0: + raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}") + if return_schedule: + if max_seqlen_q is None: + raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True") + max_seqlen_q = int(max_seqlen_q) + + # ---- Output tensors ------------------------------------------------ + device = q2k_indices.device + k2q_row_ptr = torch.empty( + (head_kv, total_rows + 1), dtype=torch.int32, device=device, + ) + k2q_q_indices = torch.empty( + (head_kv, nnz_upper_bound), dtype=torch.int32, device=device, + ) + schedule = None + if return_schedule: + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), dtype=torch.int32, device=device + ) + work_count = torch.empty((1,), dtype=torch.int32, device=device) + qsplit_indices = torch.empty_like(k2q_q_indices) + split_counts = torch.empty( + (total_q, head_kv), dtype=torch.int32, device=device + ) + schedule = SparseAttentionSchedule( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + qsplit_indices=qsplit_indices, + split_counts=split_counts, + target_q_per_cta=target_q_per_cta, + ) + + # Empty workload short-circuit (the CUDA path also handles this, + # but doing it here saves a JIT load for trivial calls). + if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0: + k2q_row_ptr.zero_() + k2q_q_indices.fill_(-1) + if schedule is not None: + schedule.work_count.zero_() + schedule.split_counts.zero_() + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices + + self._ensure_loaded() + with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"): + if schedule is None: + self._run( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + topk, + blk_kv, + total_rows, + max_kv_blocks, + ) + else: + self._run_with_schedule( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + schedule.scheduler_metadata, + schedule.work_count, + schedule.qsplit_indices, + schedule.split_counts, + topk, + blk_kv, + total_rows, + max_kv_blocks, + schedule.target_q_per_cta, + schedule.work_capacity, + max_seqlen_q, + ) + if schedule is not None: + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices diff --git a/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_scheduler.py b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..662e48f905249913a381f5d11a3f0c49626e98bd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/src/sm100/prepare_scheduler.py @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax +# SPDX-License-Identifier: MIT + +"""Prepare scheduler for SM100 sparse attention. + +The scheduler converts uneven CSR k2q row fanout into a flat worklist consumed +by sparse attention kernels. Each work item covers a contiguous q-index range +within one (head_kv, csr row) and carries the decoded batch/KV-block coordinate. +""" + +from dataclasses import dataclass +from typing import Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32, const_expr + +from ...src.common import copy_utils, utils +from ...src.common.cute_dsl_utils import ( + assume_tensor_aligned, + to_cute_tensor as to_cute_tensor_kvouter, +) + + +_PREPARE_COMPILE_CACHE: dict = {} + + +@dataclass +class SparseAttentionSchedule: + enabled: bool + scheduler_metadata: Optional[torch.Tensor] + work_count: Optional[torch.Tensor] + qsplit_indices: Optional[torch.Tensor] = None + split_counts: Optional[torch.Tensor] = None + target_q_per_cta: int = 0 + + @property + def work_capacity(self) -> int: + return 0 if self.scheduler_metadata is None else int(self.scheduler_metadata.shape[0]) + + +SparseSchedulePlan = SparseAttentionSchedule + + +class SparseAttentionScheduleModel: + """Host-side helpers for sparse attention schedule sizing.""" + + @staticmethod + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + @staticmethod + def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + def _target_q_per_cta( + self, + *, + total_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + num_sm = torch.cuda.get_device_properties(device).multi_processor_count + if usable_SM_count > 0: + num_sm = min(int(usable_SM_count), num_sm) + q_tokens_per_group = 128 // qhead_per_kv + total_refs_upper = total_q * topk * head_kv + desired_work_items = max(num_sm * 2, 1) + total_groups_upper = self._ceil_div(max(total_refs_upper, 1), q_tokens_per_group) + target_groups_per_cta = min( + 512, + max(1, self._ceil_div(total_groups_upper, desired_work_items)), + ) + return target_groups_per_cta * q_tokens_per_group + + def balanced_target_q_per_cta( + self, + *, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + q_tokens_per_group = 128 // qhead_per_kv + occupancy_target = self._target_q_per_cta( + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + sink_balance_cap = max(q_tokens_per_group, int(topk) * int(blk_kv) * 2) + target = min(max(occupancy_target, q_tokens_per_group), sink_balance_cap) + return self._round_up(target, q_tokens_per_group) + + def flat_schedule_capacity( + self, + *, + total_rows: int, + total_q: int, + topk: int, + head_kv: int, + target_q_per_cta: int, + ) -> int: + row_upper = max(total_rows, 0) * max(head_kv, 1) + refs_upper = max(total_q, 0) * max(topk, 1) * max(head_kv, 1) + split_upper = self._ceil_div(max(refs_upper, 1), max(target_q_per_cta, 1)) + return max(1, row_upper + split_upper) + + +SPARSE_SCHEDULE_MODEL = SparseAttentionScheduleModel() + + +class SparseAttentionPrepareFlatScheduleSm100: + """Build a compact flat worklist by splitting each CSR row into chunks.""" + + def __init__( + self, + *, + num_threads: int = 128, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + self.warps_per_cta = num_threads // 32 + + @cute.jit + def _emit_work( + self, + mSchedulerMetadata: cute.Tensor, + work_idx: Int32, + work_capacity: Int32, + head_kv_idx: Int32, + row_linear: Int32, + q_begin: Int32, + q_count: Int32, + batch_idx: Int32, + kv_block_idx: Int32, + ): + if work_idx < work_capacity: + mSchedulerMetadata[work_idx, Int32(0)] = head_kv_idx + mSchedulerMetadata[work_idx, Int32(1)] = row_linear + mSchedulerMetadata[work_idx, Int32(2)] = q_begin + mSchedulerMetadata[work_idx, Int32(3)] = q_count + mSchedulerMetadata[work_idx, Int32(4)] = batch_idx + mSchedulerMetadata[work_idx, Int32(5)] = kv_block_idx + + @cute.jit + def _rows_in_batch( + self, + mCuSeqlensK: cute.Tensor, + batch_idx: Int32, + blk_kv: Int32, + ) -> Int32: + seqlen = mCuSeqlensK[batch_idx + Int32(1)] - mCuSeqlensK[batch_idx] + return (seqlen + blk_kv - Int32(1)) // blk_kv + + @cute.jit + def _rows_before_level( + self, + mCuSeqlensK: cute.Tensor, + level: Int32, + blk_kv: Int32, + ) -> Int32: + total = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + total += cutlass.min(rows, level) + return total + + @cute.jit + def _max_rows_per_batch( + self, + mCuSeqlensK: cute.Tensor, + blk_kv: Int32, + ) -> Int32: + max_rows = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + max_rows = cutlass.max(max_rows, rows) + return max_rows + + @cute.jit + def _decode_sparse_row_linear( + self, + mCuSeqlensK: cute.Tensor, + row_linear: Int32, + blk_kv: Int32, + ) -> tuple[Int32, Int32]: + lo = Int32(0) + hi = self._max_rows_per_batch(mCuSeqlensK, blk_kv) + while lo < hi: + mid = (lo + hi) // Int32(2) + rows_before_next = self._rows_before_level( + mCuSeqlensK, + mid + Int32(1), + blk_kv, + ) + if rows_before_next <= row_linear: + lo = mid + Int32(1) + else: + hi = mid + + level = lo + offset = row_linear - self._rows_before_level(mCuSeqlensK, level, blk_kv) + active_idx = Int32(0) + batch_idx = Int32(0) + found = Int32(0) + batch = mCuSeqlensK.shape[0] - Int32(1) + for b in cutlass.range(batch, unroll=1): + if found == Int32(0): + rows = self._rows_in_batch(mCuSeqlensK, b, blk_kv) + if rows > level: + if active_idx == offset: + batch_idx = b + found = Int32(1) + active_idx += Int32(1) + return batch_idx, level + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + blk_kv: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mCuSeqlensK.element_type != Int32): + raise TypeError("mCuSeqlensK must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount = [ + assume_tensor_aligned(t) + for t in (mK2qCounts, mCuSeqlensK, mSchedulerMetadata, mWorkCount) + ] + total_rows = mK2qCounts.shape[1] - Int32(1) + total_row_heads = total_rows * num_heads_kv + grid_ctas = cute.ceil_div(total_row_heads, self.warps_per_cta) + + self.kernel( + mK2qCounts, + mCuSeqlensK, + mSchedulerMetadata, + mWorkCount, + target_q_per_cta, + work_capacity, + num_heads_kv, + total_rows, + blk_kv, + ).launch( + grid=(grid_ctas,), + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mCuSeqlensK: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + target_q_per_cta: Int32, + work_capacity: Int32, + num_heads_kv: Int32, + total_rows: Int32, + blk_kv: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + lane_idx = tidx % Int32(32) + warp_idx = tidx // Int32(32) + row_head_idx = block_idx * Int32(self.warps_per_cta) + warp_idx + total_row_heads = total_rows * num_heads_kv + + head_kv_idx = Int32(0) + row_linear = Int32(0) + row_count = Int32(0) + num_chunks = Int32(0) + batch_idx = Int32(0) + kv_block_idx = Int32(0) + if row_head_idx < total_row_heads: + row_linear = row_head_idx // num_heads_kv + head_kv_idx = row_head_idx - row_linear * num_heads_kv + if lane_idx == Int32(0): + row_start = mK2qCounts[head_kv_idx, row_linear] + row_end = mK2qCounts[head_kv_idx, row_linear + Int32(1)] + row_count = row_end - row_start + batch_idx, kv_block_idx = self._decode_sparse_row_linear( + mCuSeqlensK, + row_linear, + blk_kv, + ) + if row_count > Int32(0): + num_chunks = ( + row_count + target_q_per_cta - Int32(1) + ) // target_q_per_cta + row_count = cute.arch.shuffle_sync(row_count, offset=0) + num_chunks = cute.arch.shuffle_sync(num_chunks, offset=0) + batch_idx = cute.arch.shuffle_sync(batch_idx, offset=0) + kv_block_idx = cute.arch.shuffle_sync(kv_block_idx, offset=0) + + chunk_idx = lane_idx + while chunk_idx < num_chunks: + work_idx = cute.arch.atomic_add( + mWorkCount.iterator.llvm_ptr, + Int32(1), + sem="relaxed", + scope="gpu", + ) + q_begin = chunk_idx * target_q_per_cta + q_count = cutlass.min(target_q_per_cta, row_count - q_begin) + self._emit_work( + mSchedulerMetadata, + work_idx, + work_capacity, + head_kv_idx, + row_linear, + q_begin, + q_count, + batch_idx, + kv_block_idx, + ) + chunk_idx += Int32(32) + + +class SparseAttentionPrepareFwdSplitAtomicSm100: + """Build packed q_idx/split_slot metadata for fwd K1 without K1 atomics.""" + + def __init__( + self, + *, + num_threads: int = 256, + ): + if num_threads % 32 != 0: + raise ValueError(f"num_threads must be a multiple of 32, got {num_threads}") + self.num_threads = num_threads + + @cute.struct + class SharedStorage: + sRow: cute.struct.MemRange[Int32, 3] + + self.shared_storage = SharedStorage + + @cute.jit + def __call__( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + work_capacity: Int32, + max_seqlen_q: Int32, + topk: Int32, + stream: cuda.CUstream = None, + ): + if const_expr(mK2qCounts.element_type != Int32): + raise TypeError("mK2qCounts must be Int32") + if const_expr(mK2qIndices.element_type != Int32): + raise TypeError("mK2qIndices must be Int32") + if const_expr(mSchedulerMetadata.element_type != Int32): + raise TypeError("mSchedulerMetadata must be Int32") + if const_expr(mWorkCount.element_type != Int32): + raise TypeError("mWorkCount must be Int32") + if const_expr(mK2qQSplitIndices.element_type != Int32): + raise TypeError("mK2qQSplitIndices must be Int32") + if const_expr(mSplitCounts.element_type != Int32): + raise TypeError("mSplitCounts must be Int32") + if const_expr(mCuSeqlensQ.element_type != Int32): + raise TypeError("mCuSeqlensQ must be Int32") + ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) = [ + assume_tensor_aligned(t) + for t in ( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + ) + ] + self.kernel( + mK2qCounts, + mK2qIndices, + mSchedulerMetadata, + mWorkCount, + mK2qQSplitIndices, + mSplitCounts, + mCuSeqlensQ, + max_seqlen_q, + topk, + ).launch( + grid=(work_capacity,), + block=[self.num_threads, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mK2qCounts: cute.Tensor, + mK2qIndices: cute.Tensor, + mSchedulerMetadata: cute.Tensor, + mWorkCount: cute.Tensor, + mK2qQSplitIndices: cute.Tensor, + mSplitCounts: cute.Tensor, + mCuSeqlensQ: cute.Tensor, + max_seqlen_q: Int32, + topk: Int32, + ): + tidx = cute.arch.thread_idx()[0] + block_idx = cute.arch.block_idx()[0] + if block_idx < mWorkCount[Int32(0)]: + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sRow = storage.sRow.get_tensor(cute.make_layout((3,))) + head_kv_idx = mSchedulerMetadata[block_idx, Int32(0)] + row_linear = mSchedulerMetadata[block_idx, Int32(1)] + q_begin = mSchedulerMetadata[block_idx, Int32(2)] + q_count = mSchedulerMetadata[block_idx, Int32(3)] + batch_idx_t0 = mSchedulerMetadata[block_idx, Int32(4)] + + if tidx == Int32(0): + row_start_t0 = mK2qCounts[head_kv_idx, row_linear] + q_begin + sRow[0] = row_start_t0 + sRow[1] = q_count + sRow[2] = batch_idx_t0 + cute.arch.barrier() + row_start = sRow[0] + row_count = sRow[1] + batch_idx = sRow[2] + qi = tidx + while qi < row_count: + edge = row_start + qi + q_idx = mK2qIndices[head_kv_idx, edge] + if q_idx >= Int32(0) and q_idx < max_seqlen_q: + q_abs = mCuSeqlensQ[batch_idx] + q_idx + split_ptr = utils.elem_pointer( + mSplitCounts, + (q_abs, head_kv_idx), + ) + split_slot = copy_utils.atomic_add_i32(split_ptr) + if split_slot < topk: + mK2qQSplitIndices[head_kv_idx, edge] = ( + q_idx | ((split_slot & Int32(0xFF)) << Int32(24)) + ) + qi += Int32(self.num_threads) + + +def _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + work_capacity: int, + max_seqlen_q: int, + topk: int, +): + key = ( + "sparse_prepare_fwd_split_atomic_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFwdSplitAtomicSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(k2q_q_indices), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + to_cute_tensor_kvouter(k2q_qsplit_indices), + to_cute_tensor_kvouter(split_counts), + to_cute_tensor_kvouter(cu_seqlens_q), + Int32(work_capacity), + Int32(max_seqlen_q), + Int32(topk), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def _get_sparse_prepare_flat_schedule( + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + target_q_per_cta: int, + scheduler_metadata_capacity: int, + head_kv: int, + blk_kv: int, +): + key = ( + "sparse_prepare_flat_schedule_sm100_csr_varlen", + ) + if key not in _PREPARE_COMPILE_CACHE: + from ...src.common.aot_cache import try_load_aot, save_aot + + loaded = try_load_aot(key) + if loaded is not None: + _PREPARE_COMPILE_CACHE[key] = loaded + else: + kernel = SparseAttentionPrepareFlatScheduleSm100() + _PREPARE_COMPILE_CACHE[key] = cute.compile( + kernel, + to_cute_tensor_kvouter(k2q_row_ptr), + to_cute_tensor_kvouter(cu_seqlens_k), + to_cute_tensor_kvouter(scheduler_metadata), + to_cute_tensor_kvouter(work_count), + Int32(target_q_per_cta), + Int32(scheduler_metadata_capacity), + Int32(head_kv), + Int32(blk_kv), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + save_aot(key, _PREPARE_COMPILE_CACHE[key]) + return _PREPARE_COMPILE_CACHE[key] + + +def prepare_sparse_flat_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + if not enabled: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + + total_rows = int(k2q_row_ptr.shape[1] - 1) + if total_rows <= 0 or head_kv <= 0: + return SparseSchedulePlan(enabled=False, scheduler_metadata=None, work_count=None) + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + scheduler_metadata = torch.empty( + (scheduler_metadata_capacity, 6), + dtype=torch.int32, + device=device, + ) + work_count = torch.zeros((1,), dtype=torch.int32, device=device) + scheduler_metadata.zero_() + + compiled_prepare = _get_sparse_prepare_flat_schedule( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFlatSchedule"): + compiled_prepare( + k2q_row_ptr, + cu_seqlens_k, + scheduler_metadata, + work_count, + target_q_per_cta, + scheduler_metadata_capacity, + head_kv, + blk_kv, + ) + + return SparseSchedulePlan( + enabled=True, + scheduler_metadata=scheduler_metadata, + work_count=work_count, + target_q_per_cta=target_q_per_cta, + ) + +def prepare_sparse_fwd_schedule_and_split( + *, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + k2q_qsplit_indices: torch.Tensor, + split_counts: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + max_seqlen_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + blk_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + plan = prepare_sparse_fwd_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + blk_kv=blk_kv, + device=device, + enabled=enabled, + usable_SM_count=usable_SM_count, + ) + if not plan.enabled: + return plan + if plan.scheduler_metadata is None or plan.work_count is None: + raise RuntimeError("fwd GPU schedule requires metadata") + if topk > 255: + raise ValueError(f"packed qsplit metadata supports topK <= 255, got {topk}") + if max_seqlen_q >= (1 << 24): + raise ValueError( + "packed qsplit metadata supports batch-local q_idx < 2^24, " + f"got max_seqlen_q={max_seqlen_q}" + ) + if k2q_qsplit_indices.shape != k2q_q_indices.shape: + raise ValueError("k2q_qsplit_indices shape must match k2q_q_indices") + if split_counts.dtype != torch.int32 or k2q_qsplit_indices.dtype != torch.int32: + raise TypeError("split metadata tensors must be torch.int32") + if split_counts.shape != (total_q, head_kv): + raise ValueError( + f"split_counts must have shape ({total_q}, {head_kv}), got {tuple(split_counts.shape)}" + ) + if cu_seqlens_q.dtype != torch.int32: + raise TypeError("cu_seqlens_q must be torch.int32") + if cu_seqlens_q.ndim != 1 or not cu_seqlens_q.is_contiguous(): + raise ValueError("cu_seqlens_q must be a contiguous rank-1 tensor") + if cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_k must be torch.int32") + + with torch.cuda.nvtx.range("SparseAttention_InitFwdSplitState"): + split_counts.zero_() + + compiled_split = _get_sparse_prepare_fwd_split_atomic( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + with torch.cuda.nvtx.range("SparseAttention_PrepareFwdSplit_Atomic"): + compiled_split( + k2q_row_ptr, + k2q_q_indices, + plan.scheduler_metadata, + plan.work_count, + k2q_qsplit_indices, + split_counts, + cu_seqlens_q, + plan.work_capacity, + max_seqlen_q, + topk, + ) + plan.qsplit_indices = k2q_qsplit_indices + plan.split_counts = split_counts + return plan + + +def prepare_sparse_fwd_schedule( + *, + k2q_row_ptr: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + enabled: bool, + usable_SM_count: int = -1, +) -> SparseSchedulePlan: + return prepare_sparse_flat_schedule( + k2q_row_ptr=k2q_row_ptr, + cu_seqlens_k=cu_seqlens_k, + total_q=int(total_q), + topk=int(topk), + blk_kv=int(blk_kv), + head_kv=int(head_kv), + qhead_per_kv=int(qhead_per_kv), + device=device, + enabled=bool(enabled), + usable_SM_count=int(usable_SM_count), + )