kernels-bot commited on
Commit
a022715
·
verified ·
1 Parent(s): 48b1bf2

Uploaded using `kernel-builder`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch211-cxx11-cu128-x86_64-linux/__init__.py +60 -0
  2. build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so +3 -0
  3. build/torch211-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  4. build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py +1061 -0
  5. build/torch211-cxx11-cu128-x86_64-linux/interface.py +2011 -0
  6. build/torch211-cxx11-cu128-x86_64-linux/metadata.json +71 -0
  7. build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore +1 -0
  8. build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py +26 -0
  9. build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py +0 -0
  10. build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py +532 -0
  11. build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py +19 -0
  12. build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py +890 -0
  13. build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py +104 -0
  14. build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py +295 -0
  15. build/torch211-cxx11-cu128-x86_64-linux/quantize.py +362 -0
  16. build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py +411 -0
  17. build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py +3 -0
  18. build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py +3 -0
  19. build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py +72 -0
  20. build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py +74 -0
  21. build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py +1093 -0
  22. build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py +61 -0
  23. build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py +1179 -0
  24. build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py +190 -0
  25. build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py +22 -0
  26. build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py +189 -0
  27. build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py +304 -0
  28. build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py +22 -0
  29. build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py +320 -0
  30. build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py +67 -0
  31. build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py +372 -0
  32. build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py +203 -0
  33. build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py +498 -0
  34. build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py +967 -0
  35. build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py +515 -0
  36. build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py +1088 -0
  37. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py +4 -0
  38. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py +103 -0
  39. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py +193 -0
  40. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py +1956 -0
  41. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py +8 -0
  42. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py +0 -0
  43. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py +0 -0
  44. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py +1498 -0
  45. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py +95 -0
  46. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py +0 -0
  47. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py +112 -0
  48. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py +680 -0
  49. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py +300 -0
  50. build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py +227 -0
build/torch211-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100.
5
+
6
+ Hub-kernel packaging of the CuTe-DSL sparse attention stack from
7
+ https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The
8
+ host-side helper kernels (CSR builder, decode scheduler) are precompiled
9
+ Torch ops; the attention kernels are compiled at runtime through
10
+ nvidia-cutlass-dsl.
11
+ """
12
+
13
+ # Sparse attention forward / decode.
14
+ from .interface import (
15
+ SparseDecodePagedAttentionWrapper,
16
+ sparse_atten_func,
17
+ sparse_atten_nvfp4_kv_func,
18
+ sparse_decode_atten_func,
19
+ )
20
+
21
+ # CSR + schedule construction.
22
+ from .sparse_index_utils import build_k2q_csr
23
+
24
+ # SM100 fused CSR builder.
25
+ from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100
26
+
27
+ # FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK
28
+ # selection + q2k construction remain caller-owned downstream steps.
29
+ from .fp4_indexer_interface import fp4_indexer_block_scores
30
+
31
+ # NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention.
32
+ from .quantize import (
33
+ Nvfp4QuantizedTensor,
34
+ dequantize_nvfp4_128x4_to_bf16,
35
+ nvfp4_global_scale_from_amax,
36
+ quantize_bf16_to_nvfp4_128x4,
37
+ quantize_kv_bf16_to_nvfp4_128x4,
38
+ swizzle_nvfp4_scale_to_128x4,
39
+ )
40
+
41
+ __version__ = "0.1.1"
42
+
43
+ __all__ = [
44
+ # attention
45
+ "sparse_atten_func",
46
+ "sparse_atten_nvfp4_kv_func",
47
+ "sparse_decode_atten_func",
48
+ "SparseDecodePagedAttentionWrapper",
49
+ # indexing / CSR
50
+ "fp4_indexer_block_scores",
51
+ "build_k2q_csr",
52
+ "SparseK2qCsrBuilderSm100",
53
+ # nvfp4 quantization helpers
54
+ "Nvfp4QuantizedTensor",
55
+ "quantize_bf16_to_nvfp4_128x4",
56
+ "quantize_kv_bf16_to_nvfp4_128x4",
57
+ "dequantize_nvfp4_128x4_to_bf16",
58
+ "swizzle_nvfp4_scale_to_128x4",
59
+ "nvfp4_global_scale_from_amax",
60
+ ]
build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dcd8c86e512f3ddd5acb95f6fdcad3cfaa1579bb6f874a714fba066e6877161
3
+ size 1169368
build/torch211-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _msa_cuda_09d7851
3
+ ops = torch.ops._msa_cuda_09d7851
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_msa_cuda_09d7851::{op_name}"
build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Public FP4 sparse-attention indexer block-score interface."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Optional
9
+
10
+ import cuda.bindings.driver as cuda
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ import torch
14
+ from cutlass import Int32
15
+ from cutlass.cute.runtime import make_ptr
16
+
17
+ from .src.sm100.fp4_indexer import (
18
+ Fp4FormatSpec,
19
+ Fp4IndexerDecodePackedQSm100,
20
+ Fp4IndexerDecodeQPackSm100,
21
+ Fp4IndexerScaleReorderSm100,
22
+ Fp4IndexerStagedMmaSm100,
23
+ _BLOCK_K,
24
+ _DECODE_K_TILES_PER_CTA,
25
+ _DECODE_PACK_Q_LEN,
26
+ _DECODE_QHEAD_PER_KV,
27
+ _FP4_PACKED_D_BYTES,
28
+ _HEAD_DIM,
29
+ _MMA_TILER_MN,
30
+ _PAGE_SIZE,
31
+ ceil_div,
32
+ k_tiles_per_cta_for,
33
+ normalize_fp4_format,
34
+ )
35
+
36
+
37
+ _PUBLIC_SCALE_LAYOUT = "public"
38
+ _PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma"
39
+ _FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
40
+
41
+
42
+ def _device_arch(device: torch.device) -> tuple[int, int]:
43
+ major, minor = torch.cuda.get_device_capability(device)
44
+ return int(major), int(minor)
45
+
46
+
47
+ def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool:
48
+ return device_arch >= (10, 3)
49
+
50
+
51
+ def normalize_scale_layout(scale_layout: str) -> str:
52
+ """Normalize and validate FP4 indexer scale layout mode.
53
+
54
+ Parameters
55
+ ----------
56
+ scale_layout : str
57
+ Either ``"public"`` for logical scale tensors or ``"preordered_mma"``
58
+ for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``.
59
+
60
+ Returns
61
+ -------
62
+ str
63
+ The normalized scale layout string.
64
+ """
65
+
66
+ scale_layout = str(scale_layout)
67
+ if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT):
68
+ raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}")
69
+ return scale_layout
70
+
71
+
72
+ def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int:
73
+ if q_len <= 0 or k_len <= 0:
74
+ return 0
75
+ q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0])
76
+ k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta)
77
+ group_tokens = k_tiles_per_cta * _BLOCK_K
78
+ causal_offset = int(k_len) - int(q_len)
79
+ tasks = 0
80
+ for q_tile_idx in range(q_tile_count):
81
+ q_tile_start = q_tile_idx * _MMA_TILER_MN[0]
82
+ q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1)
83
+ visible_limit = q_tile_last + causal_offset
84
+ if visible_limit >= 0:
85
+ tasks += min(k_group_count, visible_limit // group_tokens + 1)
86
+ return tasks
87
+
88
+
89
+ def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int:
90
+ """Conservative X-grid bound for per-batch causal prefill compact mapping."""
91
+
92
+ if max_q_len <= 0 or max_k_len <= 0:
93
+ return 0
94
+ q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0])
95
+ candidates = {int(max_q_len)}
96
+ for q_tile_idx in range(q_tile_count):
97
+ q_len = q_tile_idx * _MMA_TILER_MN[0] + 1
98
+ if q_len <= max_q_len:
99
+ candidates.add(q_len)
100
+ return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates)
101
+
102
+
103
+ def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None:
104
+ if not tensor.is_cuda:
105
+ raise ValueError(f"{name} must be a CUDA tensor")
106
+ if not tensor.is_contiguous():
107
+ raise ValueError(f"{name} must be contiguous")
108
+
109
+
110
+ def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None:
111
+ if tensor.device != device:
112
+ raise ValueError(f"{name} must be on the same CUDA device")
113
+ if tensor.dtype != torch.int32:
114
+ raise TypeError(f"{name} must be torch.int32")
115
+ if tensor.ndim != 1:
116
+ raise ValueError(f"{name} must be rank-1")
117
+ if not tensor.is_contiguous():
118
+ raise ValueError(f"{name} must be contiguous")
119
+
120
+
121
+ def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None:
122
+ fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None)
123
+ allowed = {torch.uint8, torch.int8}
124
+ if fp4_x2_dtype is not None:
125
+ allowed.add(fp4_x2_dtype)
126
+ if tensor.dtype not in allowed:
127
+ raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}")
128
+
129
+
130
+ def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor:
131
+ if tensor.ndim != 3:
132
+ raise ValueError(f"{name} must have shape [total_q, Hq, 64]")
133
+ if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES:
134
+ raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128")
135
+ _require_fp4_packed_dtype(tensor, name=name)
136
+ if tensor.dtype == torch.uint8:
137
+ return tensor
138
+ return tensor.view(torch.uint8)
139
+
140
+
141
+ def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor:
142
+ if tensor.ndim != 4:
143
+ raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]")
144
+ if int(tensor.shape[-2]) != _PAGE_SIZE:
145
+ raise ValueError(f"{name}.shape[-2] must be 128")
146
+ if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES:
147
+ raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128")
148
+ _require_fp4_packed_dtype(tensor, name=name)
149
+ if tensor.dtype == torch.uint8:
150
+ return tensor
151
+ return tensor.view(torch.uint8)
152
+
153
+
154
+ def validate_q_scale_thg(
155
+ scale: torch.Tensor,
156
+ *,
157
+ name: str,
158
+ fmt: Fp4FormatSpec,
159
+ total_q: int,
160
+ heads: int,
161
+ ) -> None:
162
+ """Validate public Q FP4 scale layout ``[total_q, Hq, G]``.
163
+
164
+ Parameters
165
+ ----------
166
+ scale : torch.Tensor
167
+ Logical Q scale tensor.
168
+ name : str
169
+ Name used in validation error messages.
170
+ fmt : Fp4FormatSpec
171
+ FP4 format specification from ``normalize_fp4_format``.
172
+ total_q : int
173
+ Total query token count.
174
+ heads : int
175
+ Number of Q heads.
176
+ """
177
+
178
+ expected = (int(total_q), int(heads), fmt.scale_groups)
179
+ if tuple(scale.shape) != expected:
180
+ raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}")
181
+ if scale.dtype != fmt.torch_scale_dtype:
182
+ raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
183
+ if not scale.is_contiguous():
184
+ raise ValueError(f"{name} must be contiguous")
185
+
186
+
187
+ def validate_k_scale_phsg(
188
+ scale: torch.Tensor,
189
+ *,
190
+ name: str,
191
+ fmt: Fp4FormatSpec,
192
+ page_count: int,
193
+ heads: int,
194
+ ) -> None:
195
+ """Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``.
196
+
197
+ Parameters
198
+ ----------
199
+ scale : torch.Tensor
200
+ Logical K scale tensor.
201
+ name : str
202
+ Name used in validation error messages.
203
+ fmt : Fp4FormatSpec
204
+ FP4 format specification from ``normalize_fp4_format``.
205
+ page_count : int
206
+ Number of physical KV pages.
207
+ heads : int
208
+ Number of KV heads.
209
+ """
210
+
211
+ expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups)
212
+ if tuple(scale.shape) != expected:
213
+ raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}")
214
+ if scale.dtype != fmt.torch_scale_dtype:
215
+ raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
216
+ if not scale.is_contiguous():
217
+ raise ValueError(f"{name} must be contiguous")
218
+
219
+
220
+ def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
221
+ """Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``."""
222
+
223
+ spec = normalize_fp4_format(fp4_format)
224
+ return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l))
225
+
226
+
227
+ def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
228
+ """Return element strides for ``fp4_indexer_mma_scale_shape``."""
229
+
230
+ spec = normalize_fp4_format(fp4_format)
231
+ rest_m = ceil_div(mn, 128)
232
+ rest_g = ceil_div(spec.scale_groups, 4)
233
+ return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g)
234
+
235
+
236
+ def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
237
+ """Return contiguous storage shape for preordered MMA scales."""
238
+
239
+ spec = normalize_fp4_format(fp4_format)
240
+ return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4)
241
+
242
+
243
+ def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
244
+ """Return element strides for ``fp4_indexer_mma_scale_storage_shape``."""
245
+
246
+ spec = normalize_fp4_format(fp4_format)
247
+ rest_m = ceil_div(mn, 128)
248
+ rest_g = ceil_div(spec.scale_groups, 4)
249
+ return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1)
250
+
251
+
252
+ def validate_mma_scale_storage(
253
+ scale: torch.Tensor,
254
+ *,
255
+ name: str,
256
+ fmt: Fp4FormatSpec,
257
+ mn: int,
258
+ l: int,
259
+ ) -> None:
260
+ """Validate preordered MMA scale storage expected by the FP4 indexer.
261
+
262
+ Parameters
263
+ ----------
264
+ scale : torch.Tensor
265
+ Tensor view whose shape/stride should match
266
+ ``fp4_indexer_mma_scale_storage_shape`` and
267
+ ``fp4_indexer_mma_scale_storage_stride``.
268
+ name : str
269
+ Name used in validation error messages.
270
+ fmt : Fp4FormatSpec
271
+ FP4 format specification from ``normalize_fp4_format``.
272
+ mn : int
273
+ Logical M/N extent of the scale domain.
274
+ l : int
275
+ Logical batch/head extent folded into the final layout dimension.
276
+ """
277
+
278
+ expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name)
279
+ expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name)
280
+ if tuple(scale.shape) != expected_shape:
281
+ raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}")
282
+ if tuple(scale.stride()) != expected_stride:
283
+ raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}")
284
+ if scale.dtype != fmt.torch_scale_dtype:
285
+ raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
286
+
287
+
288
+ def _empty_mma_scale_tensor(
289
+ *,
290
+ mn: int,
291
+ l: int,
292
+ spec: Fp4FormatSpec,
293
+ device: torch.device,
294
+ ) -> torch.Tensor:
295
+ rest_m = ceil_div(mn, 128)
296
+ rest_g = ceil_div(spec.scale_groups, 4)
297
+ storage = torch.empty(
298
+ (int(l), rest_m, rest_g, 32, 4, 4),
299
+ dtype=spec.torch_scale_dtype,
300
+ device=device,
301
+ )
302
+ return storage.permute(3, 4, 1, 5, 2, 0)
303
+
304
+
305
+ def _compile_fp4_scale_reorder_kernel(
306
+ *,
307
+ fmt: Fp4FormatSpec,
308
+ q_scale_ptr: cute.Pointer,
309
+ k_scale_ptr: cute.Pointer,
310
+ q_scale_mma_ptr: cute.Pointer,
311
+ k_scale_mma_ptr: cute.Pointer,
312
+ problem_size: tuple,
313
+ stream: cuda.CUstream,
314
+ ):
315
+ key = (
316
+ "fp4_indexer_scale_reorder_sm100_1cta",
317
+ fmt.name,
318
+ )
319
+ if key not in _FP4_COMPILE_CACHE:
320
+ kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name)
321
+ _FP4_COMPILE_CACHE[key] = cute.compile(
322
+ kernel,
323
+ q_scale_ptr,
324
+ k_scale_ptr,
325
+ q_scale_mma_ptr,
326
+ k_scale_mma_ptr,
327
+ problem_size,
328
+ stream,
329
+ )
330
+ return _FP4_COMPILE_CACHE[key]
331
+
332
+
333
+ def fp4_indexer_reorder_scales_for_mma_cute(
334
+ q_scale: torch.Tensor,
335
+ k_scale: torch.Tensor,
336
+ *,
337
+ fp4_format: str,
338
+ ) -> tuple[torch.Tensor, torch.Tensor]:
339
+ """Reorder public Q/K FP4 scales to MMA-friendly storage.
340
+
341
+ Parameters
342
+ ----------
343
+ q_scale : torch.Tensor
344
+ Public Q scale tensor with shape ``[total_q, Hq, G]``.
345
+ k_scale : torch.Tensor
346
+ Public K scale tensor with shape ``[page_count, Hk, 128, G]``.
347
+ fp4_format : str
348
+ ``"mxfp4"`` or ``"nvfp4"``.
349
+
350
+ Returns
351
+ -------
352
+ tuple[torch.Tensor, torch.Tensor]
353
+ ``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by
354
+ ``validate_mma_scale_storage``. These tensors can be passed to
355
+ ``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``.
356
+ """
357
+
358
+ spec = normalize_fp4_format(fp4_format)
359
+ if q_scale.device != k_scale.device:
360
+ raise ValueError("q_scale and k_scale must be on the same CUDA device")
361
+ _require_cuda_tensor(q_scale, name="q_scale")
362
+ _require_cuda_tensor(k_scale, name="k_scale")
363
+ if q_scale.ndim != 3:
364
+ raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}")
365
+ if k_scale.ndim != 4:
366
+ raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}")
367
+ total_q, heads_q, _ = (int(v) for v in q_scale.shape)
368
+ page_count, heads_k, _, _ = (int(v) for v in k_scale.shape)
369
+ validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q)
370
+ validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k)
371
+
372
+ q_scale_mma = _empty_mma_scale_tensor(
373
+ mn=total_q,
374
+ l=heads_q,
375
+ spec=spec,
376
+ device=q_scale.device,
377
+ )
378
+ k_scale_mma = _empty_mma_scale_tensor(
379
+ mn=_PAGE_SIZE,
380
+ l=page_count * heads_k,
381
+ spec=spec,
382
+ device=k_scale.device,
383
+ )
384
+
385
+ q_scale_ptr = make_ptr(
386
+ spec.cutlass_scale_dtype,
387
+ q_scale.data_ptr(),
388
+ cute.AddressSpace.gmem,
389
+ assumed_align=16,
390
+ )
391
+ k_scale_ptr = make_ptr(
392
+ spec.cutlass_scale_dtype,
393
+ k_scale.data_ptr(),
394
+ cute.AddressSpace.gmem,
395
+ assumed_align=16,
396
+ )
397
+ q_scale_mma_ptr = make_ptr(
398
+ spec.cutlass_scale_dtype,
399
+ q_scale_mma.data_ptr(),
400
+ cute.AddressSpace.gmem,
401
+ assumed_align=32,
402
+ )
403
+ k_scale_mma_ptr = make_ptr(
404
+ spec.cutlass_scale_dtype,
405
+ k_scale_mma.data_ptr(),
406
+ cute.AddressSpace.gmem,
407
+ assumed_align=32,
408
+ )
409
+ problem_size = (
410
+ Int32(total_q),
411
+ Int32(heads_q),
412
+ Int32(page_count),
413
+ Int32(heads_k),
414
+ )
415
+ stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream)
416
+ compiled = _compile_fp4_scale_reorder_kernel(
417
+ fmt=spec,
418
+ q_scale_ptr=q_scale_ptr,
419
+ k_scale_ptr=k_scale_ptr,
420
+ q_scale_mma_ptr=q_scale_mma_ptr,
421
+ k_scale_mma_ptr=k_scale_mma_ptr,
422
+ problem_size=problem_size,
423
+ stream=stream,
424
+ )
425
+ compiled(
426
+ q_scale_ptr,
427
+ k_scale_ptr,
428
+ q_scale_mma_ptr,
429
+ k_scale_mma_ptr,
430
+ problem_size,
431
+ stream,
432
+ )
433
+ return q_scale_mma, k_scale_mma
434
+
435
+
436
+ def _compile_fp4_decode_q_pack_kernel(
437
+ *,
438
+ fmt: Fp4FormatSpec,
439
+ q_ptr: cute.Pointer,
440
+ q_scale_ptr: cute.Pointer,
441
+ q_pack_ptr: cute.Pointer,
442
+ q_scale_pack_ptr: cute.Pointer,
443
+ cu_seqlens_q_ptr: cute.Pointer,
444
+ problem_size: tuple,
445
+ stream: cuda.CUstream,
446
+ ):
447
+ key = (
448
+ "fp4_indexer_decode_q_pack_sm100",
449
+ fmt.name,
450
+ )
451
+ if key not in _FP4_COMPILE_CACHE:
452
+ kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name)
453
+ _FP4_COMPILE_CACHE[key] = cute.compile(
454
+ kernel,
455
+ q_ptr,
456
+ q_scale_ptr,
457
+ q_pack_ptr,
458
+ q_scale_pack_ptr,
459
+ cu_seqlens_q_ptr,
460
+ problem_size,
461
+ stream,
462
+ )
463
+ return _FP4_COMPILE_CACHE[key]
464
+
465
+
466
+ def _pack_decode_q_for_mma(
467
+ q_bytes: torch.Tensor,
468
+ q_scale_storage: torch.Tensor,
469
+ cu_seqlens_q: torch.Tensor,
470
+ *,
471
+ fmt: Fp4FormatSpec,
472
+ heads_q: int,
473
+ heads_k: int,
474
+ batch: int,
475
+ ) -> tuple[torch.Tensor, torch.Tensor]:
476
+ q_pack = torch.empty(
477
+ (batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES),
478
+ dtype=torch.uint8,
479
+ device=q_bytes.device,
480
+ )
481
+ q_scale_pack = torch.empty(
482
+ fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name),
483
+ dtype=fmt.torch_scale_dtype,
484
+ device=q_bytes.device,
485
+ )
486
+ if q_pack.data_ptr() % 128 != 0:
487
+ raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA")
488
+ if q_scale_pack.data_ptr() % 32 != 0:
489
+ raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned")
490
+ q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
491
+ q_scale_ptr = make_ptr(
492
+ fmt.cutlass_scale_dtype,
493
+ q_scale_storage.data_ptr(),
494
+ cute.AddressSpace.gmem,
495
+ assumed_align=32,
496
+ )
497
+ q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
498
+ q_scale_pack_ptr = make_ptr(
499
+ fmt.cutlass_scale_dtype,
500
+ q_scale_pack.data_ptr(),
501
+ cute.AddressSpace.gmem,
502
+ assumed_align=32,
503
+ )
504
+ cu_seqlens_q_ptr = make_ptr(
505
+ cutlass.Int32,
506
+ cu_seqlens_q.data_ptr(),
507
+ cute.AddressSpace.gmem,
508
+ assumed_align=4,
509
+ )
510
+ problem_size = (
511
+ Int32(q_bytes.shape[0]),
512
+ Int32(heads_q),
513
+ Int32(heads_k),
514
+ Int32(batch),
515
+ )
516
+ stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream)
517
+ compiled = _compile_fp4_decode_q_pack_kernel(
518
+ fmt=fmt,
519
+ q_ptr=q_ptr,
520
+ q_scale_ptr=q_scale_ptr,
521
+ q_pack_ptr=q_pack_ptr,
522
+ q_scale_pack_ptr=q_scale_pack_ptr,
523
+ cu_seqlens_q_ptr=cu_seqlens_q_ptr,
524
+ problem_size=problem_size,
525
+ stream=stream,
526
+ )
527
+ compiled(
528
+ q_ptr,
529
+ q_scale_ptr,
530
+ q_pack_ptr,
531
+ q_scale_pack_ptr,
532
+ cu_seqlens_q_ptr,
533
+ problem_size,
534
+ stream,
535
+ )
536
+ return q_pack, q_scale_pack
537
+
538
+
539
+ def _compile_fp4_decode_packed_q_kernel(
540
+ *,
541
+ fmt: Fp4FormatSpec,
542
+ causal: bool,
543
+ compact_schedule: bool,
544
+ device_arch: tuple[int, int],
545
+ use_tmem_load_red: bool,
546
+ q_pack_ptr: cute.Pointer,
547
+ k_ptr: cute.Pointer,
548
+ q_scale_pack_ptr: cute.Pointer,
549
+ k_scale_ptr: cute.Pointer,
550
+ scores_ptr: cute.Pointer,
551
+ kv_indices_ptr: cute.Pointer,
552
+ cu_seqlens_q_ptr: cute.Pointer,
553
+ cu_seqlens_k_ptr: cute.Pointer,
554
+ cu_page_offsets_ptr: cute.Pointer,
555
+ qo_offset_ptr: cute.Pointer,
556
+ problem_size: tuple,
557
+ stream: cuda.CUstream,
558
+ ):
559
+ key = (
560
+ "fp4_indexer_decode_packed_q_sm100",
561
+ fmt.name,
562
+ bool(causal),
563
+ bool(compact_schedule),
564
+ device_arch,
565
+ )
566
+ if key not in _FP4_COMPILE_CACHE:
567
+ kernel = Fp4IndexerDecodePackedQSm100(
568
+ fmt=fmt.name,
569
+ causal=causal,
570
+ compact_schedule=compact_schedule,
571
+ use_tmem_load_red=use_tmem_load_red,
572
+ )
573
+ _FP4_COMPILE_CACHE[key] = cute.compile(
574
+ kernel,
575
+ q_pack_ptr,
576
+ k_ptr,
577
+ q_scale_pack_ptr,
578
+ k_scale_ptr,
579
+ scores_ptr,
580
+ kv_indices_ptr,
581
+ cu_seqlens_q_ptr,
582
+ cu_seqlens_k_ptr,
583
+ cu_page_offsets_ptr,
584
+ qo_offset_ptr,
585
+ problem_size,
586
+ stream,
587
+ )
588
+ return _FP4_COMPILE_CACHE[key]
589
+
590
+
591
+ def _run_fp4_decode_packed_q_scores(
592
+ q_pack: torch.Tensor,
593
+ k_bytes: torch.Tensor,
594
+ q_scale_pack: torch.Tensor,
595
+ k_scale_storage: torch.Tensor,
596
+ scores: torch.Tensor,
597
+ kv_indices: torch.Tensor,
598
+ cu_seqlens_q: torch.Tensor,
599
+ cu_seqlens_k: torch.Tensor,
600
+ cu_page_offsets: torch.Tensor,
601
+ qo_offset_arg: torch.Tensor,
602
+ *,
603
+ fmt: Fp4FormatSpec,
604
+ causal: bool,
605
+ has_qo_offset: int,
606
+ heads_q: int,
607
+ heads_k: int,
608
+ batch: int,
609
+ max_k_tiles: int,
610
+ total_q: int,
611
+ device_arch: tuple[int, int],
612
+ use_tmem_load_red: bool,
613
+ ) -> None:
614
+ page_count = int(k_bytes.shape[0])
615
+ rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA)
616
+ compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA)
617
+ compact_schedule = compact_groups < rectangular_groups
618
+ if compact_schedule:
619
+ scores.fill_(float("-inf"))
620
+
621
+ q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
622
+ k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
623
+ q_scale_pack_ptr = make_ptr(
624
+ fmt.cutlass_scale_dtype,
625
+ q_scale_pack.data_ptr(),
626
+ cute.AddressSpace.gmem,
627
+ assumed_align=32,
628
+ )
629
+ k_scale_ptr = make_ptr(
630
+ fmt.cutlass_scale_dtype,
631
+ k_scale_storage.data_ptr(),
632
+ cute.AddressSpace.gmem,
633
+ assumed_align=32,
634
+ )
635
+ scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
636
+ kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
637
+ cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
638
+ cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
639
+ cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
640
+ qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
641
+ problem_size = (
642
+ Int32(_PAGE_SIZE),
643
+ Int32(max_k_tiles * _PAGE_SIZE),
644
+ Int32(_HEAD_DIM),
645
+ Int32(batch * heads_k),
646
+ Int32(page_count * heads_k),
647
+ Int32(heads_q),
648
+ Int32(heads_k),
649
+ Int32(batch),
650
+ Int32(max_k_tiles),
651
+ Int32(total_q),
652
+ Int32(has_qo_offset),
653
+ )
654
+ stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream)
655
+ compiled = _compile_fp4_decode_packed_q_kernel(
656
+ fmt=fmt,
657
+ causal=causal,
658
+ compact_schedule=compact_schedule,
659
+ device_arch=device_arch,
660
+ use_tmem_load_red=use_tmem_load_red,
661
+ q_pack_ptr=q_pack_ptr,
662
+ k_ptr=k_ptr,
663
+ q_scale_pack_ptr=q_scale_pack_ptr,
664
+ k_scale_ptr=k_scale_ptr,
665
+ scores_ptr=scores_ptr,
666
+ kv_indices_ptr=kv_indices_ptr,
667
+ cu_seqlens_q_ptr=cu_seqlens_q_ptr,
668
+ cu_seqlens_k_ptr=cu_seqlens_k_ptr,
669
+ cu_page_offsets_ptr=cu_page_offsets_ptr,
670
+ qo_offset_ptr=qo_offset_ptr,
671
+ problem_size=problem_size,
672
+ stream=stream,
673
+ )
674
+ compiled(
675
+ q_pack_ptr,
676
+ k_ptr,
677
+ q_scale_pack_ptr,
678
+ k_scale_ptr,
679
+ scores_ptr,
680
+ kv_indices_ptr,
681
+ cu_seqlens_q_ptr,
682
+ cu_seqlens_k_ptr,
683
+ cu_page_offsets_ptr,
684
+ qo_offset_ptr,
685
+ problem_size,
686
+ stream,
687
+ )
688
+
689
+
690
+ def _compile_fp4_qk_kernel(
691
+ *,
692
+ fmt: Fp4FormatSpec,
693
+ causal: bool,
694
+ preordered_q_scale_tma: bool,
695
+ compact_schedule: bool,
696
+ device_arch: tuple[int, int],
697
+ use_tmem_load_red: bool,
698
+ q_ptr: cute.Pointer,
699
+ k_ptr: cute.Pointer,
700
+ q_scale_ptr: cute.Pointer,
701
+ k_scale_ptr: cute.Pointer,
702
+ scores_ptr: cute.Pointer,
703
+ kv_indices_ptr: cute.Pointer,
704
+ cu_seqlens_q_ptr: cute.Pointer,
705
+ cu_seqlens_k_ptr: cute.Pointer,
706
+ cu_page_offsets_ptr: cute.Pointer,
707
+ qo_offset_ptr: cute.Pointer,
708
+ problem_size: tuple,
709
+ stream: cuda.CUstream,
710
+ ):
711
+ key = (
712
+ "fp4_indexer_staged_mma_sm100",
713
+ fmt.name,
714
+ bool(causal),
715
+ bool(preordered_q_scale_tma),
716
+ bool(compact_schedule),
717
+ device_arch,
718
+ )
719
+ if key not in _FP4_COMPILE_CACHE:
720
+ kernel = Fp4IndexerStagedMmaSm100(
721
+ fmt=fmt.name,
722
+ causal=causal,
723
+ preordered_q_scale_tma=preordered_q_scale_tma,
724
+ compact_schedule=compact_schedule,
725
+ use_tmem_load_red=use_tmem_load_red,
726
+ )
727
+ _FP4_COMPILE_CACHE[key] = cute.compile(
728
+ kernel,
729
+ q_ptr,
730
+ k_ptr,
731
+ q_scale_ptr,
732
+ k_scale_ptr,
733
+ scores_ptr,
734
+ kv_indices_ptr,
735
+ cu_seqlens_q_ptr,
736
+ cu_seqlens_k_ptr,
737
+ cu_page_offsets_ptr,
738
+ qo_offset_ptr,
739
+ problem_size,
740
+ stream,
741
+ )
742
+ return _FP4_COMPILE_CACHE[key]
743
+
744
+
745
+ def fp4_indexer_block_scores(
746
+ q_fp4: torch.Tensor,
747
+ k_fp4: torch.Tensor,
748
+ q_scale: torch.Tensor,
749
+ k_scale: torch.Tensor,
750
+ cu_seqlens_q: torch.Tensor,
751
+ cu_seqlens_k: torch.Tensor,
752
+ cu_page_offsets: torch.Tensor,
753
+ *,
754
+ max_seqlen_q: int,
755
+ max_seqlen_k: int,
756
+ kv_indices: torch.Tensor,
757
+ fp4_format: str,
758
+ causal: bool = False,
759
+ qo_offset: Optional[torch.Tensor] = None,
760
+ scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT,
761
+ ) -> torch.Tensor:
762
+ """Return FP4 QK max scores per 128-token KV page.
763
+
764
+ Parameters
765
+ ----------
766
+ q_fp4 : torch.Tensor
767
+ Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last
768
+ dimension stores two FP4 values per byte for logical head dimension
769
+ 128.
770
+ k_fp4 : torch.Tensor
771
+ Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``.
772
+ q_scale : torch.Tensor
773
+ Q scale tensor. With ``scale_layout="public"``, shape is
774
+ ``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use
775
+ ``fp4_indexer_reorder_scales_for_mma_cute`` output layout.
776
+ k_scale : torch.Tensor
777
+ K scale tensor. With ``scale_layout="public"``, shape is
778
+ ``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the
779
+ preordered MMA scale layout.
780
+ cu_seqlens_q : torch.Tensor
781
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
782
+ cu_seqlens_k : torch.Tensor
783
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
784
+ cu_page_offsets : torch.Tensor
785
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request
786
+ page counts.
787
+ max_seqlen_q : int
788
+ Maximum Q sequence length.
789
+ max_seqlen_k : int
790
+ Maximum KV sequence length.
791
+ kv_indices : torch.Tensor
792
+ Flattened physical page indices with shape ``[sum_pages]`` and dtype
793
+ int32.
794
+ fp4_format : str
795
+ ``"mxfp4"`` or ``"nvfp4"``.
796
+ causal : bool, optional
797
+ Whether to apply causal masking.
798
+ qo_offset : torch.Tensor, optional
799
+ Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid
800
+ only when ``causal=True``.
801
+ scale_layout : str, optional
802
+ ``"public"`` accepts logical public scale tensors and launches a scale
803
+ reorder kernel. ``"preordered_mma"`` expects preordered MMA scale
804
+ tensors and skips the reorder.
805
+
806
+ Returns
807
+ -------
808
+ torch.Tensor
809
+ Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32.
810
+ Entries beyond the valid KV page range are ``-inf``.
811
+ """
812
+
813
+ spec = normalize_fp4_format(fp4_format)
814
+ causal = bool(causal)
815
+ scale_layout = normalize_scale_layout(scale_layout)
816
+ use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE
817
+ q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4")
818
+ k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4")
819
+ total_q, heads_q, _ = (int(v) for v in q_bytes.shape)
820
+ page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape)
821
+ if page_size != _PAGE_SIZE:
822
+ raise ValueError(f"k_fp4 page_size must be 128, got {page_size}")
823
+ if heads_q % heads_k != 0:
824
+ raise ValueError("num_qo_heads must be divisible by num_kv_heads")
825
+ _require_cuda_tensor(q_fp4, name="q_fp4")
826
+ _require_cuda_tensor(k_fp4, name="k_fp4")
827
+ device_arch = _device_arch(q_fp4.device)
828
+ use_tmem_load_red = _supports_tmem_load_red(device_arch)
829
+ _require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device)
830
+ _require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device)
831
+ _require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device)
832
+ if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device:
833
+ raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4")
834
+ if scale_layout == _PUBLIC_SCALE_LAYOUT:
835
+ validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q)
836
+ validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k)
837
+ else:
838
+ validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q)
839
+ validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k)
840
+ batch = int(cu_seqlens_q.shape[0]) - 1
841
+ if batch < 0:
842
+ raise ValueError("cu_seqlens_q must have shape [B + 1]")
843
+ if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape:
844
+ raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]")
845
+ if q_bytes.data_ptr() % 128 != 0:
846
+ raise ValueError("q_fp4 data pointer must be 128B aligned for TMA")
847
+ if k_bytes.data_ptr() % 128 != 0:
848
+ raise ValueError("k_fp4 data pointer must be 128B aligned for TMA")
849
+ if kv_indices is None:
850
+ raise ValueError("kv_indices is required")
851
+ if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1:
852
+ raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device")
853
+ if not kv_indices.is_contiguous():
854
+ raise ValueError("kv_indices must be contiguous")
855
+ if qo_offset is not None:
856
+ if not causal:
857
+ raise ValueError("qo_offset is only valid when causal=True")
858
+ if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,):
859
+ raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device")
860
+ if not qo_offset.is_contiguous():
861
+ raise ValueError("qo_offset must be contiguous")
862
+
863
+ m_extent = int(max_seqlen_q)
864
+ max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE)
865
+ n_aligned = max_k_tiles * _PAGE_SIZE
866
+ if max_k_tiles == 0:
867
+ return torch.full(
868
+ (heads_q, 0, total_q),
869
+ float("-inf"),
870
+ dtype=torch.float32,
871
+ device=q_fp4.device,
872
+ )
873
+
874
+ scores = torch.empty(
875
+ (heads_q, max_k_tiles, total_q),
876
+ dtype=torch.float32,
877
+ device=q_fp4.device,
878
+ )
879
+ if qo_offset is None:
880
+ qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device)
881
+ has_qo_offset = 0
882
+ else:
883
+ qo_offset_arg = qo_offset
884
+ has_qo_offset = 1
885
+ if scale_layout == _PUBLIC_SCALE_LAYOUT:
886
+ q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute(
887
+ q_scale,
888
+ k_scale,
889
+ fp4_format=spec.name,
890
+ )
891
+ else:
892
+ q_scale_arg = q_scale
893
+ k_scale_arg = k_scale
894
+ scale_assumed_align = 32
895
+ if q_scale_arg.data_ptr() % scale_assumed_align != 0:
896
+ raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale")
897
+ if k_scale_arg.data_ptr() % scale_assumed_align != 0:
898
+ raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale")
899
+ use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV
900
+ if use_decode_packed_q:
901
+ q_pack, q_scale_pack = _pack_decode_q_for_mma(
902
+ q_bytes,
903
+ q_scale_arg,
904
+ cu_seqlens_q,
905
+ fmt=spec,
906
+ heads_q=heads_q,
907
+ heads_k=heads_k,
908
+ batch=batch,
909
+ )
910
+ _run_fp4_decode_packed_q_scores(
911
+ q_pack,
912
+ k_bytes,
913
+ q_scale_pack,
914
+ k_scale_arg,
915
+ scores,
916
+ kv_indices,
917
+ cu_seqlens_q,
918
+ cu_seqlens_k,
919
+ cu_page_offsets,
920
+ qo_offset_arg,
921
+ fmt=spec,
922
+ causal=causal,
923
+ has_qo_offset=has_qo_offset,
924
+ heads_q=heads_q,
925
+ heads_k=heads_k,
926
+ batch=batch,
927
+ max_k_tiles=max_k_tiles,
928
+ total_q=total_q,
929
+ device_arch=device_arch,
930
+ use_tmem_load_red=use_tmem_load_red,
931
+ )
932
+ return scores
933
+ prefill_compact_task_count = 0
934
+ prefill_compact_schedule = False
935
+ if causal and has_qo_offset == 0:
936
+ k_tiles_per_cta = k_tiles_per_cta_for(causal)
937
+ q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0])
938
+ k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta)
939
+ rectangular_task_count = q_tile_count * k_group_count
940
+ prefill_compact_task_count = min(
941
+ rectangular_task_count,
942
+ _causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta),
943
+ )
944
+ prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19
945
+ if prefill_compact_schedule:
946
+ scores.fill_(float("-inf"))
947
+ q_ptr = make_ptr(
948
+ cutlass.Uint8,
949
+ q_bytes.data_ptr(),
950
+ cute.AddressSpace.gmem,
951
+ assumed_align=128,
952
+ )
953
+ k_ptr = make_ptr(
954
+ cutlass.Uint8,
955
+ k_bytes.data_ptr(),
956
+ cute.AddressSpace.gmem,
957
+ assumed_align=128,
958
+ )
959
+ q_scale_ptr = make_ptr(
960
+ spec.cutlass_scale_dtype,
961
+ q_scale_arg.data_ptr(),
962
+ cute.AddressSpace.gmem,
963
+ assumed_align=scale_assumed_align,
964
+ )
965
+ k_scale_ptr = make_ptr(
966
+ spec.cutlass_scale_dtype,
967
+ k_scale_arg.data_ptr(),
968
+ cute.AddressSpace.gmem,
969
+ assumed_align=scale_assumed_align,
970
+ )
971
+ scores_ptr = make_ptr(
972
+ cutlass.Float32,
973
+ scores.data_ptr(),
974
+ cute.AddressSpace.gmem,
975
+ assumed_align=16,
976
+ )
977
+ kv_indices_ptr = make_ptr(
978
+ cutlass.Int32,
979
+ kv_indices.data_ptr(),
980
+ cute.AddressSpace.gmem,
981
+ assumed_align=4,
982
+ )
983
+ cu_seqlens_q_ptr = make_ptr(
984
+ cutlass.Int32,
985
+ cu_seqlens_q.data_ptr(),
986
+ cute.AddressSpace.gmem,
987
+ assumed_align=4,
988
+ )
989
+ cu_seqlens_k_ptr = make_ptr(
990
+ cutlass.Int32,
991
+ cu_seqlens_k.data_ptr(),
992
+ cute.AddressSpace.gmem,
993
+ assumed_align=4,
994
+ )
995
+ cu_page_offsets_ptr = make_ptr(
996
+ cutlass.Int32,
997
+ cu_page_offsets.data_ptr(),
998
+ cute.AddressSpace.gmem,
999
+ assumed_align=4,
1000
+ )
1001
+ qo_offset_ptr = make_ptr(
1002
+ cutlass.Int32,
1003
+ qo_offset_arg.data_ptr(),
1004
+ cute.AddressSpace.gmem,
1005
+ assumed_align=4,
1006
+ )
1007
+ problem_size = (
1008
+ Int32(m_extent),
1009
+ Int32(n_aligned),
1010
+ Int32(_HEAD_DIM),
1011
+ Int32(batch * heads_q),
1012
+ Int32(page_count * heads_k),
1013
+ Int32(heads_q),
1014
+ Int32(heads_k),
1015
+ Int32(batch),
1016
+ Int32(max_k_tiles),
1017
+ Int32(total_q),
1018
+ Int32(has_qo_offset),
1019
+ Int32(prefill_compact_task_count),
1020
+ )
1021
+ stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream)
1022
+ compiled = _compile_fp4_qk_kernel(
1023
+ fmt=spec,
1024
+ causal=causal,
1025
+ preordered_q_scale_tma=use_preordered_q_scale_tma,
1026
+ compact_schedule=prefill_compact_schedule,
1027
+ device_arch=device_arch,
1028
+ use_tmem_load_red=use_tmem_load_red,
1029
+ q_ptr=q_ptr,
1030
+ k_ptr=k_ptr,
1031
+ q_scale_ptr=q_scale_ptr,
1032
+ k_scale_ptr=k_scale_ptr,
1033
+ scores_ptr=scores_ptr,
1034
+ kv_indices_ptr=kv_indices_ptr,
1035
+ cu_seqlens_q_ptr=cu_seqlens_q_ptr,
1036
+ cu_seqlens_k_ptr=cu_seqlens_k_ptr,
1037
+ cu_page_offsets_ptr=cu_page_offsets_ptr,
1038
+ qo_offset_ptr=qo_offset_ptr,
1039
+ problem_size=problem_size,
1040
+ stream=stream,
1041
+ )
1042
+ compiled(
1043
+ q_ptr,
1044
+ k_ptr,
1045
+ q_scale_ptr,
1046
+ k_scale_ptr,
1047
+ scores_ptr,
1048
+ kv_indices_ptr,
1049
+ cu_seqlens_q_ptr,
1050
+ cu_seqlens_k_ptr,
1051
+ cu_page_offsets_ptr,
1052
+ qo_offset_ptr,
1053
+ problem_size,
1054
+ stream,
1055
+ )
1056
+ return scores
1057
+
1058
+
1059
+ __all__ = [
1060
+ "fp4_indexer_block_scores",
1061
+ ]
build/torch211-cxx11-cu128-x86_64-linux/interface.py ADDED
@@ -0,0 +1,2011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Sparse attention interface.
5
+
6
+ Current delivery scope:
7
+ - head dimension is supported only for D=128
8
+
9
+ Public API:
10
+ sparse_atten_func(...)
11
+ sparse_decode_atten_func(...)
12
+ SparseDecodePagedAttentionWrapper
13
+
14
+ Internal forward core:
15
+ _sparse_atten_csr_varlen_forward(...)
16
+
17
+ Preprocessing (external, done once):
18
+ q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr()
19
+ -> k2q_row_ptr [head_kv, total_rows + 1] int32
20
+ -> k2q_q_indices [head_kv, total_q * topK] int32
21
+ """
22
+
23
+ import math
24
+ import os
25
+ from typing import Optional
26
+
27
+ import cutlass
28
+ import cutlass.cute as cute
29
+ import torch
30
+ from cutlass import Float32, Int32
31
+ from cutlass.cute.runtime import from_dlpack
32
+
33
+ from .src.sm100.fwd.combine import combine
34
+ from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100
35
+ from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100
36
+ from .src.sm100.prepare_scheduler import (
37
+ SparseAttentionSchedule,
38
+ prepare_sparse_fwd_schedule_and_split,
39
+ )
40
+ from .src.sm100.decode_schedule import (
41
+ DecodeAttentionSchedule,
42
+ prepare_decode_schedule,
43
+ )
44
+ from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter
45
+ from .src.common.tma_utils import (
46
+ create_q_gather4_tma_desc,
47
+ )
48
+
49
+ _compile_cache: dict = {}
50
+ _TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12
51
+ _SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32)
52
+ _SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn)
53
+ _SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn)
54
+ _SUPPORTED_DECODE_QHEAD_PER_KV = 16
55
+
56
+
57
+ def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype:
58
+ supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn}
59
+ if partial_dtype not in supported:
60
+ raise TypeError(
61
+ "partial_dtype must be one of torch.float32 / torch.bfloat16 / "
62
+ "torch.float16 / torch.float8_e4m3fn, "
63
+ f"got {partial_dtype}"
64
+ )
65
+ return partial_dtype
66
+
67
+
68
+ def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype:
69
+ dtype = fallback if dtype is None else dtype
70
+ if dtype not in _SUPPORTED_FWD_MMA_DTYPES:
71
+ raise TypeError(
72
+ f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}"
73
+ )
74
+ return dtype
75
+
76
+
77
+ def _resolve_forward_mma_dtypes(
78
+ q: torch.Tensor,
79
+ k: torch.Tensor,
80
+ v: torch.Tensor,
81
+ qk_dtype: Optional[torch.dtype],
82
+ pv_dtype: Optional[torch.dtype],
83
+ ) -> tuple[torch.dtype, torch.dtype]:
84
+ qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype")
85
+ if pv_dtype is None:
86
+ # Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V
87
+ # stages both K and V as BF16 compute operands.
88
+ if (
89
+ q.dtype == torch.bfloat16
90
+ and k.dtype == torch.float8_e4m3fn
91
+ and v.dtype == torch.float8_e4m3fn
92
+ ):
93
+ pv_dtype = torch.bfloat16
94
+ else:
95
+ pv_dtype = v.dtype
96
+ pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype")
97
+
98
+ if q.dtype != qk_dtype:
99
+ raise ValueError(
100
+ "qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported"
101
+ )
102
+ if k.dtype != qk_dtype:
103
+ if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16):
104
+ raise ValueError(
105
+ "unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported"
106
+ )
107
+ if v.dtype != pv_dtype:
108
+ if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16):
109
+ raise ValueError(
110
+ "unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported"
111
+ )
112
+ return qk_dtype, pv_dtype
113
+
114
+
115
+ def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4):
116
+ tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True)
117
+ return tensor.mark_layout_dynamic(leading_dim=0)
118
+
119
+
120
+ def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype):
121
+ if dtype == torch.bfloat16:
122
+ return cutlass.BFloat16
123
+ if dtype == torch.float16:
124
+ return cutlass.Float16
125
+ if dtype == torch.float8_e4m3fn:
126
+ return cutlass.Float8E4M3FN
127
+ raise TypeError(
128
+ f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}"
129
+ )
130
+
131
+
132
+ def _prepare_paged_kv_for_tma(k, v, blk_kv: int):
133
+ page_size = int(k.shape[2])
134
+ if page_size != blk_kv:
135
+ raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}")
136
+ return k, v
137
+
138
+
139
+ def _validate_cu_seqlens(
140
+ cu_seqlens: torch.Tensor,
141
+ *,
142
+ name: str,
143
+ device: torch.device,
144
+ ) -> None:
145
+ if cu_seqlens.device != device:
146
+ raise ValueError(f"{name} must be on the same device as q")
147
+ if cu_seqlens.dtype != torch.int32:
148
+ raise TypeError(f"{name} must be torch.int32")
149
+ if cu_seqlens.ndim != 1:
150
+ raise ValueError(f"{name} must have shape [B + 1]")
151
+ if cu_seqlens.shape[0] < 1:
152
+ raise ValueError(f"{name} must have at least one element")
153
+ if not cu_seqlens.is_contiguous():
154
+ raise ValueError(f"{name} must be contiguous")
155
+
156
+
157
+ def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int:
158
+ return int(k2q_row_ptr.shape[1] - 1)
159
+
160
+
161
+ def _validate_csr_varlen_inputs(
162
+ q: torch.Tensor,
163
+ k: torch.Tensor,
164
+ v: torch.Tensor,
165
+ k2q_row_ptr: torch.Tensor,
166
+ k2q_q_indices: torch.Tensor,
167
+ topK: int,
168
+ blk_kv: int,
169
+ page_table: Optional[torch.Tensor],
170
+ cu_seqlens_q: torch.Tensor,
171
+ cu_seqlens_k: torch.Tensor,
172
+ seqused_k: Optional[torch.Tensor],
173
+ ) -> tuple[int, int]:
174
+ if q.ndim != 3:
175
+ raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]")
176
+ if q.dtype not in _SUPPORTED_FWD_DTYPES:
177
+ raise TypeError(
178
+ "CSR sparse forward supports only torch.bfloat16 and "
179
+ f"torch.float8_e4m3fn Q/K/V, got {q.dtype}"
180
+ )
181
+ if q.device != k.device or q.device != v.device:
182
+ raise ValueError("q, k, v must be on the same device")
183
+ mixed_fp8_kv_bf16_q = (
184
+ q.dtype == torch.bfloat16
185
+ and k.dtype == torch.float8_e4m3fn
186
+ and v.dtype == torch.float8_e4m3fn
187
+ )
188
+ if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype):
189
+ raise ValueError(
190
+ "q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache"
191
+ )
192
+ if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]:
193
+ raise ValueError("q, k, v must have the same head dimension")
194
+ dim = q.shape[-1]
195
+ if dim != 128:
196
+ raise NotImplementedError(
197
+ f"CSR sparse forward currently supports only D=128, got D={dim}"
198
+ )
199
+ if page_table is None:
200
+ if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]:
201
+ raise ValueError("k and v must have the same [Hkv, D] tail dimensions")
202
+ head_kv = k.shape[-2]
203
+ else:
204
+ if k.ndim != 4 or v.ndim != 4:
205
+ raise ValueError(
206
+ "Sparse Page Attention requires k and v to have shape "
207
+ "[num_pages, Hkv, page_size, D]"
208
+ )
209
+ if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]:
210
+ raise ValueError(
211
+ "Sparse Page Attention k and v must have the same Hkv and D"
212
+ )
213
+ head_kv = k.shape[1]
214
+ if (
215
+ q.device != k2q_row_ptr.device
216
+ or q.device != k2q_q_indices.device
217
+ ):
218
+ raise ValueError("CSR metadata must be on the same device as q")
219
+ if (
220
+ k2q_row_ptr.dtype != torch.int32
221
+ or k2q_q_indices.dtype != torch.int32
222
+ ):
223
+ raise TypeError("CSR metadata tensors must be torch.int32")
224
+ if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2:
225
+ raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2")
226
+
227
+ _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device)
228
+ _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device)
229
+ if cu_seqlens_k.shape != cu_seqlens_q.shape:
230
+ raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q")
231
+ batch = int(cu_seqlens_q.shape[0] - 1)
232
+ total_q = q.shape[0]
233
+
234
+ head_q = q.shape[1]
235
+ if head_q % head_kv != 0:
236
+ raise ValueError("q.shape[1] must be divisible by Hkv")
237
+ qhead_per_kv = head_q // head_kv
238
+ if qhead_per_kv not in (1, 2, 4, 8, 16):
239
+ raise NotImplementedError(
240
+ "CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}"
241
+ )
242
+ if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv:
243
+ raise ValueError("CSR metadata head dimension must match KV head count")
244
+ if k2q_q_indices.shape[1] < total_q * topK:
245
+ raise ValueError(
246
+ f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})"
247
+ )
248
+ if k2q_row_ptr.shape[1] < 1:
249
+ raise ValueError("k2q_row_ptr must contain at least one row pointer column")
250
+
251
+ if page_table is None:
252
+ if seqused_k is not None:
253
+ raise ValueError("seqused_k is only supported together with page_table")
254
+ total_k = k.shape[0]
255
+ if k.ndim != 3 or v.ndim != 3:
256
+ raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]")
257
+ if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]):
258
+ raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]")
259
+ else:
260
+ if page_table.device != q.device:
261
+ raise ValueError("page_table must be on the same device as q")
262
+ if page_table.dtype != torch.int32:
263
+ raise TypeError("page_table must be torch.int32")
264
+ if page_table.ndim != 2 or page_table.shape[0] != batch:
265
+ raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
266
+ if page_table.stride(-1) != 1:
267
+ raise ValueError("page_table must be contiguous in the last dimension")
268
+ if k.ndim != 4 or v.ndim != 4:
269
+ raise ValueError(
270
+ "Sparse Page Attention requires k and v to have shape "
271
+ "[num_pages, Hkv, page_size, D]"
272
+ )
273
+ if k.shape != v.shape:
274
+ raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}")
275
+ if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]:
276
+ raise ValueError(
277
+ "Sparse Page Attention k and v must match "
278
+ "[num_pages, Hkv, page_size, D]"
279
+ )
280
+ page_size = int(k.shape[2])
281
+ if page_size != blk_kv:
282
+ raise ValueError(
283
+ f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; "
284
+ "require page_size == blk_kv"
285
+ )
286
+ if seqused_k is not None:
287
+ if seqused_k.device != q.device:
288
+ raise ValueError("seqused_k must be on the same device as q")
289
+ if seqused_k.dtype != torch.int32:
290
+ raise TypeError("seqused_k must be torch.int32")
291
+ if seqused_k.shape != (batch,):
292
+ raise ValueError("seqused_k must have shape [B]")
293
+ if not seqused_k.is_contiguous():
294
+ raise ValueError("seqused_k must be contiguous")
295
+ if topK not in _SUPPORTED_SPARSE_TOPK:
296
+ raise ValueError(
297
+ f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}"
298
+ )
299
+ return batch, head_kv
300
+
301
+
302
+ def _validate_csr_varlen_nvfp4_kv_inputs(
303
+ q: torch.Tensor,
304
+ k: torch.Tensor,
305
+ v: torch.Tensor,
306
+ k_scale_128x4: torch.Tensor,
307
+ v_scale_128x4: torch.Tensor,
308
+ k_global_scale: Optional[torch.Tensor],
309
+ v_global_scale: Optional[torch.Tensor],
310
+ k2q_row_ptr: torch.Tensor,
311
+ k2q_q_indices: torch.Tensor,
312
+ topK: int,
313
+ blk_kv: int,
314
+ page_table: Optional[torch.Tensor],
315
+ cu_seqlens_q: torch.Tensor,
316
+ cu_seqlens_k: torch.Tensor,
317
+ seqused_k: Optional[torch.Tensor],
318
+ ) -> tuple[int, int]:
319
+ if q.ndim != 3:
320
+ raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]")
321
+ if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn):
322
+ raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}")
323
+ if q.shape[-1] != 128:
324
+ raise NotImplementedError(
325
+ f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}"
326
+ )
327
+ if k.dtype != torch.uint8 or v.dtype != torch.uint8:
328
+ raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}")
329
+ if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8:
330
+ raise TypeError(
331
+ "KVFP4 block scales must be torch.uint8 E4M3 tensors, got "
332
+ f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}"
333
+ )
334
+ if k_global_scale is not None and k_global_scale.dtype != torch.float32:
335
+ raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None")
336
+ if v_global_scale is not None and v_global_scale.dtype != torch.float32:
337
+ raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None")
338
+ tensors = (
339
+ k,
340
+ v,
341
+ k_scale_128x4,
342
+ v_scale_128x4,
343
+ k2q_row_ptr,
344
+ k2q_q_indices,
345
+ cu_seqlens_q,
346
+ cu_seqlens_k,
347
+ )
348
+ optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None)
349
+ if any(t.device != q.device for t in tensors + optional_tensors):
350
+ raise ValueError("KVFP4 inputs and metadata must be on the same device as q")
351
+ if k.shape != v.shape:
352
+ raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}")
353
+ packed_dim = q.shape[-1] // 2
354
+ scale_cols = q.shape[-1] // 16
355
+ if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2:
356
+ raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors")
357
+ if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols:
358
+ raise ValueError(
359
+ "KVFP4 block scales must have at least D/16 columns; "
360
+ f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}"
361
+ )
362
+ if k_global_scale is not None and k_global_scale.numel() < 1:
363
+ raise ValueError("KVFP4 K global scale must contain at least one element")
364
+ if v_global_scale is not None and v_global_scale.numel() < 1:
365
+ raise ValueError("KVFP4 V global scale must contain at least one element")
366
+
367
+ if page_table is None:
368
+ if seqused_k is not None:
369
+ raise ValueError("seqused_k is only supported together with page_table")
370
+ if k.ndim != 3:
371
+ raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]")
372
+ if k.shape[-1] != packed_dim:
373
+ raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}")
374
+ total_k = int(k.shape[0])
375
+ head_kv = int(k.shape[1])
376
+ required_scale_rows = total_k * head_kv
377
+ else:
378
+ if k.ndim != 4:
379
+ raise ValueError(
380
+ "KVFP4 Sparse Page Attention requires k/v shape "
381
+ "[num_pages, Hkv, page_size, D/2]"
382
+ )
383
+ if k.shape[-1] != packed_dim:
384
+ raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}")
385
+ page_size = int(k.shape[2])
386
+ if page_size != int(blk_kv):
387
+ raise ValueError(
388
+ f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}"
389
+ )
390
+ head_kv = int(k.shape[1])
391
+ required_scale_rows = int(k.shape[0]) * head_kv * page_size
392
+ if page_table.device != q.device:
393
+ raise ValueError("page_table must be on the same device as q")
394
+ if page_table.dtype != torch.int32:
395
+ raise TypeError("page_table must be torch.int32")
396
+ if page_table.ndim != 2:
397
+ raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
398
+ if page_table.stride(-1) != 1:
399
+ raise ValueError("page_table must be contiguous in the last dimension")
400
+ if seqused_k is not None:
401
+ if seqused_k.device != q.device:
402
+ raise ValueError("seqused_k must be on the same device as q")
403
+ if seqused_k.dtype != torch.int32:
404
+ raise TypeError("seqused_k must be torch.int32")
405
+ if not seqused_k.is_contiguous():
406
+ raise ValueError("seqused_k must be contiguous")
407
+
408
+ padded_scale_rows = ((required_scale_rows + 127) // 128) * 128
409
+ padded_scale_cols = ((scale_cols + 3) // 4) * 4
410
+ for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)):
411
+ if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols:
412
+ raise ValueError(
413
+ f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, "
414
+ f"need at least {(padded_scale_rows, padded_scale_cols)}"
415
+ )
416
+
417
+ if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device:
418
+ raise ValueError("CSR metadata must be on the same device as q")
419
+ if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32:
420
+ raise TypeError("CSR metadata tensors must be torch.int32")
421
+ if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2:
422
+ raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2")
423
+ _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device)
424
+ _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device)
425
+ if cu_seqlens_k.shape != cu_seqlens_q.shape:
426
+ raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q")
427
+ batch = int(cu_seqlens_q.shape[0] - 1)
428
+ if page_table is not None and page_table.shape[0] != batch:
429
+ raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
430
+ if seqused_k is not None and seqused_k.shape != (batch,):
431
+ raise ValueError("seqused_k must have shape [B]")
432
+ head_q = int(q.shape[1])
433
+ if head_q % head_kv != 0:
434
+ raise ValueError("q.shape[1] must be divisible by Hkv")
435
+ qhead_per_kv = head_q // head_kv
436
+ if qhead_per_kv not in (1, 2, 4, 8, 16):
437
+ raise NotImplementedError(
438
+ "KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}"
439
+ )
440
+ if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv:
441
+ raise ValueError("CSR metadata head dimension must match KV head count")
442
+ if k2q_q_indices.shape[1] < q.shape[0] * topK:
443
+ raise ValueError(
444
+ f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})"
445
+ )
446
+ if k2q_row_ptr.shape[1] < 1:
447
+ raise ValueError("k2q_row_ptr must contain at least one row pointer column")
448
+ if topK not in _SUPPORTED_SPARSE_TOPK:
449
+ raise ValueError(
450
+ f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}"
451
+ )
452
+ return batch, head_kv
453
+
454
+
455
+ def _validate_sparse_decode_inputs(
456
+ q: torch.Tensor,
457
+ k: torch.Tensor,
458
+ v: torch.Tensor,
459
+ q2k_indices: Optional[torch.Tensor],
460
+ *,
461
+ page_table: torch.Tensor,
462
+ seqused_k: torch.Tensor,
463
+ seqlen_q: int,
464
+ max_seqlen_k: int,
465
+ blk_kv: int,
466
+ causal: bool,
467
+ ) -> tuple[int, int]:
468
+ if q.ndim != 3:
469
+ raise ValueError("decode attention requires q to have shape [total_q, Hq, D]")
470
+ if k.ndim != 4 or v.ndim != 4:
471
+ raise ValueError(
472
+ "decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]"
473
+ )
474
+ if q.device != k.device or q.device != v.device:
475
+ raise ValueError("decode q, k, and v must be on the same device")
476
+ if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype:
477
+ raise TypeError(
478
+ "decode attention currently supports only torch.float8_e4m3fn Q/K/V"
479
+ )
480
+ if k.shape != v.shape:
481
+ raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}")
482
+ if q.shape[-1] != 128 or k.shape[-1] != 128:
483
+ raise NotImplementedError(
484
+ f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}"
485
+ )
486
+ if not bool(causal):
487
+ raise NotImplementedError("decode attention currently supports only causal=True")
488
+ page_size = int(k.shape[2])
489
+ if page_size != int(blk_kv):
490
+ raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}")
491
+
492
+ head_kv = int(k.shape[1])
493
+ head_q = int(q.shape[1])
494
+ if head_q % head_kv != 0:
495
+ raise ValueError("decode q.shape[1] must be divisible by Hkv")
496
+ qhead_per_kv = head_q // head_kv
497
+ if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV:
498
+ raise NotImplementedError(
499
+ "decode attention currently supports only "
500
+ f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}"
501
+ )
502
+
503
+ if page_table is None:
504
+ raise ValueError("decode attention requires page_table")
505
+ if page_table.device != q.device:
506
+ raise ValueError("decode page_table must be on the same device as q")
507
+ if page_table.dtype != torch.int32:
508
+ raise TypeError("decode page_table must be torch.int32")
509
+ if page_table.ndim != 2:
510
+ raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]")
511
+ batch = int(page_table.shape[0])
512
+ if page_table.stride(-1) != 1:
513
+ raise ValueError("decode page_table must be contiguous in the last dimension")
514
+
515
+ if seqused_k is None:
516
+ raise ValueError("decode attention requires seqused_k")
517
+ if seqused_k.device != q.device:
518
+ raise ValueError("decode seqused_k must be on the same device as q")
519
+ if seqused_k.dtype != torch.int32:
520
+ raise TypeError("decode seqused_k must be torch.int32")
521
+ if seqused_k.shape != (batch,):
522
+ raise ValueError("decode seqused_k must have shape [B]")
523
+ if not seqused_k.is_contiguous():
524
+ raise ValueError("decode seqused_k must be contiguous")
525
+
526
+ seqlen_q = int(seqlen_q)
527
+ max_seqlen_k = int(max_seqlen_k)
528
+ if seqlen_q <= 0 or max_seqlen_k <= 0:
529
+ raise ValueError("decode seqlen_q and max_seqlen_k must be positive")
530
+ if int(q.shape[0]) != batch * seqlen_q:
531
+ raise ValueError("decode q.shape[0] must equal batch * seqlen_q")
532
+
533
+ if q2k_indices is not None:
534
+ if q2k_indices.device != q.device:
535
+ raise ValueError("decode q2k_indices must be on the same device as q")
536
+ if q2k_indices.dtype != torch.int32:
537
+ raise TypeError("decode q2k_indices must be torch.int32")
538
+ if q2k_indices.ndim != 3:
539
+ raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]")
540
+ if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]:
541
+ raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]")
542
+ if not q2k_indices.is_contiguous():
543
+ raise ValueError("decode q2k_indices must be contiguous")
544
+ return batch, head_kv
545
+
546
+
547
+ def _validate_schedule_common(
548
+ schedule: SparseAttentionSchedule,
549
+ *,
550
+ device: torch.device,
551
+ ) -> None:
552
+ if schedule.scheduler_metadata is None:
553
+ raise ValueError("schedule.scheduler_metadata is required")
554
+ if schedule.work_count is None:
555
+ raise ValueError("schedule.work_count is required")
556
+ metadata = schedule.scheduler_metadata
557
+ work_count = schedule.work_count
558
+ if metadata.device != device or work_count.device != device:
559
+ raise ValueError("schedule tensors must be on the same device as q")
560
+ if metadata.dtype != torch.int32 or work_count.dtype != torch.int32:
561
+ raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32")
562
+ if metadata.ndim != 2 or metadata.shape[1] != 6:
563
+ raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]")
564
+ if work_count.shape != (1,):
565
+ raise ValueError("schedule.work_count must have shape [1]")
566
+ if not metadata.is_contiguous() or not work_count.is_contiguous():
567
+ raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous")
568
+
569
+
570
+ def _validate_fwd_schedule(
571
+ schedule: SparseAttentionSchedule,
572
+ *,
573
+ q: torch.Tensor,
574
+ k2q_q_indices: torch.Tensor,
575
+ head_kv: int,
576
+ ) -> None:
577
+ _validate_schedule_common(schedule, device=q.device)
578
+ if schedule.qsplit_indices is None:
579
+ raise ValueError("schedule.qsplit_indices is required for forward")
580
+ if schedule.split_counts is None:
581
+ raise ValueError("schedule.split_counts is required for forward")
582
+ qsplit = schedule.qsplit_indices
583
+ split_counts = schedule.split_counts
584
+ if qsplit.device != q.device or split_counts.device != q.device:
585
+ raise ValueError("forward schedule tensors must be on the same device as q")
586
+ if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32:
587
+ raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32")
588
+ if qsplit.shape != k2q_q_indices.shape:
589
+ raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices")
590
+ total_q = q.shape[0]
591
+ if split_counts.shape != (total_q, head_kv):
592
+ raise ValueError(
593
+ "schedule.split_counts must have shape "
594
+ f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}"
595
+ )
596
+ if not qsplit.is_contiguous() or not split_counts.is_contiguous():
597
+ raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous")
598
+
599
+
600
+ def sparse_atten_func(
601
+ q: torch.Tensor,
602
+ k: torch.Tensor,
603
+ v: torch.Tensor,
604
+ k2q_row_ptr: torch.Tensor,
605
+ k2q_q_indices: torch.Tensor,
606
+ topK: int,
607
+ *,
608
+ cu_seqlens_q: torch.Tensor,
609
+ cu_seqlens_k: torch.Tensor,
610
+ max_seqlen_q: int,
611
+ max_seqlen_k: int,
612
+ blk_kv: int = 128,
613
+ causal: bool = False,
614
+ softmax_scale: Optional[float] = None,
615
+ lse_temperature_scale: float = 1.0,
616
+ return_temperature_lse: bool = False,
617
+ partial_dtype: torch.dtype = torch.bfloat16,
618
+ return_softmax_lse: bool = False,
619
+ page_table: Optional[torch.Tensor] = None,
620
+ seqused_k: Optional[torch.Tensor] = None,
621
+ schedule: Optional[SparseAttentionSchedule] = None,
622
+ usable_SM_count: int = -1,
623
+ qk_dtype: Optional[torch.dtype] = None,
624
+ pv_dtype: Optional[torch.dtype] = None,
625
+ ):
626
+ """Run SM100 CSR block-sparse varlen attention.
627
+
628
+ This is the public forward-only sparse attention API. It consumes
629
+ query-to-key block selections converted to CSR metadata by
630
+ ``build_k2q_csr`` and supports both dense KV layout and paged KV layout.
631
+
632
+ Parameters
633
+ ----------
634
+ q : torch.Tensor
635
+ Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and
636
+ FP8 E4M3.
637
+ k : torch.Tensor
638
+ Dense layout ``[total_k, Hkv, 128]`` or paged layout
639
+ ``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K
640
+ may be FP8 E4M3 while QK compute uses BF16 staging.
641
+ v : torch.Tensor
642
+ Same layout and head count as ``k``.
643
+ k2q_row_ptr : torch.Tensor
644
+ CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32.
645
+ k2q_q_indices : torch.Tensor
646
+ CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype
647
+ int32.
648
+ topK : int
649
+ Number of selected KV blocks per query. Supported values are
650
+ ``4, 8, 16, 32``.
651
+ cu_seqlens_q : torch.Tensor
652
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
653
+ cu_seqlens_k : torch.Tensor
654
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
655
+ max_seqlen_q : int
656
+ Maximum Q sequence length in the batch.
657
+ max_seqlen_k : int
658
+ Maximum KV sequence length in the batch.
659
+ blk_kv : int, optional
660
+ KV block size. Paged KV requires ``k.shape[2] == blk_kv``.
661
+ causal : bool, optional
662
+ Whether to apply causal masking.
663
+ softmax_scale : float, optional
664
+ Softmax scale. Defaults to ``1 / sqrt(128)``.
665
+ lse_temperature_scale : float, optional
666
+ Extra divisor used only for temperature-scaled LSE output.
667
+ return_temperature_lse : bool, optional
668
+ If True, also return LSE computed with logits scaled by
669
+ ``softmax_scale / lse_temperature_scale``. Requires
670
+ ``return_softmax_lse=True``.
671
+ partial_dtype : torch.dtype, optional
672
+ Accumulation dtype for per-block partial O. Supported values are
673
+ FP32, BF16, FP16, and FP8 E4M3.
674
+ return_softmax_lse : bool, optional
675
+ If True, return ``(out, softmax_lse)`` or
676
+ ``(out, softmax_lse, temperature_lse)``.
677
+ page_table : torch.Tensor, optional
678
+ Paged-KV physical page table with shape
679
+ ``[batch_size, max_num_pages_per_seq]`` and dtype int32.
680
+ seqused_k : torch.Tensor, optional
681
+ Shape ``[batch_size]``, dtype int32. Effective KV length per request
682
+ for paged causal attention.
683
+ schedule : SparseAttentionSchedule, optional
684
+ Prebuilt sparse forward schedule. If omitted, the schedule is built
685
+ during the call.
686
+ usable_SM_count : int, optional
687
+ Maximum number of SMs used by the scheduler. ``-1`` uses all SMs.
688
+ qk_dtype : torch.dtype, optional
689
+ Compile-time MMA operand dtype for QK. Defaults to Q storage dtype,
690
+ except supported FP8 K/V cache staging modes.
691
+ pv_dtype : torch.dtype, optional
692
+ Compile-time MMA operand dtype for PV. Defaults to V storage dtype,
693
+ except supported FP8 K/V cache staging modes.
694
+
695
+ Returns
696
+ -------
697
+ torch.Tensor or tuple[torch.Tensor, torch.Tensor]
698
+ Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE
699
+ outputs have shape ``[total_q, Hq]`` and dtype float32.
700
+
701
+ Notes
702
+ -----
703
+ ``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support
704
+ head dimension 128 only.
705
+ """
706
+ if softmax_scale is None:
707
+ softmax_scale = q.shape[-1] ** -0.5
708
+ lse_temperature_scale = float(lse_temperature_scale)
709
+ if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
710
+ raise ValueError(
711
+ f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
712
+ )
713
+ return_temperature_lse = bool(return_temperature_lse)
714
+ if return_temperature_lse and not return_softmax_lse:
715
+ raise ValueError("return_temperature_lse=True requires return_softmax_lse=True")
716
+ partial_dtype = _normalize_partial_dtype(partial_dtype)
717
+ qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype)
718
+
719
+ if cu_seqlens_q is None or cu_seqlens_k is None:
720
+ raise ValueError(
721
+ "sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k"
722
+ )
723
+ batch, head_kv = _validate_csr_varlen_inputs(
724
+ q,
725
+ k,
726
+ v,
727
+ k2q_row_ptr,
728
+ k2q_q_indices,
729
+ topK,
730
+ blk_kv,
731
+ page_table,
732
+ cu_seqlens_q,
733
+ cu_seqlens_k,
734
+ seqused_k,
735
+ )
736
+ max_seqlen_q = int(max_seqlen_q)
737
+ max_seqlen_k = int(max_seqlen_k)
738
+
739
+ return _sparse_atten_csr_varlen_forward(
740
+ q.contiguous(),
741
+ k.contiguous(),
742
+ v.contiguous(),
743
+ k2q_row_ptr.contiguous(),
744
+ k2q_q_indices.contiguous(),
745
+ int(topK),
746
+ int(blk_kv),
747
+ bool(causal),
748
+ float(softmax_scale),
749
+ lse_temperature_scale,
750
+ return_temperature_lse,
751
+ partial_dtype,
752
+ bool(return_softmax_lse),
753
+ cu_seqlens_q.contiguous(),
754
+ cu_seqlens_k.contiguous(),
755
+ None if page_table is None else page_table.contiguous(),
756
+ None if seqused_k is None else seqused_k.contiguous(),
757
+ schedule,
758
+ int(usable_SM_count),
759
+ int(batch),
760
+ int(head_kv),
761
+ int(max_seqlen_q),
762
+ int(max_seqlen_k),
763
+ qk_dtype,
764
+ pv_dtype,
765
+ )
766
+
767
+
768
+ def sparse_atten_nvfp4_kv_func(
769
+ q: torch.Tensor,
770
+ k: torch.Tensor,
771
+ v: torch.Tensor,
772
+ k_scale_128x4: torch.Tensor,
773
+ v_scale_128x4: torch.Tensor,
774
+ k_global_scale: Optional[torch.Tensor],
775
+ v_global_scale: Optional[torch.Tensor],
776
+ k2q_row_ptr: torch.Tensor,
777
+ k2q_q_indices: torch.Tensor,
778
+ topK: int,
779
+ *,
780
+ cu_seqlens_q: torch.Tensor,
781
+ cu_seqlens_k: torch.Tensor,
782
+ max_seqlen_q: int,
783
+ max_seqlen_k: int,
784
+ blk_kv: int = 128,
785
+ causal: bool = False,
786
+ softmax_scale: Optional[float] = None,
787
+ lse_temperature_scale: float = 1.0,
788
+ return_temperature_lse: bool = False,
789
+ partial_dtype: torch.dtype = torch.bfloat16,
790
+ return_softmax_lse: bool = False,
791
+ page_table: Optional[torch.Tensor] = None,
792
+ seqused_k: Optional[torch.Tensor] = None,
793
+ schedule: Optional[SparseAttentionSchedule] = None,
794
+ ):
795
+ """Run SM100 CSR sparse attention with packed NVFP4 K/V.
796
+
797
+ Parameters
798
+ ----------
799
+ q : torch.Tensor
800
+ Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and
801
+ FP8 E4M3.
802
+ k : torch.Tensor
803
+ Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged
804
+ layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8
805
+ because each byte packs two FP4 values.
806
+ v : torch.Tensor
807
+ Packed NVFP4 V data with the same shape as ``k``.
808
+ k_scale_128x4 : torch.Tensor
809
+ K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8
810
+ containing FP8 E4M3 scale values.
811
+ v_scale_128x4 : torch.Tensor
812
+ V block scales in the same 128x4 tiled storage.
813
+ k_global_scale : torch.Tensor, optional
814
+ FP32 tensor/global dequant scale for K. May be ``None``.
815
+ v_global_scale : torch.Tensor, optional
816
+ FP32 tensor/global dequant scale for V. May be ``None``. The V global
817
+ scale is applied in the combine stage.
818
+ k2q_row_ptr : torch.Tensor
819
+ CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32.
820
+ k2q_q_indices : torch.Tensor
821
+ CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype
822
+ int32.
823
+ topK : int
824
+ Number of selected KV blocks per query. Supported values are
825
+ ``4, 8, 16, 32``.
826
+ cu_seqlens_q, cu_seqlens_k : torch.Tensor
827
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV
828
+ lengths.
829
+ max_seqlen_q, max_seqlen_k : int
830
+ Maximum Q and KV sequence lengths in the batch.
831
+ blk_kv : int, optional
832
+ KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``.
833
+ causal : bool, optional
834
+ Whether to apply causal masking.
835
+ softmax_scale : float, optional
836
+ Softmax scale. Defaults to ``1 / sqrt(128)``.
837
+ lse_temperature_scale : float, optional
838
+ Extra divisor used only for temperature-scaled LSE output.
839
+ return_temperature_lse : bool, optional
840
+ If True, also return temperature-scaled LSE. Requires
841
+ ``return_softmax_lse=True``.
842
+ partial_dtype : torch.dtype, optional
843
+ Accumulation dtype for per-block partial O.
844
+ return_softmax_lse : bool, optional
845
+ If True, return LSE together with the output.
846
+ page_table : torch.Tensor, optional
847
+ Paged-KV physical page table with shape
848
+ ``[batch_size, max_num_pages_per_seq]`` and dtype int32.
849
+ seqused_k : torch.Tensor, optional
850
+ Effective KV length per request for paged causal attention.
851
+ schedule : SparseAttentionSchedule, optional
852
+ Prebuilt sparse forward schedule.
853
+
854
+ Returns
855
+ -------
856
+ torch.Tensor or tuple[torch.Tensor, torch.Tensor]
857
+ Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE
858
+ outputs have shape ``[total_q, Hq]`` and dtype float32.
859
+ """
860
+
861
+ if softmax_scale is None:
862
+ softmax_scale = q.shape[-1] ** -0.5
863
+ lse_temperature_scale = float(lse_temperature_scale)
864
+ if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
865
+ raise ValueError(
866
+ f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
867
+ )
868
+ return_temperature_lse = bool(return_temperature_lse)
869
+ if return_temperature_lse and not return_softmax_lse:
870
+ raise ValueError("return_temperature_lse=True requires return_softmax_lse=True")
871
+ partial_dtype = _normalize_partial_dtype(partial_dtype)
872
+
873
+ if cu_seqlens_q is None or cu_seqlens_k is None:
874
+ raise ValueError(
875
+ "sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k"
876
+ )
877
+ batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs(
878
+ q,
879
+ k,
880
+ v,
881
+ k_scale_128x4,
882
+ v_scale_128x4,
883
+ k_global_scale,
884
+ v_global_scale,
885
+ k2q_row_ptr,
886
+ k2q_q_indices,
887
+ topK,
888
+ blk_kv,
889
+ page_table,
890
+ cu_seqlens_q,
891
+ cu_seqlens_k,
892
+ seqused_k,
893
+ )
894
+ total_q, head_q, dim = q.shape
895
+ max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr)
896
+ temperature_lse_fast_path = (
897
+ return_temperature_lse
898
+ and math.isclose(
899
+ float(lse_temperature_scale),
900
+ 1.0,
901
+ rel_tol=0.0,
902
+ abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL,
903
+ )
904
+ )
905
+ kernel_return_temperature_lse = (
906
+ return_temperature_lse and not temperature_lse_fast_path
907
+ )
908
+
909
+ O_partial = torch.empty(
910
+ topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device
911
+ )
912
+ LSE_partial = torch.empty(
913
+ topK, total_q, head_q, dtype=torch.float32, device=q.device
914
+ )
915
+ LSE_temperature_partial = (
916
+ torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device)
917
+ if kernel_return_temperature_lse
918
+ else None
919
+ )
920
+ O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device)
921
+ LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device)
922
+ LSE_temperature_out = (
923
+ torch.empty_like(LSE_out) if kernel_return_temperature_lse else None
924
+ )
925
+ if schedule is None:
926
+ k2q_qsplit_indices = torch.empty_like(k2q_q_indices)
927
+ split_counts = torch.zeros(
928
+ (total_q, head_kv),
929
+ dtype=torch.int32,
930
+ device=q.device,
931
+ )
932
+ else:
933
+ _validate_fwd_schedule(
934
+ schedule,
935
+ q=q,
936
+ k2q_q_indices=k2q_q_indices,
937
+ head_kv=head_kv,
938
+ )
939
+ k2q_qsplit_indices = schedule.qsplit_indices
940
+ split_counts = schedule.split_counts
941
+
942
+ schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv(
943
+ q.contiguous(),
944
+ k.contiguous(),
945
+ v.contiguous(),
946
+ k_scale_128x4.contiguous(),
947
+ v_scale_128x4.contiguous(),
948
+ None if k_global_scale is None else k_global_scale.contiguous(),
949
+ None if v_global_scale is None else v_global_scale.contiguous(),
950
+ k2q_row_ptr.contiguous(),
951
+ k2q_q_indices.contiguous(),
952
+ k2q_qsplit_indices.contiguous(),
953
+ split_counts.contiguous(),
954
+ cu_seqlens_q.contiguous(),
955
+ cu_seqlens_k.contiguous(),
956
+ None if page_table is None else page_table.contiguous(),
957
+ None if seqused_k is None else seqused_k.contiguous(),
958
+ O_partial,
959
+ LSE_partial,
960
+ LSE_temperature_partial,
961
+ float(softmax_scale),
962
+ lse_temperature_scale,
963
+ kernel_return_temperature_lse,
964
+ max_num_kv_blocks,
965
+ int(blk_kv),
966
+ head_kv,
967
+ int(max_seqlen_q),
968
+ causal=bool(causal),
969
+ schedule=schedule,
970
+ )
971
+
972
+ combine(
973
+ O_partial,
974
+ LSE_partial,
975
+ O_out,
976
+ LSE_out,
977
+ lse_temperature_partial=LSE_temperature_partial,
978
+ lse_temperature_out=LSE_temperature_out,
979
+ cu_seqlens=cu_seqlens_q,
980
+ split_counts=split_counts,
981
+ output_scale=v_global_scale,
982
+ use_pdl=True,
983
+ )
984
+ if temperature_lse_fast_path:
985
+ LSE_temperature_out = LSE_out
986
+
987
+ if return_softmax_lse:
988
+ if return_temperature_lse:
989
+ return O_out, LSE_out, LSE_temperature_out
990
+ return O_out, LSE_out
991
+ return O_out
992
+
993
+
994
+ def sparse_decode_atten_func(
995
+ q: torch.Tensor,
996
+ k: torch.Tensor,
997
+ v: torch.Tensor,
998
+ q2k_indices: Optional[torch.Tensor] = None,
999
+ *,
1000
+ page_table: torch.Tensor,
1001
+ seqused_k: torch.Tensor,
1002
+ seqlen_q: int,
1003
+ max_seqlen_k: int,
1004
+ blk_kv: int = 128,
1005
+ causal: bool = True,
1006
+ softmax_scale: Optional[float] = None,
1007
+ return_softmax_lse: bool = False,
1008
+ schedule: Optional[DecodeAttentionSchedule] = None,
1009
+ O_partial: Optional[torch.Tensor] = None,
1010
+ LSE_partial: Optional[torch.Tensor] = None,
1011
+ ):
1012
+ """Run forward-only paged FP8 decode attention.
1013
+
1014
+ Parameters
1015
+ ----------
1016
+ q : torch.Tensor
1017
+ Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3.
1018
+ k : torch.Tensor
1019
+ Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8
1020
+ E4M3 dtype.
1021
+ v : torch.Tensor
1022
+ Paged V cache with the same shape and dtype as ``k``.
1023
+ q2k_indices : torch.Tensor, optional
1024
+ Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype
1025
+ int32. ``None`` selects the dense all-KV decode path.
1026
+ page_table : torch.Tensor
1027
+ Physical page table with shape ``[batch_size, max_num_pages_per_seq]``
1028
+ and dtype int32.
1029
+ seqused_k : torch.Tensor
1030
+ Shape ``[batch_size]``, dtype int32. Effective KV length per request.
1031
+ seqlen_q : int
1032
+ Uniform query length per request. Ragged Q lengths should use prefill
1033
+ or append paths instead.
1034
+ max_seqlen_k : int
1035
+ Maximum KV sequence length in the batch.
1036
+ blk_kv : int, optional
1037
+ Page size. Must match ``k.shape[2]``.
1038
+ causal : bool, optional
1039
+ Whether to apply causal masking. Current decode kernel requires True.
1040
+ softmax_scale : float, optional
1041
+ Softmax scale. Defaults to ``1 / sqrt(128)``.
1042
+ return_softmax_lse : bool, optional
1043
+ If True, return ``(out, lse)``.
1044
+ schedule : DecodeAttentionSchedule, optional
1045
+ Prebuilt decode schedule.
1046
+ O_partial, LSE_partial : torch.Tensor, optional
1047
+ Optional split-KV partial workspaces. Normally owned by
1048
+ ``SparseDecodePagedAttentionWrapper``.
1049
+
1050
+ Returns
1051
+ -------
1052
+ torch.Tensor or tuple[torch.Tensor, torch.Tensor]
1053
+ BF16 output with shape ``q.shape``. Optional LSE has shape
1054
+ ``[batch_size * seqlen_q, Hq]`` and dtype float32.
1055
+ """
1056
+ if softmax_scale is None:
1057
+ softmax_scale = q.shape[-1] ** -0.5
1058
+ batch, head_kv = _validate_sparse_decode_inputs(
1059
+ q,
1060
+ k,
1061
+ v,
1062
+ q2k_indices,
1063
+ page_table=page_table,
1064
+ seqused_k=seqused_k,
1065
+ seqlen_q=seqlen_q,
1066
+ max_seqlen_k=max_seqlen_k,
1067
+ blk_kv=blk_kv,
1068
+ causal=causal,
1069
+ )
1070
+ head_q = int(q.shape[1])
1071
+ head_dim = int(q.shape[2])
1072
+ if schedule is None:
1073
+ schedule = prepare_decode_schedule(
1074
+ seqused_k=seqused_k.contiguous(),
1075
+ page_size=int(blk_kv),
1076
+ seqlen_q=int(seqlen_q),
1077
+ num_qo_heads=head_q,
1078
+ num_kv_heads=head_kv,
1079
+ head_dim=head_dim,
1080
+ max_seqlen_k=int(max_seqlen_k),
1081
+ )
1082
+ if schedule.split_kv:
1083
+ if O_partial is None:
1084
+ O_partial = torch.empty(
1085
+ (schedule.partial_rows, head_q, head_dim),
1086
+ dtype=torch.float32,
1087
+ device=q.device,
1088
+ )
1089
+ if LSE_partial is None:
1090
+ LSE_partial = torch.empty(
1091
+ (schedule.partial_rows, head_q),
1092
+ dtype=torch.float32,
1093
+ device=q.device,
1094
+ )
1095
+ out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device)
1096
+ lse = torch.empty(
1097
+ q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q),
1098
+ dtype=torch.float32,
1099
+ device=q.device,
1100
+ )
1101
+ _call_sparse_decode_forward_sm100_paged_fp8(
1102
+ q.contiguous(),
1103
+ k.contiguous(),
1104
+ v.contiguous(),
1105
+ None if q2k_indices is None else q2k_indices.contiguous(),
1106
+ page_table.contiguous(),
1107
+ seqused_k.contiguous(),
1108
+ out,
1109
+ lse,
1110
+ schedule,
1111
+ O_partial,
1112
+ LSE_partial,
1113
+ softmax_scale=float(softmax_scale),
1114
+ seqlen_q=int(seqlen_q),
1115
+ max_seqlen_k=int(max_seqlen_k),
1116
+ blk_kv=int(blk_kv),
1117
+ causal=bool(causal),
1118
+ return_lse=bool(return_softmax_lse),
1119
+ )
1120
+ if return_softmax_lse:
1121
+ return out, lse
1122
+ return out
1123
+
1124
+
1125
+ class SparseDecodePagedAttentionWrapper:
1126
+ """Plan/run helper for paged FP8 decode attention.
1127
+
1128
+ Use this wrapper when the same page table shape and sequence metadata are
1129
+ reused across multiple decode layers. ``plan`` validates metadata and
1130
+ allocates persistent schedules/workspaces; ``run`` then launches the decode
1131
+ kernel with lower per-call overhead than ``sparse_decode_atten_func``.
1132
+ """
1133
+
1134
+ def __init__(self, *, blk_kv: int = 128, causal: bool = True):
1135
+ self.blk_kv = int(blk_kv)
1136
+ self.causal = bool(causal)
1137
+ self.batch: Optional[int] = None
1138
+ self.num_qo_heads: Optional[int] = None
1139
+ self.num_kv_heads: Optional[int] = None
1140
+ self.head_dim: Optional[int] = None
1141
+ self.page_table: Optional[torch.Tensor] = None
1142
+ self.seqused_k: Optional[torch.Tensor] = None
1143
+ self.q2k_indices: Optional[torch.Tensor] = None
1144
+ self.seqlen_q: Optional[int] = None
1145
+ self.max_seqlen_k: Optional[int] = None
1146
+ self.is_sparse: bool = False
1147
+ self.decode_schedule: Optional[DecodeAttentionSchedule] = None
1148
+ self.request_indices: Optional[torch.Tensor] = None
1149
+ self.qo_tile_indices: Optional[torch.Tensor] = None
1150
+ self.kv_tile_indices: Optional[torch.Tensor] = None
1151
+ self.merge_indptr: Optional[torch.Tensor] = None
1152
+ self.o_indptr: Optional[torch.Tensor] = None
1153
+ self.block_valid_mask: Optional[torch.Tensor] = None
1154
+ self.kv_pages: Optional[torch.Tensor] = None
1155
+ self.split_counts: Optional[torch.Tensor] = None
1156
+ self.split_kv: bool = False
1157
+ self.cta_tile_q: int = 0
1158
+ self.num_q_tiles: int = 0
1159
+ self.kv_chunk_size_pages: int = 0
1160
+ self.kv_chunk_size_tokens: int = 0
1161
+ self.work_count: int = 0
1162
+ self.padded_work_count: int = 0
1163
+ self.O_partial: Optional[torch.Tensor] = None
1164
+ self.LSE_partial: Optional[torch.Tensor] = None
1165
+ # Cached dummy buffers used in non-split path to satisfy the kernel's
1166
+ # positional arg signature without per-call torch.empty (saves ~5us
1167
+ # on every run() for small kv).
1168
+ self._O_partial_dummy: Optional[torch.Tensor] = None
1169
+ self._LSE_partial_dummy: Optional[torch.Tensor] = None
1170
+ # When the caller doesn't ask for LSE, the kernel still needs a valid
1171
+ # tensor pointer to write to. Cache a small placeholder so run() can
1172
+ # skip the per-call torch.empty for it as well.
1173
+ self._lse_dummy: Optional[torch.Tensor] = None
1174
+
1175
+ def plan(
1176
+ self,
1177
+ *,
1178
+ page_table: torch.Tensor,
1179
+ seqused_k: torch.Tensor,
1180
+ seqlen_q: int,
1181
+ max_seqlen_k: int,
1182
+ q2k_indices: Optional[torch.Tensor] = None,
1183
+ num_qo_heads: Optional[int] = None,
1184
+ num_kv_heads: Optional[int] = None,
1185
+ head_dim: Optional[int] = 128,
1186
+ enable_cuda_graph: bool = False,
1187
+ max_grid_size: Optional[int] = None,
1188
+ fixed_split_size: Optional[int] = None,
1189
+ disable_split_kv: bool = False,
1190
+ ) -> "SparseDecodePagedAttentionWrapper":
1191
+ """Prepare decode scheduling metadata and reusable workspaces.
1192
+
1193
+ Parameters
1194
+ ----------
1195
+ page_table : torch.Tensor
1196
+ Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps
1197
+ logical pages to physical KV-cache pages.
1198
+ seqused_k : torch.Tensor
1199
+ Shape ``[batch_size]``, dtype int32. Effective KV length per
1200
+ request.
1201
+ seqlen_q : int
1202
+ Uniform query length per request.
1203
+ max_seqlen_k : int
1204
+ Maximum KV sequence length in the batch.
1205
+ q2k_indices : torch.Tensor, optional
1206
+ Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and
1207
+ dtype int32. ``None`` selects the dense all-KV path.
1208
+ num_qo_heads : int
1209
+ Number of Q/O heads.
1210
+ num_kv_heads : int
1211
+ Number of KV heads. Current decode kernel requires
1212
+ ``num_qo_heads / num_kv_heads == 16`` at run time.
1213
+ head_dim : int, optional
1214
+ Head dimension. Must be 128.
1215
+ enable_cuda_graph : bool, optional
1216
+ Build schedule metadata compatible with CUDA graph capture.
1217
+ max_grid_size : int, optional
1218
+ Override maximum CTA count used by the scheduler.
1219
+ fixed_split_size : int, optional
1220
+ Force a fixed split-KV chunk size in pages.
1221
+ disable_split_kv : bool, optional
1222
+ Disable split-KV even for long KV sequences.
1223
+
1224
+ Returns
1225
+ -------
1226
+ SparseDecodePagedAttentionWrapper
1227
+ ``self``, planned and ready for ``run``.
1228
+ """
1229
+ if page_table.ndim != 2:
1230
+ raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]")
1231
+ if page_table.dtype != torch.int32:
1232
+ raise TypeError("decode plan requires page_table to be torch.int32")
1233
+ if seqused_k.dtype != torch.int32:
1234
+ raise TypeError("decode plan requires seqused_k to be torch.int32")
1235
+ if not page_table.is_cuda or not seqused_k.is_cuda:
1236
+ raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors")
1237
+ if page_table.device != seqused_k.device:
1238
+ raise ValueError("decode plan requires page_table and seqused_k on the same device")
1239
+ if page_table.stride(-1) != 1:
1240
+ raise ValueError("decode plan requires page_table contiguous in the last dimension")
1241
+ if seqused_k.shape != (int(page_table.shape[0]),):
1242
+ raise ValueError("decode plan requires seqused_k with shape [B]")
1243
+ if q2k_indices is not None and q2k_indices.dtype != torch.int32:
1244
+ raise TypeError("decode plan requires q2k_indices to be torch.int32")
1245
+ if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0:
1246
+ raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k")
1247
+ if num_qo_heads is None or num_kv_heads is None or head_dim is None:
1248
+ raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim")
1249
+ if head_dim is not None and int(head_dim) != 128:
1250
+ raise NotImplementedError("decode plan currently supports only head_dim=128")
1251
+ if int(num_qo_heads) % int(num_kv_heads) != 0:
1252
+ raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads")
1253
+
1254
+ self.batch = int(page_table.shape[0])
1255
+ self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads)
1256
+ self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads)
1257
+ self.head_dim = None if head_dim is None else int(head_dim)
1258
+ self.page_table = page_table.contiguous()
1259
+ self.seqused_k = seqused_k.contiguous()
1260
+ self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous()
1261
+ self.seqlen_q = int(seqlen_q)
1262
+ self.max_seqlen_k = int(max_seqlen_k)
1263
+ self.is_sparse = q2k_indices is not None
1264
+
1265
+ # max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++
1266
+ # schedule launcher because the decode attn kernel always runs at
1267
+ # 1 CTA/SM (its register/smem budget saturates the SM). Callers
1268
+ # can still override via the explicit max_grid_size kwarg.
1269
+ schedule = prepare_decode_schedule(
1270
+ seqused_k=self.seqused_k,
1271
+ page_size=self.blk_kv,
1272
+ seqlen_q=self.seqlen_q,
1273
+ num_qo_heads=self.num_qo_heads,
1274
+ num_kv_heads=self.num_kv_heads,
1275
+ head_dim=self.head_dim,
1276
+ max_seqlen_k=self.max_seqlen_k,
1277
+ enable_cuda_graph=bool(enable_cuda_graph),
1278
+ max_grid_size=max_grid_size,
1279
+ fixed_split_size=fixed_split_size,
1280
+ disable_split_kv=bool(disable_split_kv),
1281
+ )
1282
+ self.decode_schedule = schedule
1283
+ self.request_indices = schedule.request_indices
1284
+ self.qo_tile_indices = schedule.qo_tile_indices
1285
+ self.kv_tile_indices = schedule.kv_tile_indices
1286
+ self.merge_indptr = schedule.merge_indptr
1287
+ self.o_indptr = schedule.o_indptr
1288
+ self.block_valid_mask = schedule.block_valid_mask
1289
+ self.kv_pages = schedule.kv_pages
1290
+ self.split_counts = schedule.split_counts
1291
+ self.split_kv = schedule.split_kv
1292
+ self.cta_tile_q = schedule.cta_tile_q
1293
+ self.num_q_tiles = schedule.num_q_tiles
1294
+ self.kv_chunk_size_pages = schedule.kv_chunk_size_pages
1295
+ self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens
1296
+ self.work_count = schedule.work_count
1297
+ self.padded_work_count = schedule.padded_work_count
1298
+ if schedule.split_kv:
1299
+ self.O_partial = torch.empty(
1300
+ (schedule.partial_rows, self.num_qo_heads, self.head_dim),
1301
+ dtype=torch.float32,
1302
+ device=page_table.device,
1303
+ )
1304
+ self.LSE_partial = torch.empty(
1305
+ (schedule.partial_rows, self.num_qo_heads),
1306
+ dtype=torch.float32,
1307
+ device=page_table.device,
1308
+ )
1309
+ self._O_partial_dummy = None
1310
+ self._LSE_partial_dummy = None
1311
+ else:
1312
+ self.O_partial = None
1313
+ self.LSE_partial = None
1314
+ # decode_forward_paged_fp8 always wants non-None partial buffers
1315
+ # for the kernel's positional arg layout (compile keeps the slot
1316
+ # alive even when split_kv=False). Allocate once here and reuse.
1317
+ self._O_partial_dummy = torch.empty(
1318
+ (1, self.head_dim),
1319
+ dtype=torch.float32,
1320
+ device=page_table.device,
1321
+ )
1322
+ self._LSE_partial_dummy = torch.empty(
1323
+ (1, self.num_qo_heads),
1324
+ dtype=torch.float32,
1325
+ device=page_table.device,
1326
+ )
1327
+ # LSE dummy is shape (1, head_q) — used when caller doesn't request
1328
+ # LSE and the schedule isn't split-KV (split-KV always writes LSE).
1329
+ self._lse_dummy = torch.empty(
1330
+ (1, self.num_qo_heads),
1331
+ dtype=torch.float32,
1332
+ device=page_table.device,
1333
+ )
1334
+ return self
1335
+
1336
+ def run(
1337
+ self,
1338
+ q: torch.Tensor,
1339
+ k: torch.Tensor,
1340
+ v: torch.Tensor,
1341
+ *,
1342
+ softmax_scale: Optional[float] = None,
1343
+ return_softmax_lse: bool = False,
1344
+ out: Optional[torch.Tensor] = None,
1345
+ lse: Optional[torch.Tensor] = None,
1346
+ ):
1347
+ """Launch decode using metadata cached by ``plan``.
1348
+
1349
+ Parameters
1350
+ ----------
1351
+ q : torch.Tensor
1352
+ Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3.
1353
+ k : torch.Tensor
1354
+ Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``.
1355
+ v : torch.Tensor
1356
+ Paged V cache with the same shape as ``k``.
1357
+ softmax_scale : float, optional
1358
+ Softmax scale. Defaults to ``1 / sqrt(128)``.
1359
+ return_softmax_lse : bool, optional
1360
+ If True, return ``(out, lse)``.
1361
+ out : torch.Tensor, optional
1362
+ Preallocated BF16 output buffer with shape ``q.shape``.
1363
+ lse : torch.Tensor, optional
1364
+ Preallocated float32 LSE buffer with shape ``[total_q, Hq]``.
1365
+
1366
+ Returns
1367
+ -------
1368
+ torch.Tensor or tuple[torch.Tensor, torch.Tensor]
1369
+ BF16 output, optionally with float32 LSE.
1370
+ """
1371
+ if self.decode_schedule is None:
1372
+ raise RuntimeError("decode wrapper must be planned before run")
1373
+ if self.is_sparse:
1374
+ # Sparse path still goes through the validating wrapper for now;
1375
+ # only the dense fast path is collapsed.
1376
+ return sparse_decode_atten_func(
1377
+ q, k, v, self.q2k_indices,
1378
+ page_table=self.page_table, seqused_k=self.seqused_k,
1379
+ seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k,
1380
+ blk_kv=self.blk_kv, causal=self.causal,
1381
+ softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse,
1382
+ schedule=self.decode_schedule,
1383
+ O_partial=self.O_partial, LSE_partial=self.LSE_partial,
1384
+ )
1385
+
1386
+ if softmax_scale is None:
1387
+ softmax_scale = q.shape[-1] ** -0.5
1388
+ if out is None:
1389
+ out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device)
1390
+ if lse is None:
1391
+ if return_softmax_lse or self.split_kv:
1392
+ # Real LSE needed — must allocate per-call (shape depends on q).
1393
+ lse = torch.empty(
1394
+ q.shape[:2], dtype=torch.float32, device=q.device,
1395
+ )
1396
+ else:
1397
+ # Kernel only needs a valid pointer; reuse cached dummy.
1398
+ lse = self._lse_dummy
1399
+ from .src.sm100.fwd_decode import decode_forward_paged_fp8
1400
+ schedule = self.decode_schedule
1401
+ decode_forward_paged_fp8(
1402
+ q, k, v,
1403
+ self.page_table, self.seqused_k,
1404
+ out, lse,
1405
+ schedule.request_indices, schedule.qo_tile_indices,
1406
+ schedule.kv_tile_indices, schedule.block_valid_mask,
1407
+ schedule.split_counts, schedule.o_indptr, schedule.merge_indptr,
1408
+ self.O_partial, self.LSE_partial,
1409
+ softmax_scale=float(softmax_scale),
1410
+ seqlen_q=self.seqlen_q,
1411
+ page_size=self.blk_kv,
1412
+ kv_chunk_size_pages=int(schedule.kv_chunk_size_pages),
1413
+ max_split_count=int(schedule.max_split_count),
1414
+ split_kv=bool(schedule.split_kv),
1415
+ causal=self.causal,
1416
+ return_lse=bool(return_softmax_lse),
1417
+ # cached dummies — avoid per-call torch.empty inside run_decode_attention
1418
+ O_partial_dummy=self._O_partial_dummy,
1419
+ LSE_partial_dummy=self._LSE_partial_dummy,
1420
+ )
1421
+ if return_softmax_lse:
1422
+ return out, lse
1423
+ return out
1424
+
1425
+
1426
+ def _sparse_atten_csr_varlen_forward(
1427
+ q: torch.Tensor,
1428
+ k: torch.Tensor,
1429
+ v: torch.Tensor,
1430
+ k2q_row_ptr: torch.Tensor,
1431
+ k2q_q_indices: torch.Tensor,
1432
+ topK: int,
1433
+ blk_kv: int,
1434
+ causal: bool,
1435
+ softmax_scale: float,
1436
+ lse_temperature_scale: float,
1437
+ return_temperature_lse: bool,
1438
+ partial_dtype: torch.dtype,
1439
+ return_softmax_lse: bool,
1440
+ cu_seqlens_q: torch.Tensor,
1441
+ cu_seqlens_k: torch.Tensor,
1442
+ page_table: Optional[torch.Tensor],
1443
+ seqused_k: Optional[torch.Tensor],
1444
+ schedule: Optional[SparseAttentionSchedule],
1445
+ usable_SM_count: int,
1446
+ batch: int,
1447
+ head_kv: int,
1448
+ max_seqlen_q: int,
1449
+ max_seqlen_k: int,
1450
+ qk_dtype: torch.dtype,
1451
+ pv_dtype: torch.dtype,
1452
+ ):
1453
+ total_q, head_q, dim = q.shape
1454
+ if head_q % head_kv != 0:
1455
+ raise ValueError("q.shape[1] must be divisible by head_kv")
1456
+ max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr)
1457
+ temperature_lse_fast_path = (
1458
+ return_temperature_lse
1459
+ and math.isclose(
1460
+ float(lse_temperature_scale),
1461
+ 1.0,
1462
+ rel_tol=0.0,
1463
+ abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL,
1464
+ )
1465
+ )
1466
+ kernel_return_temperature_lse = (
1467
+ return_temperature_lse and not temperature_lse_fast_path
1468
+ )
1469
+
1470
+ O_partial = torch.empty(
1471
+ topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device
1472
+ )
1473
+ LSE_partial = torch.empty(
1474
+ topK, total_q, head_q, dtype=torch.float32, device=q.device
1475
+ )
1476
+ LSE_temperature_partial = (
1477
+ torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device)
1478
+ if kernel_return_temperature_lse
1479
+ else None
1480
+ )
1481
+ O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device)
1482
+ LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device)
1483
+ LSE_temperature_out = (
1484
+ torch.empty_like(LSE_out) if kernel_return_temperature_lse else None
1485
+ )
1486
+ if schedule is None:
1487
+ k2q_qsplit_indices = torch.empty_like(k2q_q_indices)
1488
+ split_counts = torch.zeros(
1489
+ (total_q, head_kv),
1490
+ dtype=torch.int32,
1491
+ device=q.device,
1492
+ )
1493
+ else:
1494
+ _validate_fwd_schedule(
1495
+ schedule,
1496
+ q=q,
1497
+ k2q_q_indices=k2q_q_indices,
1498
+ head_kv=head_kv,
1499
+ )
1500
+ k2q_qsplit_indices = schedule.qsplit_indices
1501
+ split_counts = schedule.split_counts
1502
+ schedule = _call_sparse_forward_sm100_csr_varlen(
1503
+ q,
1504
+ k,
1505
+ v,
1506
+ k2q_row_ptr,
1507
+ k2q_q_indices,
1508
+ k2q_qsplit_indices,
1509
+ split_counts,
1510
+ cu_seqlens_q,
1511
+ cu_seqlens_k,
1512
+ page_table,
1513
+ seqused_k,
1514
+ O_partial,
1515
+ LSE_partial,
1516
+ LSE_temperature_partial,
1517
+ softmax_scale,
1518
+ lse_temperature_scale,
1519
+ kernel_return_temperature_lse,
1520
+ max_num_kv_blocks,
1521
+ blk_kv,
1522
+ head_kv,
1523
+ max_seqlen_q,
1524
+ usable_SM_count,
1525
+ causal=causal,
1526
+ schedule=schedule,
1527
+ qk_dtype=qk_dtype,
1528
+ pv_dtype=pv_dtype,
1529
+ )
1530
+ # Sparse Attention and Sparse Page Attention both use the varlen-Q
1531
+ # combine path; the kernel-written LSE_out is the final contract.
1532
+ combine(
1533
+ O_partial,
1534
+ LSE_partial,
1535
+ O_out,
1536
+ LSE_out,
1537
+ lse_temperature_partial=LSE_temperature_partial,
1538
+ lse_temperature_out=LSE_temperature_out,
1539
+ cu_seqlens=cu_seqlens_q,
1540
+ split_counts=split_counts,
1541
+ use_pdl=True,
1542
+ )
1543
+ if temperature_lse_fast_path:
1544
+ LSE_temperature_out = LSE_out
1545
+
1546
+ if return_softmax_lse:
1547
+ if return_temperature_lse:
1548
+ return O_out, LSE_out, LSE_temperature_out
1549
+ return O_out, LSE_out
1550
+ return O_out
1551
+
1552
+
1553
+ def _call_sparse_decode_forward_sm100_paged_fp8(
1554
+ q: torch.Tensor,
1555
+ k: torch.Tensor,
1556
+ v: torch.Tensor,
1557
+ q2k_indices: Optional[torch.Tensor],
1558
+ page_table: torch.Tensor,
1559
+ seqused_k: torch.Tensor,
1560
+ out: torch.Tensor,
1561
+ lse: torch.Tensor,
1562
+ schedule: DecodeAttentionSchedule,
1563
+ O_partial: Optional[torch.Tensor],
1564
+ LSE_partial: Optional[torch.Tensor],
1565
+ *,
1566
+ softmax_scale: float,
1567
+ seqlen_q: int,
1568
+ max_seqlen_k: int,
1569
+ blk_kv: int,
1570
+ causal: bool,
1571
+ return_lse: bool = True,
1572
+ ) -> None:
1573
+ """Compile and launch the SM100 paged fp8 decode forward kernel.
1574
+
1575
+ Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse
1576
+ the same schedule wrapper but needs a separate q2k gather path.
1577
+ """
1578
+ if q2k_indices is not None:
1579
+ raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet")
1580
+ if schedule.cta_tile_q != 128:
1581
+ raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}")
1582
+ if schedule.split_kv:
1583
+ if O_partial is None or LSE_partial is None:
1584
+ raise ValueError("split decode forward requires O_partial and LSE_partial")
1585
+ if O_partial.dtype != torch.float32:
1586
+ raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}")
1587
+ if LSE_partial.dtype != torch.float32:
1588
+ raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}")
1589
+
1590
+ from .src.sm100.fwd_decode import decode_forward_paged_fp8
1591
+
1592
+ decode_forward_paged_fp8(
1593
+ q,
1594
+ k,
1595
+ v,
1596
+ page_table,
1597
+ seqused_k,
1598
+ out,
1599
+ lse,
1600
+ schedule.request_indices,
1601
+ schedule.qo_tile_indices,
1602
+ schedule.kv_tile_indices,
1603
+ schedule.block_valid_mask,
1604
+ schedule.split_counts,
1605
+ schedule.o_indptr,
1606
+ schedule.merge_indptr,
1607
+ O_partial,
1608
+ LSE_partial,
1609
+ softmax_scale=float(softmax_scale),
1610
+ seqlen_q=int(seqlen_q),
1611
+ page_size=int(blk_kv),
1612
+ kv_chunk_size_pages=int(schedule.kv_chunk_size_pages),
1613
+ max_split_count=int(schedule.max_split_count),
1614
+ split_kv=bool(schedule.split_kv),
1615
+ causal=bool(causal),
1616
+ return_lse=bool(return_lse),
1617
+ )
1618
+
1619
+
1620
+ def _call_sparse_forward_sm100_csr_varlen(
1621
+ q,
1622
+ k,
1623
+ v,
1624
+ k2q_row_ptr,
1625
+ k2q_q_indices,
1626
+ k2q_qsplit_indices,
1627
+ split_counts,
1628
+ cu_seqlens_q,
1629
+ cu_seqlens_k,
1630
+ page_table,
1631
+ seqused_k,
1632
+ O_partial,
1633
+ LSE_partial,
1634
+ LSE_temperature_partial,
1635
+ softmax_scale,
1636
+ lse_temperature_scale,
1637
+ return_temperature_lse,
1638
+ max_num_kv_blocks,
1639
+ blk_kv,
1640
+ head_kv,
1641
+ max_seqlen_q,
1642
+ usable_SM_count=-1,
1643
+ *,
1644
+ causal=False,
1645
+ use_prepare_scheduler=True,
1646
+ schedule: Optional[SparseAttentionSchedule] = None,
1647
+ qk_dtype: torch.dtype,
1648
+ pv_dtype: torch.dtype,
1649
+ ):
1650
+ """Compile and launch the SM100 sparse forward K1 kernel on CSR metadata."""
1651
+ head_dim = q.shape[-1]
1652
+ dtype = q.dtype
1653
+ qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype")
1654
+ pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype")
1655
+ partial_dtype = O_partial.dtype
1656
+ return_temperature_lse = bool(return_temperature_lse)
1657
+ if return_temperature_lse != (LSE_temperature_partial is not None):
1658
+ raise ValueError(
1659
+ "return_temperature_lse must match LSE_temperature_partial presence"
1660
+ )
1661
+ lse_temperature_scale = float(lse_temperature_scale)
1662
+ if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
1663
+ raise ValueError(
1664
+ f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
1665
+ )
1666
+ lse_temperature_inv_scale = 1.0 / lse_temperature_scale
1667
+ n_block_size = int(blk_kv)
1668
+ head_q = q.shape[1]
1669
+ qhead_per_kv = head_q // head_kv
1670
+ paged_kv = page_table is not None
1671
+ if not bool(use_prepare_scheduler):
1672
+ raise RuntimeError("sparse forward requires prepare scheduler")
1673
+ schedule_enabled = k2q_row_ptr.shape[1] > 1
1674
+ page_size = int(k.shape[2]) if paged_kv else None
1675
+ if paged_kv:
1676
+ k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size)
1677
+ else:
1678
+ k_kernel = k
1679
+ v_kernel = v
1680
+ O_partial_flat = O_partial.reshape(-1, head_dim).contiguous()
1681
+ Q_flat = q.reshape(-1, head_dim).contiguous()
1682
+ Q_gather4_desc = (
1683
+ create_q_gather4_tma_desc(
1684
+ Q_flat,
1685
+ box_x=128 if q.dtype == torch.float8_e4m3fn else 64,
1686
+ )
1687
+ if qhead_per_kv in (1, 2, 4)
1688
+ else None
1689
+ )
1690
+ if schedule is None:
1691
+ schedule = prepare_sparse_fwd_schedule_and_split(
1692
+ k2q_row_ptr=k2q_row_ptr,
1693
+ k2q_q_indices=k2q_q_indices,
1694
+ k2q_qsplit_indices=k2q_qsplit_indices,
1695
+ split_counts=split_counts,
1696
+ cu_seqlens_q=cu_seqlens_q,
1697
+ cu_seqlens_k=cu_seqlens_k,
1698
+ total_q=int(q.shape[0]),
1699
+ max_seqlen_q=max_seqlen_q,
1700
+ topk=int(O_partial.shape[0]),
1701
+ head_kv=head_kv,
1702
+ qhead_per_kv=qhead_per_kv,
1703
+ blk_kv=n_block_size,
1704
+ device=q.device,
1705
+ enabled=schedule_enabled,
1706
+ )
1707
+ use_prepare_scheduler = schedule.enabled
1708
+ scheduler_metadata = schedule.scheduler_metadata
1709
+ work_count = schedule.work_count
1710
+ work_capacity = schedule.work_capacity
1711
+ if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0:
1712
+ raise RuntimeError("sparse forward requires a non-empty prepared schedule")
1713
+
1714
+ key = (
1715
+ "sparse_forward_sm100_csr_varlen",
1716
+ head_dim,
1717
+ n_block_size,
1718
+ qhead_per_kv,
1719
+ dtype,
1720
+ k.dtype,
1721
+ v.dtype,
1722
+ qk_dtype,
1723
+ pv_dtype,
1724
+ partial_dtype,
1725
+ bool(causal),
1726
+ bool(paged_kv),
1727
+ bool(use_prepare_scheduler),
1728
+ page_size,
1729
+ bool(seqused_k is not None),
1730
+ bool(return_temperature_lse),
1731
+ )
1732
+ if key not in _compile_cache:
1733
+ from .src.common.aot_cache import try_load_aot, save_aot
1734
+
1735
+ loaded = try_load_aot(key)
1736
+ if loaded is not None:
1737
+ _compile_cache[key] = loaded
1738
+ else:
1739
+ kernel = SparseAttentionForwardSm100(
1740
+ head_dim=head_dim,
1741
+ qheadperkv=qhead_per_kv,
1742
+ n_block_size=n_block_size,
1743
+ paged_kv=paged_kv,
1744
+ page_size=page_size,
1745
+ has_seqused_k=seqused_k is not None,
1746
+ causal=bool(causal),
1747
+ use_prepare_scheduler=use_prepare_scheduler,
1748
+ qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype),
1749
+ pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype),
1750
+ )
1751
+ _compile_cache[key] = cute.compile(
1752
+ kernel,
1753
+ to_cute_tensor_kvouter(k_kernel),
1754
+ to_cute_tensor_kvouter(v_kernel),
1755
+ to_cute_tensor_kvouter(k2q_q_indices),
1756
+ to_cute_tensor_kvouter(k2q_qsplit_indices),
1757
+ to_cute_tensor_kvouter(k2q_row_ptr),
1758
+ None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata),
1759
+ None if work_count is None else to_cute_tensor_kvouter(work_count),
1760
+ to_cute_tensor_kvouter(O_partial_flat),
1761
+ to_cute_tensor_kvouter(LSE_partial),
1762
+ None
1763
+ if LSE_temperature_partial is None
1764
+ else to_cute_tensor_kvouter(LSE_temperature_partial),
1765
+ to_cute_tensor_kvouter(Q_flat),
1766
+ None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc),
1767
+ None if page_table is None else to_cute_tensor_kvouter(page_table),
1768
+ None if seqused_k is None else to_cute_tensor_kvouter(seqused_k),
1769
+ to_cute_tensor_kvouter(cu_seqlens_q),
1770
+ to_cute_tensor_kvouter(cu_seqlens_k),
1771
+ Float32(softmax_scale),
1772
+ Float32(lse_temperature_inv_scale),
1773
+ Int32(max_num_kv_blocks),
1774
+ Int32(head_kv),
1775
+ Int32(max_seqlen_q),
1776
+ Int32(work_capacity),
1777
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
1778
+ options="--enable-tvm-ffi",
1779
+ )
1780
+ save_aot(key, _compile_cache[key])
1781
+
1782
+ with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"):
1783
+ _compile_cache[key](
1784
+ k_kernel,
1785
+ v_kernel,
1786
+ k2q_q_indices,
1787
+ k2q_qsplit_indices,
1788
+ k2q_row_ptr,
1789
+ scheduler_metadata,
1790
+ work_count,
1791
+ O_partial_flat,
1792
+ LSE_partial,
1793
+ LSE_temperature_partial,
1794
+ Q_flat,
1795
+ Q_gather4_desc,
1796
+ page_table,
1797
+ seqused_k,
1798
+ cu_seqlens_q,
1799
+ cu_seqlens_k,
1800
+ softmax_scale,
1801
+ lse_temperature_inv_scale,
1802
+ max_num_kv_blocks,
1803
+ head_kv,
1804
+ max_seqlen_q,
1805
+ work_capacity,
1806
+ )
1807
+ return schedule
1808
+
1809
+
1810
+ def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv(
1811
+ q,
1812
+ k,
1813
+ v,
1814
+ k_scale_128x4,
1815
+ v_scale_128x4,
1816
+ k_global_scale,
1817
+ v_global_scale,
1818
+ k2q_row_ptr,
1819
+ k2q_q_indices,
1820
+ k2q_qsplit_indices,
1821
+ split_counts,
1822
+ cu_seqlens_q,
1823
+ cu_seqlens_k,
1824
+ page_table,
1825
+ seqused_k,
1826
+ O_partial,
1827
+ LSE_partial,
1828
+ LSE_temperature_partial,
1829
+ softmax_scale,
1830
+ lse_temperature_scale,
1831
+ return_temperature_lse,
1832
+ max_num_kv_blocks,
1833
+ blk_kv,
1834
+ head_kv,
1835
+ max_seqlen_q,
1836
+ *,
1837
+ causal=False,
1838
+ use_prepare_scheduler=True,
1839
+ schedule: Optional[SparseAttentionSchedule] = None,
1840
+ ):
1841
+ """Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V."""
1842
+
1843
+ head_dim = q.shape[-1]
1844
+ dtype = q.dtype
1845
+ partial_dtype = O_partial.dtype
1846
+ return_temperature_lse = bool(return_temperature_lse)
1847
+ if return_temperature_lse != (LSE_temperature_partial is not None):
1848
+ raise ValueError(
1849
+ "return_temperature_lse must match LSE_temperature_partial presence"
1850
+ )
1851
+ lse_temperature_scale = float(lse_temperature_scale)
1852
+ if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
1853
+ raise ValueError(
1854
+ f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
1855
+ )
1856
+ lse_temperature_inv_scale = 1.0 / lse_temperature_scale
1857
+ n_block_size = int(blk_kv)
1858
+ head_q = q.shape[1]
1859
+ qhead_per_kv = head_q // head_kv
1860
+ fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0"
1861
+ k_global_scale_kernel = k_global_scale
1862
+ # V global scale is linear in the final output. Keep K1 on block-scale-only V
1863
+ # and apply the tensor scale once in K2 combine.
1864
+ v_global_scale_kernel = None
1865
+ has_k_global_scale = k_global_scale_kernel is not None
1866
+ has_v_global_scale = v_global_scale_kernel is not None
1867
+ paged_kv = page_table is not None
1868
+ if not bool(use_prepare_scheduler):
1869
+ raise RuntimeError("KVFP4 sparse forward requires prepare scheduler")
1870
+ schedule_enabled = k2q_row_ptr.shape[1] > 1
1871
+ page_size = int(k.shape[2]) if paged_kv else None
1872
+ if paged_kv:
1873
+ _prepare_paged_kv_for_tma(k, v, n_block_size)
1874
+ k_kernel = k
1875
+ v_kernel = v
1876
+ O_partial_flat = O_partial.reshape(-1, head_dim).contiguous()
1877
+ Q_flat = q.reshape(-1, head_dim).contiguous()
1878
+ Q_gather4_desc = (
1879
+ create_q_gather4_tma_desc(
1880
+ Q_flat,
1881
+ box_x=128 if q.dtype == torch.float8_e4m3fn else 64,
1882
+ )
1883
+ if qhead_per_kv in (1, 2, 4)
1884
+ else None
1885
+ )
1886
+ if schedule is None:
1887
+ schedule = prepare_sparse_fwd_schedule_and_split(
1888
+ k2q_row_ptr=k2q_row_ptr,
1889
+ k2q_q_indices=k2q_q_indices,
1890
+ k2q_qsplit_indices=k2q_qsplit_indices,
1891
+ split_counts=split_counts,
1892
+ cu_seqlens_q=cu_seqlens_q,
1893
+ cu_seqlens_k=cu_seqlens_k,
1894
+ total_q=int(q.shape[0]),
1895
+ max_seqlen_q=max_seqlen_q,
1896
+ topk=int(O_partial.shape[0]),
1897
+ head_kv=head_kv,
1898
+ qhead_per_kv=qhead_per_kv,
1899
+ blk_kv=n_block_size,
1900
+ device=q.device,
1901
+ enabled=schedule_enabled,
1902
+ )
1903
+ use_prepare_scheduler = schedule.enabled
1904
+ scheduler_metadata = schedule.scheduler_metadata
1905
+ work_count = schedule.work_count
1906
+ work_capacity = schedule.work_capacity
1907
+ if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0:
1908
+ raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule")
1909
+
1910
+ key = (
1911
+ "sparse_forward_sm100_csr_varlen_nvfp4_kv",
1912
+ head_dim,
1913
+ n_block_size,
1914
+ qhead_per_kv,
1915
+ dtype,
1916
+ partial_dtype,
1917
+ bool(causal),
1918
+ bool(paged_kv),
1919
+ bool(use_prepare_scheduler),
1920
+ page_size,
1921
+ bool(seqused_k is not None),
1922
+ bool(return_temperature_lse),
1923
+ bool(fp8_pair_dequant),
1924
+ bool(has_k_global_scale),
1925
+ bool(has_v_global_scale),
1926
+ )
1927
+ if key not in _compile_cache:
1928
+ from .src.common.aot_cache import try_load_aot, save_aot
1929
+
1930
+ loaded = try_load_aot(key)
1931
+ if loaded is not None:
1932
+ _compile_cache[key] = loaded
1933
+ else:
1934
+ kernel = SparseAttentionForwardNvfp4KvSm100(
1935
+ head_dim=head_dim,
1936
+ qheadperkv=qhead_per_kv,
1937
+ n_block_size=n_block_size,
1938
+ paged_kv=paged_kv,
1939
+ page_size=page_size,
1940
+ has_seqused_k=seqused_k is not None,
1941
+ causal=bool(causal),
1942
+ use_prepare_scheduler=use_prepare_scheduler,
1943
+ fp8_pair_dequant=bool(fp8_pair_dequant),
1944
+ has_k_global_scale=bool(has_k_global_scale),
1945
+ has_v_global_scale=bool(has_v_global_scale),
1946
+ )
1947
+ _compile_cache[key] = cute.compile(
1948
+ kernel,
1949
+ to_cute_tensor_kvouter(k_kernel),
1950
+ to_cute_tensor_kvouter(v_kernel),
1951
+ to_cute_tensor_kvouter(k_scale_128x4),
1952
+ to_cute_tensor_kvouter(v_scale_128x4),
1953
+ None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel),
1954
+ None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel),
1955
+ to_cute_tensor_kvouter(k2q_q_indices),
1956
+ to_cute_tensor_kvouter(k2q_qsplit_indices),
1957
+ to_cute_tensor_kvouter(k2q_row_ptr),
1958
+ None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata),
1959
+ None if work_count is None else to_cute_tensor_kvouter(work_count),
1960
+ to_cute_tensor_kvouter(O_partial_flat),
1961
+ to_cute_tensor_kvouter(LSE_partial),
1962
+ None
1963
+ if LSE_temperature_partial is None
1964
+ else to_cute_tensor_kvouter(LSE_temperature_partial),
1965
+ to_cute_tensor_kvouter(Q_flat),
1966
+ None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc),
1967
+ None if page_table is None else to_cute_tensor_kvouter(page_table),
1968
+ None if seqused_k is None else to_cute_tensor_kvouter(seqused_k),
1969
+ to_cute_tensor_kvouter(cu_seqlens_q),
1970
+ to_cute_tensor_kvouter(cu_seqlens_k),
1971
+ Float32(softmax_scale),
1972
+ Float32(lse_temperature_inv_scale),
1973
+ Int32(max_num_kv_blocks),
1974
+ Int32(head_kv),
1975
+ Int32(max_seqlen_q),
1976
+ Int32(work_capacity),
1977
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
1978
+ options="--enable-tvm-ffi",
1979
+ )
1980
+ save_aot(key, _compile_cache[key])
1981
+
1982
+ with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"):
1983
+ _compile_cache[key](
1984
+ k_kernel,
1985
+ v_kernel,
1986
+ k_scale_128x4,
1987
+ v_scale_128x4,
1988
+ k_global_scale_kernel,
1989
+ v_global_scale_kernel,
1990
+ k2q_q_indices,
1991
+ k2q_qsplit_indices,
1992
+ k2q_row_ptr,
1993
+ scheduler_metadata,
1994
+ work_count,
1995
+ O_partial_flat,
1996
+ LSE_partial,
1997
+ LSE_temperature_partial,
1998
+ Q_flat,
1999
+ Q_gather4_desc,
2000
+ page_table,
2001
+ seqused_k,
2002
+ cu_seqlens_q,
2003
+ cu_seqlens_k,
2004
+ softmax_scale,
2005
+ lse_temperature_inv_scale,
2006
+ max_num_kv_blocks,
2007
+ head_kv,
2008
+ max_seqlen_q,
2009
+ work_capacity,
2010
+ )
2011
+ return schedule
build/torch211-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "msa",
3
+ "id": "_msa_cuda_09d7851",
4
+ "version": 0,
5
+ "license": "other",
6
+ "upstream": "https://github.com/MiniMax-AI/MSA",
7
+ "python-depends": [
8
+ "tvm-ffi",
9
+ "nvidia-cutlass-dsl"
10
+ ],
11
+ "backend": {
12
+ "type": "cuda",
13
+ "archs": [
14
+ "10.0"
15
+ ]
16
+ },
17
+ "digest": {
18
+ "algorithm": "sha256",
19
+ "files": {
20
+ "__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=",
21
+ "_msa_cuda_09d7851.abi3.so": "jc2MhuUS893VrLlfb9ytPPqhV5u2+HSnFPugZuaHcWE=",
22
+ "_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=",
23
+ "fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=",
24
+ "interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=",
25
+ "msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=",
26
+ "quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=",
27
+ "quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=",
28
+ "quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=",
29
+ "quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=",
30
+ "quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=",
31
+ "quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=",
32
+ "quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=",
33
+ "sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=",
34
+ "src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=",
35
+ "src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=",
36
+ "src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=",
37
+ "src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=",
38
+ "src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=",
39
+ "src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=",
40
+ "src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=",
41
+ "src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=",
42
+ "src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=",
43
+ "src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=",
44
+ "src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=",
45
+ "src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=",
46
+ "src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=",
47
+ "src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=",
48
+ "src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=",
49
+ "src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=",
50
+ "src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=",
51
+ "src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=",
52
+ "src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=",
53
+ "src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=",
54
+ "src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=",
55
+ "src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=",
56
+ "src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=",
57
+ "src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=",
58
+ "src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=",
59
+ "src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=",
60
+ "src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=",
61
+ "src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=",
62
+ "src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=",
63
+ "src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=",
64
+ "src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=",
65
+ "src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=",
66
+ "src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=",
67
+ "src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=",
68
+ "src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw="
69
+ }
70
+ }
71
+ }
build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore ADDED
@@ -0,0 +1 @@
 
 
1
+ {"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHTDCCBtGgAwIBAgIUXQHYSDFOSO1tjFUUICxJvOGeZcMwCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA4WhcNMjYwNjMwMTc1NDA4WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPXM0K6Fgcg5CUSklxxl2csu3F3KVSv8zPaW2wSeCwTB487WjsTVM+EqcLz/LSKUD5XL4tCAc1+gFBa30H4iDgKOCBfAwggXsMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQUfsSvN2oaJ+OmV0cSOHDNe9Nc/qUwHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBigYKKwYBBAHWeQIEAgR8BHoAeAB2AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhlrEAAAQDAEcwRQIhAN6iYC5242Rjj5dTsIgyISVMIPYWL2i81TwWknEvZur+AiAt30f5Wif9ZHR/wsWh+ve5O9GtVpL2jPTURJTl0u2xMjAKBggqhkjOPQQDAwNpADBmAjEA4i2QuFAcvw5KQAQADHbn8kVwmCTVfjK5xdQ1bJEu5eVu4PY4Br1zC9GVk7p6opFmAjEAm7jnPQ2jC5BL90FIlwMdeEVPgNmR7svFEElrkQme43Rqt6pvdGksMAzAqaWXQFqT"},"tlogEntries":[{"logIndex":"2024793345","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841448","inclusionPromise":{"signedEntryTimestamp":"MEUCIQDoWovnRcuj8EsCnxn/h18ObLX1W2EowGsjOnjj31tjKgIgE1bqiVYG2avTTL3CutjFGVSxSQtlXFYWVfl+DRCyVUk="},"inclusionProof":{"logIndex":"1902889083","rootHash":"rTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=","treeSize":"1902889093","hashes":["o6DK+OhTtiUAKd3yIcR79MoEH+e/lGDEz7/klBOgQgQ=","QFE69AbxzyZT6lYixktLCZ3SnTobLI2F6l/FFy7U7bE=","euXxtVgM7AeowPy83tQZihH1C4RDec9dw20k4Rjy7X8=","mCF45aBQkD6Ga0kRgUZm/6GIWnlvuDEwC1rsiDj7r9A=","wCaOWjILsSS/Bc8GMCLLwZ/lR4z6kHhhDwjBR489Drg=","oREPAC441YAiXLkRB+S3slZaG/rywypoRAOWh9Onh28=","tdRUnZp2XzgIgMBhnUUzZKRYmgMR9VRE4EFRMnBcvN4=","SRE7OpzsmEEBrnt2NvwSO2YvAQJHxIzVKMjw7ssvt3A=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889093\nrTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=\n\n— rekor.sigstore.dev wNI9ajBFAiBuldB8XClfqbEMlZnWsMAPF1CWf+PfKW6kiBU0RaE3YwIhAKQGXPHErozLpsxzvdgVeeJVRUx9RGAtRP5qoXqfKhJm\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiIyMzVlYjhiNGYxZmIyOWIzZWU4OTNlNzI4ODU1NDc3N2E3YzE3ZTVhNzNkNDM3YTc0M2JlNzAxOGYyOWQ5OGI4In19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FVUNJQ1dkOUxlZ3ZSb0oxWDZIQUwway9SV1BvTG1sbS9YU3c3VXhOWmNpSFMwc0FpRUE3U1phSlJXVGlHdlJIWWh2d0pLS0RwRDVnRUNZT25GMGMzRURMT0VTOWNNPSIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFVSRU5EUW5SSFowRjNTVUpCWjBsVldGRklXVk5FUms5VFR6RjBha1pWVlVsRGVFcDJUMGRsV21OTmQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTkZkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFMFYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVZRV0Uwd1N6WkdaMk5uTlVOVlUydHNlSGhzTW1OemRUTkdNMHRXVTNZNGVsQmhWeklLZDFObFEzZFVRalE0TjFkcWMxUldUU3RGY1dOTWVpOU1VMHRWUkRWWVREUjBRMEZqTVN0blJrSmhNekJJTkdsRVowdFBRMEptUVhkbloxaHpUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlZtYzFOMkNrNHliMkZLSzA5dFZqQmpVMDlJUkU1bE9VNWpMM0ZWZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwWjFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT0VKSWIwRUtaVUZDTWtGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2hzY2tWQlFVRlJSQXBCUldOM1VsRkphRUZPTm1sWlF6VXlOREpTYW1vMVpGUnpTV2Q1U1ZOV1RVbFFXVmRNTW1rNE1WUjNWMnR1UlhaYWRYSXJRV2xCZERNd1pqVlhhV1k1Q2xwSVVpOTNjMWRvSzNabE5VODVSM1JXY0V3eWFsQlVWVkpLVkd3d2RUSjRUV3BCUzBKblozRm9hMnBQVUZGUlJFRjNUbkJCUkVKdFFXcEZRVFJwTWxFS2RVWkJZM1ozTlV0UlFWRkJSRWhpYmpoclZuZHRRMVJXWm1wTE5YaGtVVEZpU2tWMU5XVldkVFJRV1RSQ2NqRjZRemxIVm1zM2NEWnZjRVp0UVdwRlFRcHROMnB1VUZFeWFrTTFRa3c1TUVaSmJIZE5aR1ZGVmxCblRtMVNOM04yUmtWRmJISnJVVzFsTkROU2NYUTJjSFprUjJ0elRVRjZRWEZoVjFoUlJuRlVDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyDADAgEAMIICvwYJKoZIhvcNAQcCoIICsDCCAqwCAQMxDTALBglghkgBZQMEAgEwgbcGCyqGSIb3DQEJEAEEoIGnBIGkMIGhAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQghcKBnsFCpVtXbanqDCSR8zDubO5wb4xvtguYuZJRTKMCFGXfBMQDzomI8IngRpeuarmPZQoDGA8yMDI2MDYzMDE3NDQwOFowAwIBAaAypDAwLjEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MRUwEwYDVQQDEwxzaWdzdG9yZS10c2GgADGCAdowggHWAgEBMFEwOTEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MSAwHgYDVQQDExdzaWdzdG9yZS10c2Etc2VsZnNpZ25lZAIUOhNULwyQYe68wUMvy4qOiyojiwwwCwYJYIZIAWUDBAIBoIH8MBoGCSqGSIb3DQEJAzENBgsqhkiG9w0BCRABBDAcBgkqhkiG9w0BCQUxDxcNMjYwNjMwMTc0NDA4WjAvBgkqhkiG9w0BCQQxIgQgczwr9pKyxDMc0eur+DGt9Mdetezf8UQKp2Sn3wspffwwgY4GCyqGSIb3DQEJEAIvMX8wfTB7MHkEIIX5J7wHq2LKw7RDVsEO/IGyxog/2nq55thw2dE6zQW3MFUwPaQ7MDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAoGCCqGSM49BAMCBGYwZAIwJmfpM3hVIBsGwNTieyT54BZfQTwFye2f0/les1QzRFpXz5nu59C0tKLFYqcNPDdQAjBI9y5eNjjl9yo9BtpcZmIjURLuYioqzrjahNDmiThJZgRNROaVkPWrE5dlDJoFe58="}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"I164tPH7KbPuiT5yiFVHd6fBflpz1DenQ75wGPKdmLg="},"signature":"MEUCICWd9LegvRoJ1X6HAL0k/RWPoLmlm/XSw7UxNZciHS0sAiEA7SZaJRWTiGvRHYhvwJKKDpD5gECYOnF0c3EDLOES9cM="}}
build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py ADDED
File without changes
build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+ from functools import partial
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32, Boolean, const_expr
9
+ from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm, nvvm
11
+
12
+
13
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
14
+
15
+
16
+ sub_packed_f32x2 = partial(
17
+ cute.arch.calc_packed_f32x2_op,
18
+ src_c=None,
19
+ calc_func=nvvm.sub_packed_f32x2,
20
+ )
21
+
22
+
23
+ @dsl_user_op
24
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
25
+ return Float32(
26
+ llvm.inline_asm(
27
+ T.f32(),
28
+ [Float32(a).ir_value(loc=loc, ip=ip)],
29
+ "tanh.approx.f32 $0, $1;",
30
+ "=f,f",
31
+ has_side_effects=False,
32
+ is_align_stack=False,
33
+ asm_dialect=llvm.AsmDialect.AD_ATT,
34
+ )
35
+ )
36
+
37
+
38
+ @dsl_user_op
39
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
40
+ if const_expr(not isinstance(x, tuple)):
41
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
42
+ return 0.5 + 0.5 * tanh(0.5 * x)
43
+ else:
44
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
45
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
46
+ return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
47
+
48
+
49
+ @dsl_user_op
50
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
51
+ # return dout * out * (1.0 - out)
52
+ return dout * (out - out * out)
53
+
54
+
55
+ @dsl_user_op
56
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
57
+ if const_expr(not isinstance(x, tuple)):
58
+ return cute.arch.fmax(x, Float32(0.0))
59
+ else:
60
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
61
+
62
+
63
+ @dsl_user_op
64
+ @cute.jit
65
+ def drelu(
66
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
67
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
68
+ if const_expr(not isinstance(x, tuple)):
69
+ x_pos = Boolean(x > 0)
70
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
71
+ else:
72
+ x0_pos = Boolean(x[0] > 0)
73
+ x1_pos = Boolean(x[1] > 0)
74
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
75
+ return dx, relu(x)
76
+
77
+
78
+ @dsl_user_op
79
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
80
+ if const_expr(not isinstance(x, tuple)):
81
+ return cute.arch.fmax(x, Float32(0.0)) * x
82
+ else:
83
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
84
+ return cute.arch.mul_packed_f32x2(relu_x, x)
85
+
86
+
87
+ @dsl_user_op
88
+ @cute.jit
89
+ def drelu_sq(
90
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
91
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
92
+ """
93
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
94
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
95
+ Returns: (dx, relu_sq_out) where:
96
+ - dx = dout * 2 * x if x > 0, else 0
97
+ - relu_sq_out = max(x, 0) * x
98
+ """
99
+ if const_expr(not isinstance(x, tuple)):
100
+ relu_x = relu(x)
101
+ relu_sq_out = relu_x * x
102
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
103
+ dx = 2.0 * (dout * relu_x)
104
+ return dx, relu_sq_out
105
+ else:
106
+ relu_x = relu(x)
107
+ relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
108
+ dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
109
+ return dx, relu_sq_out
110
+
111
+
112
+ @dsl_user_op
113
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
114
+ """
115
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
116
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
117
+ """
118
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
119
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
120
+ if const_expr(not isinstance(x, tuple)):
121
+ return 0.5 * (
122
+ x
123
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
124
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
125
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
126
+ )
127
+ else:
128
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
129
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
130
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
131
+ )
132
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
133
+ tanh_z = (tanh(z[0]), tanh(z[1]))
134
+ x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
135
+ return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
136
+
137
+
138
+ @dsl_user_op
139
+ def dgelu_tanh_approx(
140
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
141
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
142
+ """
143
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
144
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
145
+ Returns: (dx, gelu_out)
146
+
147
+ Derivative uses the chain rule:
148
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
149
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
150
+ and sech^2(z) = 1 - tanh^2(z)
151
+ """
152
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
153
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
154
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
155
+
156
+ if const_expr(not isinstance(x, tuple)):
157
+ # Compute z = x * (c1 + c2 * x^2)
158
+ x_sq = x * x
159
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
160
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
161
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
162
+ gelu_out = x * half_tanh_z_plus_one
163
+
164
+ # Compute gradient
165
+ # sech^2(z) = 1 - tanh^2(z)
166
+ sech2_z = 1 - tanh_z * tanh_z
167
+ # dz/dx = c1 + 3 * c2 * x^2
168
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
169
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
170
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
171
+
172
+ dx = dout * dgelu
173
+ return dx, gelu_out
174
+ else:
175
+ # Compute z = x * (c1 + c2 * x^2)
176
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
177
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
178
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
179
+ )
180
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
181
+ tanh_z = (tanh(z[0]), tanh(z[1]))
182
+ half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
183
+ gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
184
+
185
+ # Compute gradient
186
+ # sech^2(z) = 1 - tanh^2(z)
187
+ sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
188
+ # dz/dx = c1 + 3 * c2 * x^2
189
+ dz_dx = cute.arch.fma_packed_f32x2(
190
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
191
+ )
192
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
193
+ sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
194
+ x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
195
+ dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
196
+
197
+ dx = cute.arch.mul_packed_f32x2(dout, dgelu)
198
+ return dx, gelu_out
199
+
200
+
201
+ @dsl_user_op
202
+ @cute.jit
203
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
204
+ if const_expr(not isinstance(x, tuple)):
205
+ use_linear = Boolean(x > 20.0)
206
+ return (
207
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
208
+ if not use_linear
209
+ else x
210
+ )
211
+ else:
212
+ log2_e = math.log2(math.e)
213
+ x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
214
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
215
+ x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
216
+ log_x_exp_p1 = (
217
+ cute.math.log2(x_exp_p1[0], fastmath=True),
218
+ cute.math.log2(x_exp_p1[1], fastmath=True),
219
+ )
220
+ ln2 = math.log(2.0)
221
+ softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
222
+ use_linear_0 = Boolean(x[0] > 20.0)
223
+ use_linear_1 = Boolean(x[1] > 20.0)
224
+ return (
225
+ softplus_x[0] if not use_linear_0 else x[0],
226
+ softplus_x[1] if not use_linear_1 else x[1],
227
+ )
228
+
229
+
230
+ @dsl_user_op
231
+ @cute.jit
232
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
233
+ use_linear = Boolean(out > 20.0)
234
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
235
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
236
+ return dx if not use_linear else dout
237
+
238
+
239
+ @dsl_user_op
240
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
241
+ """
242
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
243
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
244
+ """
245
+ if const_expr(not isinstance(x, tuple)):
246
+ x_half = 0.5 * x if const_expr(not already_halved) else x
247
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
248
+ return x_half * tanh(x_half) + x_half
249
+ else:
250
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
251
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
252
+ return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
253
+
254
+
255
+ @dsl_user_op
256
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
257
+ if const_expr(not isinstance(x, tuple)):
258
+ return silu(x) * y
259
+ else:
260
+ return cute.arch.mul_packed_f32x2(silu(x), y)
261
+
262
+
263
+ @dsl_user_op
264
+ def dswiglu(
265
+ x: F32_or_F32x2,
266
+ y: F32_or_F32x2,
267
+ dout: F32_or_F32x2,
268
+ *,
269
+ already_halved: bool = False,
270
+ loc=None,
271
+ ip=None,
272
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
273
+ """
274
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
275
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
276
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
277
+
278
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
279
+
280
+ This has been optimized to use fewer instructions (i.e. we expand things out
281
+ to use FFMA instead of FADD and FMUL).
282
+ """
283
+ if const_expr(not isinstance(x, tuple)):
284
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
285
+ # FMUL, MUFU.TANH, then FFMA
286
+ if const_expr(not already_halved):
287
+ sigmoid_x = sigmoid(x)
288
+ silu_x = x * sigmoid_x # FMUL
289
+ else:
290
+ tanh_x = tanh(x) # MUFU.TANH
291
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
292
+ silu_x = x * tanh_x + x # FFMA
293
+ silu_x_dout = silu_x * dout # FMUL
294
+ # d_silu(x) * dout
295
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
296
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
297
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
298
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
299
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
300
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
301
+ dx = d_silu_x_dout * y # FMUL
302
+ dy = silu_x_dout
303
+ swiglu_out = silu_x * y # FMUL
304
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
305
+ return dx, dy, swiglu_out
306
+ else:
307
+ # Compute sigmoid(x) and silu(x)
308
+ if const_expr(not already_halved):
309
+ sigmoid_x = sigmoid(x)
310
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
311
+ else:
312
+ tanh_x = (tanh(x[0]), tanh(x[1]))
313
+ sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
314
+ silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
315
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
316
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
317
+ sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
318
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
319
+ )
320
+ d_silu_x_dout = cute.arch.fma_packed_f32x2(
321
+ sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
322
+ )
323
+ dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
324
+ dy = silu_x_dout
325
+ swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
326
+ return dx, dy, swiglu_out
327
+
328
+
329
+ @dsl_user_op
330
+ def swiglu_oai(
331
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
332
+ ) -> F32_or_F32x2:
333
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
334
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
335
+ x * sigmoid(alpha * x) * (y + 1)
336
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
337
+ """
338
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
339
+ if const_expr(not isinstance(x, tuple)):
340
+ x_half = 0.5 * x
341
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
342
+ silu_x = x_half * tanh(alpha * x_half) + x_half
343
+ return silu_x * y + silu_x
344
+ else:
345
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
346
+ alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
347
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
348
+ silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
349
+ return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
350
+
351
+
352
+ @dsl_user_op
353
+ def dswiglu_oai(
354
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
355
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
356
+ """
357
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
358
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
359
+ Returns: (dx, dy, swiglu_oai_out)
360
+
361
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
362
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
363
+ """
364
+ if const_expr(not isinstance(x, tuple)):
365
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
366
+ alpha_x_half = (0.5 * alpha) * x # FMUL
367
+ # MUFU.TANH, then FFMA
368
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
369
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
370
+ silu_x = x * sigmoid_alpha_x # FMUL
371
+ silu_x_dout = silu_x * dout # FMUL
372
+ # FFMA, FFMA, FMUL
373
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
374
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
375
+ dy = silu_x_dout
376
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
377
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
378
+ return dx, dy, swiglu_out
379
+ else:
380
+ # Compute sigmoid(alpha * x)
381
+ alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
382
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
383
+ sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
384
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
385
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
386
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
387
+ silu_x_minus_product = cute.arch.fma_packed_f32x2(
388
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
389
+ )
390
+ sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
391
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
392
+ )
393
+ d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
394
+ dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
395
+ dy = silu_x_dout
396
+ swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
397
+ return dx, dy, swiglu_out
398
+
399
+
400
+ @dsl_user_op
401
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
402
+ """GLU: Gated Linear Unit
403
+ glu(x, y) = sigmoid(x) * y
404
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
405
+ """
406
+ if const_expr(not isinstance(x, tuple)):
407
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
408
+ return sigmoid_x * y # FMUL
409
+ else:
410
+ sigmoid_x = sigmoid(x)
411
+ return cute.arch.mul_packed_f32x2(sigmoid_x, y)
412
+
413
+
414
+ @dsl_user_op
415
+ def dglu(
416
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
417
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
418
+ """
419
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
420
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
421
+ Returns: (dx, dy, glu_out) where:
422
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
423
+ - dy = dout * sigmoid(x)
424
+ - glu_out = sigmoid(x) * y
425
+ """
426
+ if const_expr(not isinstance(x, tuple)):
427
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
428
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
429
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
430
+ glu_out = sigmoid_x * y # FMUL
431
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
432
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
433
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
434
+ # = (y - glu_out) * sigmoid_x_dout
435
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
436
+ dy = sigmoid_x_dout
437
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
438
+ return dx, dy, glu_out
439
+ else:
440
+ sigmoid_x = sigmoid(x)
441
+ sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
442
+ glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
443
+ # dx = (y - glu_out) * sigmoid_x_dout
444
+ y_minus_glu_out = sub_packed_f32x2(y, glu_out)
445
+ dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
446
+ dy = sigmoid_x_dout
447
+ return dx, dy, glu_out
448
+
449
+
450
+ @dsl_user_op
451
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
452
+ """ReGLU: ReLU Gated Linear Unit
453
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
454
+ """
455
+ if const_expr(not isinstance(x, tuple)):
456
+ return cute.arch.fmax(x, Float32(0.0)) * y
457
+ else:
458
+ relu_x = relu(x)
459
+ return cute.arch.mul_packed_f32x2(relu_x, y)
460
+
461
+
462
+ @dsl_user_op
463
+ @cute.jit
464
+ def dreglu(
465
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
466
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
467
+ """
468
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
469
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
470
+ Returns: (dx, dy, reglu_out) where:
471
+ - dx = dout * y if x > 0, else 0
472
+ - dy = dout * relu(x)
473
+ - reglu_out = relu(x) * y
474
+ """
475
+ if const_expr(not isinstance(x, tuple)):
476
+ x_pos = Boolean(x > 0)
477
+ relu_x = cute.arch.fmax(x, Float32(0.0))
478
+ dx = (dout * y) if x_pos else Float32(0.0)
479
+ dy = dout * relu_x
480
+ reglu_out = relu_x * y
481
+ return dx, dy, reglu_out
482
+ else:
483
+ x0_pos = Boolean(x[0] > 0)
484
+ x1_pos = Boolean(x[1] > 0)
485
+ relu_x = relu(x)
486
+ dout_y = cute.arch.mul_packed_f32x2(dout, y)
487
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
488
+ dy = cute.arch.mul_packed_f32x2(dout, relu_x)
489
+ reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
490
+ return dx, dy, reglu_out
491
+
492
+
493
+ @dsl_user_op
494
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
495
+ """GeGLU: GELU Gated Linear Unit
496
+ geglu(x, y) = gelu(x) * y
497
+ Uses the tanh approximation of GELU
498
+ """
499
+ if const_expr(not isinstance(x, tuple)):
500
+ return gelu_tanh_approx(x) * y
501
+ else:
502
+ return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
503
+
504
+
505
+ @dsl_user_op
506
+ def dgeglu(
507
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
508
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
509
+ """
510
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
511
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
512
+ Returns: (dx, dy, geglu_out) where:
513
+ - dx = dout * y * d_gelu(x)
514
+ - dy = dout * gelu(x)
515
+ - geglu_out = gelu(x) * y
516
+ """
517
+ if const_expr(not isinstance(x, tuple)):
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = dgelu_x_dout * y
522
+ dy = gelu_x * dout
523
+ geglu_out = gelu_x * y
524
+ return dx, dy, geglu_out
525
+ else:
526
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
527
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
528
+ # Compute gradients for geglu
529
+ dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
530
+ dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
531
+ geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
532
+ return dx, dy, geglu_out
build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )
build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import re
4
+ from typing import Optional, Type, Tuple, Callable, Sequence
5
+ from functools import partial
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+
10
+ from cutlass import Int32, Int16, Boolean, const_expr
11
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
12
+ from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
13
+ from cutlass.cutlass_dsl import dsl_user_op
14
+ import cutlass.pipeline
15
+ from cutlass._mlir.dialects import llvm
16
+ from cutlass._mlir import ir
17
+ from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
18
+
19
+
20
+ Sm100MmaPeerBitMask = 0xFEFFFFFF
21
+
22
+
23
+ @dsl_user_op
24
+ def cvt_copy(
25
+ tiled_copy: cute.TiledCopy,
26
+ src: cute.Tensor,
27
+ dst: cute.Tensor,
28
+ *,
29
+ pred: Optional[cute.Tensor] = None,
30
+ retile: bool = False,
31
+ loc=None,
32
+ ip=None,
33
+ **kwargs,
34
+ ) -> None:
35
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
36
+ if const_expr(src.element_type != dst.element_type):
37
+ src_cvt = cute.make_fragment_like(src, dst.element_type)
38
+ src_cvt.store(src.load().to(dst.element_type))
39
+ src = src_cvt
40
+ if const_expr(retile):
41
+ src = tiled_copy.retile(src)
42
+ cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
43
+
44
+
45
+ @dsl_user_op
46
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
47
+ dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
48
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
49
+ return dst
50
+
51
+
52
+ @dsl_user_op
53
+ def load_s2r_retile(
54
+ tiled_copy: cute.TiledCopy,
55
+ src: cute.Tensor,
56
+ dst_shape: cute.Tensor | cute.Shape,
57
+ *,
58
+ loc=None,
59
+ ip=None,
60
+ ) -> cute.Tensor:
61
+ # Will also accept dst_shape being a tensor, in which case we write into that tensor
62
+ if const_expr(not isinstance(dst_shape, cute.Tensor)):
63
+ dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
64
+ else:
65
+ dst = dst_shape
66
+ cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
67
+ return dst
68
+
69
+
70
+ @dsl_user_op
71
+ def get_copy_atom(
72
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
73
+ ) -> cute.CopyAtom:
74
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
75
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
76
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
77
+
78
+
79
+ @dsl_user_op
80
+ def copy(
81
+ src: cute.Tensor,
82
+ dst: cute.Tensor,
83
+ *,
84
+ pred: Optional[cute.Tensor] = None,
85
+ is_async: bool = False,
86
+ loc=None,
87
+ ip=None,
88
+ **kwargs,
89
+ ) -> None:
90
+ num_copy_elems = src.shape[0][0]
91
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
92
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
93
+
94
+
95
+ def tiled_copy_1d(
96
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
97
+ ) -> cute.TiledCopy:
98
+ num_copy_bits = num_copy_elems * dtype.width
99
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
100
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
101
+ thr_layout = cute.make_layout(num_threads)
102
+ val_layout = cute.make_layout(num_copy_elems)
103
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
104
+
105
+
106
+ def tiled_copy_2d(
107
+ dtype: Type[cutlass.Numeric],
108
+ threads_per_row: int,
109
+ num_threads: int,
110
+ num_copy_elems: int = 1,
111
+ is_async: bool = False,
112
+ ) -> cute.TiledCopy:
113
+ num_copy_bits = num_copy_elems * dtype.width
114
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
115
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
116
+ assert num_threads % threads_per_row == 0
117
+ thr_layout = cute.make_ordered_layout(
118
+ (num_threads // threads_per_row, threads_per_row),
119
+ order=(1, 0),
120
+ )
121
+ val_layout = cute.make_layout((1, num_copy_elems))
122
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
123
+
124
+
125
+ @cute.jit
126
+ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
127
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
128
+ tApA = cute.make_rmem_tensor(
129
+ cute.make_layout(
130
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
131
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
132
+ ),
133
+ Boolean,
134
+ )
135
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
136
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
137
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
138
+ return tApA
139
+
140
+
141
+ # def tiled_copy_2d(
142
+ # dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
143
+ # ) -> cute.TiledCopy:
144
+ # num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
145
+ # copy_elems = num_copy_bits // dtype.width
146
+ # copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
147
+ # copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
148
+ # gmem_threads_per_row = major_mode_size // copy_elems
149
+ # assert num_threads % gmem_threads_per_row == 0
150
+ # thr_layout = cute.make_ordered_layout(
151
+ # (num_threads // gmem_threads_per_row, gmem_threads_per_row),
152
+ # order=(1, 0),
153
+ # )
154
+ # val_layout = cute.make_layout((1, copy_elems))
155
+ # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
156
+
157
+
158
+ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
159
+ """Extract swizzle parameters from a pointer's swizzle_type.
160
+
161
+ The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
162
+ b, m, s are the swizzle parameters (bits, base, shift).
163
+
164
+ Returns:
165
+ A cute.Swizzle object constructed from the extracted parameters
166
+
167
+ Raises:
168
+ ValueError: If the swizzle_type string cannot be parsed
169
+ """
170
+ # Ideally there should be a better API to get swizzle parameters, but we'll just parse
171
+ # the string here.
172
+ swizzle_str = str(ptr.type.swizzle_type)
173
+ # Extract the inner part "S<b,m,s>"
174
+ match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
175
+ if match:
176
+ b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
177
+ return b, m, s
178
+ else:
179
+ raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
180
+
181
+
182
+ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
183
+ bit_msk = (1 << b) - 1
184
+ yyy_msk = bit_msk << (m + s)
185
+ return ptr_int ^ ((ptr_int & yyy_msk) >> s)
186
+
187
+
188
+ def swizzle_ptr(ptr: cute.Pointer):
189
+ b, m, s = parse_swizzle_from_pointer(ptr)
190
+ ptr_int = swizzle_int(ptr.toint(), b, m, s)
191
+ return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
192
+
193
+
194
+ def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
195
+ outer = tensor.layout
196
+ width = tensor.element_type.width
197
+ inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
198
+ # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
199
+ # for 16 bits and <3, 2, 3> for 32 bits)
200
+ new_layout = cute.recast_layout(
201
+ width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
202
+ )
203
+ # recast_ptr to remove the pointer swizzle
204
+ return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
205
+
206
+
207
+ def partition_D_position_independent(
208
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
209
+ ) -> cute.Tensor:
210
+ return cute.make_tensor(
211
+ swizzle_ptr(thr_copy.partition_D(tensor).iterator),
212
+ thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
213
+ )
214
+
215
+
216
+ def partition_S_position_independent(
217
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
218
+ ) -> cute.Tensor:
219
+ return cute.make_tensor(
220
+ swizzle_ptr(thr_copy.partition_S(tensor).iterator),
221
+ thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
222
+ )
223
+
224
+
225
+ @dsl_user_op
226
+ def sm90_get_smem_load_op(
227
+ layout_c: cutlass.utils.LayoutEnum,
228
+ elem_ty_c: Type[cutlass.Numeric],
229
+ *,
230
+ loc=None,
231
+ ip=None,
232
+ ) -> cute.CopyAtom:
233
+ """
234
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
235
+
236
+ Parameters:
237
+ -----------
238
+ layout_c : LayoutEnum
239
+ The layout enum of the output tensor D.
240
+
241
+ elem_ty_c : Type[Numeric]
242
+ The element type for output tensor D.
243
+
244
+ Returns:
245
+ --------
246
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
247
+ """
248
+
249
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
250
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
251
+ is_m_major = layout_c.is_m_major_c()
252
+ if elem_ty_c.width == 16:
253
+ return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
254
+ else:
255
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
256
+
257
+
258
+ def get_smem_store_atom(
259
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
260
+ ) -> cute.CopyAtom:
261
+ if const_expr(arch < 90 or element_type.width != 16):
262
+ return cute.make_copy_atom(
263
+ cute.nvgpu.CopyUniversalOp(),
264
+ element_type,
265
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
266
+ )
267
+ else:
268
+ return cute.make_copy_atom(
269
+ warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
270
+ element_type,
271
+ )
272
+
273
+
274
+ def get_smem_load_atom(
275
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
276
+ ) -> cute.CopyAtom:
277
+ if const_expr(arch < 90 or element_type.width != 16):
278
+ return cute.make_copy_atom(
279
+ cute.nvgpu.CopyUniversalOp(),
280
+ element_type,
281
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
282
+ )
283
+ else:
284
+ return cute.make_copy_atom(
285
+ warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
286
+ element_type,
287
+ )
288
+
289
+
290
+ def get_smem_store_C(
291
+ tiled_mma: cute.TiledMma,
292
+ sC: cute.Tensor,
293
+ tidx: Int32,
294
+ arch: int,
295
+ transpose: bool = False,
296
+ position_independent=False,
297
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
298
+ dtype = sC.element_type
299
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
300
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
301
+ thr_copy = tiled_copy.get_slice(tidx)
302
+ if const_expr(not position_independent):
303
+ tRS_sC = thr_copy.partition_D(sC)
304
+ else:
305
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
306
+
307
+ def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
308
+ dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
309
+ cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
310
+
311
+ return copy_fn, thr_copy, tRS_sC
312
+
313
+
314
+ def get_smem_load_C(
315
+ tiled_mma: cute.TiledMma,
316
+ sC: cute.Tensor,
317
+ tidx: Int32,
318
+ arch: int,
319
+ transpose: bool = False,
320
+ position_independent=False,
321
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
322
+ dtype = sC.element_type
323
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
324
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
325
+ thr_copy = tiled_copy.get_slice(tidx)
326
+ if const_expr(not position_independent):
327
+ tSR_sC = thr_copy.partition_S(sC)
328
+ else:
329
+ tSR_sC = partition_S_position_independent(thr_copy, sC)
330
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
331
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
332
+ tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
333
+
334
+ def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
335
+ src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
336
+ return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
337
+
338
+ return copy_fn, thr_copy, tSR_sC
339
+
340
+
341
+ def epilog_smem_copy_atom(
342
+ tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
343
+ ) -> cute.TiledCopy:
344
+ copy_atom_C = cute.make_copy_atom(
345
+ warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
346
+ cutlass.Float16, # this is just to get the right source layout
347
+ )
348
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
349
+ return tiled_copy_C_atom
350
+
351
+
352
+ def get_smem_store_epi(
353
+ tiled_mma: cute.TiledMma,
354
+ epi_tile: cute.Shape,
355
+ sC: Optional[cute.Tensor],
356
+ tidx: Int32,
357
+ arch: int,
358
+ transpose: bool = False,
359
+ position_independent=False,
360
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
361
+ dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
362
+ tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
363
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
364
+ tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
365
+ thr_copy = tiled_copy.get_slice(tidx)
366
+ tRS_sC = None
367
+ if const_expr(sC is not None):
368
+ if const_expr(not position_independent):
369
+ tRS_sC = thr_copy.partition_D(sC)
370
+ else:
371
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
372
+ sC_shape = sC.shape[:2] if sC is not None else epi_tile
373
+ # (R2S, R2S_M, R2S_N, PIPE_C)
374
+ tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
375
+ tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
376
+
377
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
378
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
379
+
380
+ return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
381
+
382
+
383
+ def get_smem_store_A(
384
+ tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
385
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
386
+ dtype = sA.element_type
387
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
388
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
389
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
390
+ thr_copy = tiled_copy.get_slice(tidx)
391
+ if const_expr(not position_independent):
392
+ tRS_sA = thr_copy.partition_D(sA)
393
+ else:
394
+ tRS_sA = partition_D_position_independent(thr_copy, sA)
395
+
396
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
397
+ cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
398
+
399
+ return copy_fn, thr_copy, tRS_sA
400
+
401
+
402
+ def get_smem_load_A(
403
+ tiled_mma: cute.TiledMma,
404
+ sA: cute.Tensor,
405
+ tidx: Int32,
406
+ arch: int,
407
+ with_dst_tensor: bool = False,
408
+ position_independent=False,
409
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
410
+ dtype = sA.element_type
411
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
412
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
413
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
414
+ thr_copy = tiled_copy.get_slice(tidx)
415
+ if const_expr(not position_independent):
416
+ tSR_sA = thr_copy.partition_S(sA)
417
+ else:
418
+ tSR_sA = partition_S_position_independent(thr_copy, sA)
419
+ tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
420
+
421
+ def copy_fn(src_idx: Int32, **new_kwargs):
422
+ return load_s2r_retile(
423
+ tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
424
+ )
425
+
426
+ def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
427
+ return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
428
+
429
+ return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
430
+
431
+
432
+ @dsl_user_op
433
+ def cpasync_reduce_bulk_add_f32(
434
+ smem_ptr: cute.Pointer,
435
+ gmem_ptr: cute.Pointer,
436
+ store_bytes: int | Int32,
437
+ *,
438
+ loc=None,
439
+ ip=None,
440
+ ):
441
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
442
+ # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
443
+ llvm.inline_asm(
444
+ None,
445
+ [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
446
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
447
+ "l,r,r",
448
+ # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
449
+ # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
450
+ # "l,r,r,l",
451
+ has_side_effects=True,
452
+ is_align_stack=False,
453
+ asm_dialect=llvm.AsmDialect.AD_ATT,
454
+ )
455
+
456
+
457
+ @dsl_user_op
458
+ def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
459
+ """
460
+ Get the address of the TMA descriptor embedded in a TMA Copy Atom.
461
+
462
+ Extracts the constant memory address of the TMA descriptor for use with
463
+ custom PTX instructions.
464
+
465
+ :param tma_atom: TMA Copy Atom from make_tiled_tma_atom
466
+ :return: Pointer to TMA descriptor in constant memory
467
+
468
+ Example:
469
+ >>> desc_ptr = get_tma_descriptor_address(tma_atom)
470
+ """
471
+ exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
472
+ tma_desc_ptr_type = ir.Type.parse(
473
+ "!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
474
+ )
475
+ return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
476
+
477
+
478
+ @dsl_user_op
479
+ def tma_gather4_load(
480
+ tma_desc_ptr: cute.Pointer,
481
+ dst_smem_ptr: cute.Pointer,
482
+ mbarrier_ptr: cute.Pointer,
483
+ col_idx: Int32,
484
+ row_indices: Sequence[Int32],
485
+ *,
486
+ num_cta: int = 1,
487
+ multicast_mask=None,
488
+ loc=None,
489
+ ip=None,
490
+ ) -> None:
491
+ """
492
+ Perform TMA gather4 load from global memory to shared memory.
493
+
494
+ Issues PTX instruction:
495
+ cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
496
+ [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
497
+
498
+ This loads 4 rows (specified by row_indices) from a 2D tensor at the given
499
+ column index into shared memory, using the TMA descriptor.
500
+
501
+ :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
502
+ :type tma_desc_ptr: Pointer
503
+ :param dst_smem_ptr: Destination address in shared memory
504
+ :type dst_smem_ptr: Pointer
505
+ :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
506
+ :type mbarrier_ptr: Pointer
507
+ :param col_idx: Column index
508
+ :type col_idx: Int32
509
+ :param row_indices: Sequence of exactly 4 row indices
510
+ :type row_indices: Sequence[Int32]
511
+ :param num_cta: Number of CTAs participating (default: 1)
512
+ :type num_cta: int
513
+ :param multicast_mask: Optional multicast mask
514
+ :type multicast_mask: Int16
515
+
516
+ Requirements:
517
+ - row_indices must contain exactly 4 elements
518
+ - Compute capability >= SM_100 (Blackwell)
519
+ - TMA descriptor must be properly initialized for 2D tensor
520
+
521
+ Example:
522
+ >>> from cutlass.cute.nvgpu import cpasync
523
+ >>> from cutlass.cute import core
524
+ >>>
525
+ >>> # Create TMA descriptor
526
+ >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
527
+ >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
528
+ >>>
529
+ >>> # Compute indices (typically from kernel logic)
530
+ >>> col_idx = core.get(...) or 5 # Int32 value
531
+ >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
532
+ >>>
533
+ >>> # Gather 4 rows at computed column
534
+ >>> tma_gather4_load(
535
+ ... tma_desc_ptr=tma_desc_ptr,
536
+ ... dst_smem_ptr=smem_ptr,
537
+ ... mbarrier_ptr=barrier_ptr,
538
+ ... col_idx=col_idx,
539
+ ... row_indices=row_indices
540
+ ... )
541
+ """
542
+ if len(row_indices) != 4:
543
+ raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
544
+ col_val = Int32(col_idx).ir_value()
545
+ row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
546
+ # Convert pointers to integer addresses
547
+ desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
548
+ dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
549
+ mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
550
+ if num_cta > 1:
551
+ # Executed by both CTAs. Set peer bit to 0 so that the
552
+ # transaction bytes will update CTA0's barrier.
553
+ mbar_addr = mbar_addr & Sm100MmaPeerBitMask
554
+ mbar_addr = mbar_addr.ir_value()
555
+ # Handle multicast_mask - may already be ir.Value or Python int
556
+ multicast_mask_val = None
557
+ if multicast_mask is not None:
558
+ multicast_mask_val = Int16(multicast_mask).ir_value()
559
+ assert multicast_mask_val is None, "multicast is not supported yet"
560
+ # Emit inline PTX for TMA gather4
561
+ # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
562
+ # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
563
+ ptx = (
564
+ f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
565
+ "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
566
+ )
567
+
568
+ llvm.inline_asm(
569
+ None,
570
+ [
571
+ dst_addr,
572
+ desc_addr,
573
+ col_val,
574
+ row_vals[0],
575
+ row_vals[1],
576
+ row_vals[2],
577
+ row_vals[3],
578
+ mbar_addr,
579
+ ],
580
+ ptx,
581
+ "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
582
+ has_side_effects=True,
583
+ is_align_stack=False,
584
+ asm_dialect=llvm.AsmDialect.AD_ATT,
585
+ loc=loc,
586
+ ip=ip,
587
+ )
588
+
589
+
590
+ def cpasync_bulk_get_copy_fn(
591
+ src_tensor: cute.Tensor,
592
+ dst_tensor: cute.Tensor,
593
+ single_stage: bool = False,
594
+ **kwargs,
595
+ ) -> Callable:
596
+ group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
597
+ group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
598
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
599
+ src = cute.group_modes(src_tensor, 0, group_rank_src)
600
+ dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
601
+
602
+ def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
603
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
604
+ with cute.arch.elect_one():
605
+ cute.copy(
606
+ atom,
607
+ src[None, src_idx],
608
+ dst[None, dst_idx],
609
+ mbar_ptr=tma_bar_ptr,
610
+ **new_kwargs,
611
+ **kwargs,
612
+ )
613
+
614
+ def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
615
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
616
+ with cute.arch.elect_one():
617
+ cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
618
+
619
+ return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
620
+
621
+
622
+ def tma_get_copy_fn(
623
+ atom: cute.CopyAtom,
624
+ cta_coord: cute.Coord,
625
+ cta_layout: cute.Layout,
626
+ src_tensor: cute.Tensor,
627
+ dst_tensor: cute.Tensor,
628
+ filter_zeros: bool = False,
629
+ single_stage: bool = False,
630
+ **kwargs,
631
+ ) -> Callable:
632
+ src_is_smem = const_expr(
633
+ isinstance(src_tensor.iterator, cute.Pointer)
634
+ and src_tensor.memspace == cute.AddressSpace.smem
635
+ )
636
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
637
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
638
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
639
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
640
+ s, g = cpasync.tma_partition(
641
+ atom,
642
+ cta_coord,
643
+ cta_layout,
644
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
645
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
646
+ )
647
+ if const_expr(filter_zeros):
648
+ s = cute.filter_zeros(s)
649
+ g = cute.filter_zeros(g)
650
+ src, dst = (s, g) if src_is_smem else (g, s)
651
+
652
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
653
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
654
+
655
+ def copy_tma_single_stage(**new_kwargs):
656
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
657
+
658
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
659
+
660
+
661
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
662
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
663
+ copy(
664
+ src_idx=src_idx,
665
+ dst_idx=producer_state.index,
666
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
667
+ **new_kwargs,
668
+ )
669
+
670
+ return copy_fn
671
+
672
+
673
+ @cute.jit
674
+ def gather_m_get_copy_fn(
675
+ thr_copy_A: cute.ThrCopy,
676
+ mA: cute.Tensor, # (whatever, K)
677
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
678
+ gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
679
+ limit_m: Int32,
680
+ limit_k: Int32,
681
+ ) -> Callable:
682
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
683
+ tAsA = thr_copy_A.partition_D(sA)
684
+ # k-major
685
+ assert tAsA.shape[2] == 1
686
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
687
+
688
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
689
+ if const_expr(not is_even_m_smem):
690
+ limit_m = min(limit_m, tile_shape_mk[0])
691
+ elems_per_load = cute.size(tAsA.shape[0][0])
692
+ cA = cute.make_identity_tensor(tile_shape_mk)
693
+ tAcA = thr_copy_A.partition_S(cA)
694
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
695
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
696
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
697
+ # This is so that when we do the comparison, t0AcA is known at compile time.
698
+ limit_m = limit_m - tAcA[0][0]
699
+ limit_k = limit_k - tAcA[0][1]
700
+ # Read and cache indices for A
701
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
702
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
703
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
704
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
705
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
706
+ m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
707
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
708
+ row_idx = tAcA[0, m, 0][0]
709
+ if tApA_m[m]:
710
+ m_idx[m] = gsAIdx[row_idx]
711
+ else:
712
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
713
+
714
+ mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
715
+
716
+ def copy_fn(src_idx, dst_idx, pred: bool = False):
717
+ tApA_k = None
718
+ if const_expr(pred):
719
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
720
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
721
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
722
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
723
+ mA_cur = mA_k[None, (None, src_idx)]
724
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
725
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
726
+ # ((elems_per_load), thread_per_row)
727
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
728
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
729
+ mA_row = cute.tiled_divide(
730
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
731
+ )[None, None, 0]
732
+ if const_expr(is_even_m_smem) or tApA_m[m]:
733
+ # There's only 1 load per row
734
+ assert cute.size(tAcA.shape, mode=[2]) == 1
735
+ ki = tAcA[0, 0, 0][1] // elems_per_load
736
+ cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
737
+
738
+ return copy_fn
739
+
740
+
741
+ @cute.jit
742
+ def gather_k_get_copy_fn(
743
+ thr_copy_A: cute.ThrCopy,
744
+ mA: cute.Tensor, # (tile_M, whatever)
745
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
746
+ gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
747
+ limit_m: Int32,
748
+ limit_k: Int32,
749
+ ) -> Callable:
750
+ gAIdx, sAIdx = None, None
751
+ if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
752
+ gAIdx = gsAIdx
753
+ else:
754
+ assert gsAIdx.memspace == cute.AddressSpace.smem
755
+ sAIdx = gsAIdx
756
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
757
+ # (atom_v, CPY_M, 1, STAGE)
758
+ tAsA = thr_copy_A.partition_D(sA)
759
+ # m-major
760
+ tAsA = cute.group_modes(tAsA, 0, 3)
761
+
762
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
763
+ if const_expr(not is_even_m_smem):
764
+ limit_m = min(limit_m, tile_shape_mk[0])
765
+ elems_per_load = cute.size(tAsA.shape[0][0])
766
+ cA = cute.make_identity_tensor(tile_shape_mk)
767
+ tAcA = thr_copy_A.partition_S(cA)
768
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
769
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
770
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
771
+ # This is so that when we do the comparison, t0AcA is known at compile time.
772
+ limit_m = limit_m - tAcA[0][0]
773
+ limit_k = limit_k - tAcA[0][1]
774
+ # Read and cache indices for A
775
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
776
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
777
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
778
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
779
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
780
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
781
+ # This is very convoluted but idk a better way
782
+ # for tile_M=128, flat_divide gives (8, 16, K),
783
+ # then logical_divide gives ((8, 1), (8, 2), K).
784
+ tidx = thr_copy_A.thr_idx
785
+ tAmA = cute.logical_divide(
786
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
787
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
788
+
789
+ def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
790
+ # Prefetch mAIdx early, even before smem is free
791
+ tApA_k = None
792
+ if const_expr(pred):
793
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
794
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
795
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
796
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
797
+ gAIdx_cur = gAIdx[None, src_idx]
798
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
799
+ for k in cutlass.range(cols_per_thread):
800
+ col_idx = tAcA[0, 0, k][1]
801
+ if const_expr(not pred):
802
+ k_idx[k] = gAIdx_cur[col_idx]
803
+ else:
804
+ if tApA_k[k]:
805
+ k_idx[k] = gAIdx_cur[col_idx]
806
+ else:
807
+ k_idx[k] = -1
808
+ return k_idx, tApA_k
809
+
810
+ def prefetch_from_smem_fn(
811
+ a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
812
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
813
+ tApA_k = None
814
+ if const_expr(pred):
815
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
816
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
817
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
818
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
819
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
820
+ sAIdx_cur = sAIdx[None, dst_idx]
821
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
822
+ for k in cutlass.range(cols_per_thread):
823
+ col_idx = tAcA[0, 0, k][1]
824
+ k_idx[k] = sAIdx_cur[col_idx]
825
+ cute.arch.sync_warp()
826
+ with cute.arch.elect_one():
827
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
828
+ return k_idx, tApA_k
829
+
830
+ def copy_fn(
831
+ src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
832
+ ):
833
+ k_idx, tApA_k = k_idx_tApA_k
834
+ tApA_k_pred = None
835
+ if const_expr(pred):
836
+ tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
837
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
838
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
839
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
840
+ if tApA_m[m]:
841
+ cute.copy(
842
+ thr_copy_A,
843
+ tAmA[None, m, k_idx[k]],
844
+ tAsA[(None, m, k), dst_idx],
845
+ pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
846
+ )
847
+
848
+ return copy_fn, prefetch_from_gmem_fn if const_expr(
849
+ gAIdx is not None
850
+ ) else prefetch_from_smem_fn
851
+
852
+
853
+ @cute.jit
854
+ def gather_m_get_tma_copy_fn(
855
+ tma_atom: cute.CopyAtom,
856
+ mA: cute.Tensor, # (whatever, K)
857
+ sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
858
+ sAIdx: cute.Tensor, # (tile_M),
859
+ warp_idx: Int32,
860
+ num_warps: int,
861
+ num_cta: int = 1,
862
+ ) -> Callable:
863
+ tile_M = cute.size(sAIdx, mode=[0])
864
+ tile_K = cute.size(sA[None, None, 0]) // tile_M
865
+ assert tile_M % 4 == 0
866
+ # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
867
+ cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
868
+
869
+ copy_AIdx_s2r = cute.make_tiled_copy_tv(
870
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
871
+ cute.make_layout(num_warps), # thr_layout
872
+ cute.make_layout(4), # val_layout
873
+ )
874
+ warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
875
+ tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
876
+ # ((4, 1), 8, (64, 1), STAGE)
877
+ tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
878
+ tSR_rAIdx = load_s2r(tSR_sAIdx)
879
+ tma_desc_ptr = get_tma_desc_addr(tma_atom)
880
+ tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
881
+
882
+ def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
883
+ col_idx = tile_K * src_idx
884
+ for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
885
+ row_indices = [tSR_rAIdx[v, m] for v in range(4)]
886
+ smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
887
+ with cute.arch.elect_one():
888
+ tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
889
+
890
+ return copy_fn
build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from functools import lru_cache
5
+ from dataclasses import dataclass, fields
6
+
7
+ import torch
8
+
9
+ try:
10
+ from triton.tools.disasm import extract
11
+ except ImportError:
12
+ extract = None
13
+
14
+ import cutlass
15
+ import cutlass.cute as cute
16
+ from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
+ from cutlass.base_dsl.typing import JitArgument
18
+ from cutlass.cutlass_dsl import NumericMeta
19
+
20
+
21
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
22
+
23
+
24
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
25
+ cute_compile_og = cute.compile
26
+
27
+
28
+ torch2cute_dtype_map = {
29
+ torch.float16: Float16,
30
+ torch.bfloat16: BFloat16,
31
+ torch.float32: Float32,
32
+ torch.int32: Int32,
33
+ torch.int64: Int64,
34
+ }
35
+
36
+
37
+ @lru_cache
38
+ def get_max_active_clusters(cluster_size):
39
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
40
+
41
+
42
+ @lru_cache
43
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
44
+ return torch.cuda.get_device_capability(device)
45
+
46
+
47
+ @dataclass
48
+ class ParamsBase:
49
+ def __extract_mlir_values__(self):
50
+ all_fields = [getattr(self, field.name) for field in fields(self)]
51
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
52
+ values, self._values_pos = [], []
53
+ for obj in non_constexpr_fields:
54
+ obj_values = cutlass.extract_mlir_values(obj)
55
+ values += obj_values
56
+ self._values_pos.append(len(obj_values))
57
+ return values
58
+
59
+ def __new_from_mlir_values__(self, values):
60
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
61
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
62
+ non_constexpr_fields = {
63
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
64
+ }
65
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
66
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
67
+ values = values[n_items:]
68
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
69
+
70
+
71
+ @dataclass
72
+ class ArgumentsBase(JitArgument):
73
+ def __c_pointers__(self):
74
+ all_fields = [getattr(self, field.name) for field in fields(self)]
75
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
76
+ c_ptrs = []
77
+ for obj in non_constexpr_fields:
78
+ if hasattr(obj, "__c_pointers__"):
79
+ c_ptrs.extend(obj.__c_pointers__())
80
+ return c_ptrs
81
+
82
+ def __get_mlir_types__(self):
83
+ all_fields = [getattr(self, field.name) for field in fields(self)]
84
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
85
+ types, self._values_pos = [], []
86
+ for obj in non_constexpr_fields:
87
+ if hasattr(obj, "__get_mlir_types__"):
88
+ obj_types = obj.__get_mlir_types__()
89
+ types.extend(obj_types)
90
+ self._values_pos.append(len(obj_types))
91
+ else:
92
+ self._values_pos.append(0)
93
+ return types
94
+
95
+ def __new_from_mlir_values__(self, values):
96
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
97
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
98
+ non_constexpr_fields = {
99
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
100
+ }
101
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
102
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
103
+ values = values[n_items:]
104
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from cutlass import Int32, const_expr
8
+
9
+
10
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
11
+ """Transpose the first two dimensions of a tensor on smem."""
12
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
13
+ order = (1, 0, *range(2, cute.rank(a)))
14
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
15
+
16
+
17
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
18
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
19
+
20
+
21
+ def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
22
+ shape = (*a.shape[:dim], size, *a.shape[dim:])
23
+ stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
24
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
25
+
26
+
27
+ @cute.jit
28
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
29
+ assert t.element_type.width == 16
30
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
31
+ t_u32 = cute.recast_tensor(t, Int32)
32
+
33
+ quad_idx = cute.arch.lane_idx() % 4
34
+ lane_03 = quad_idx == 0 or quad_idx == 3
35
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
36
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
37
+ # upper_map = [0, 3, 1, 2]
38
+ # lower_map = [1, 2, 0, 3]
39
+ # upper_idx = upper_map[quad_idx]
40
+ # indexing isn't supported so we have to do arithmetic
41
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
42
+ lower_idx = upper_idx ^ 1
43
+
44
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
45
+ width = 4
46
+ mask = cute.arch.WARP_SIZE - width
47
+ clamp = cute.arch.WARP_SIZE - 1
48
+ mask_and_clamp = mask << 8 | clamp
49
+
50
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
51
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
52
+ upper0 = upper if lane_03 else lower
53
+ lower0 = lower if lane_03 else upper
54
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
55
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
56
+ t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
57
+ t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
58
+
59
+
60
+ @cute.jit
61
+ def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
62
+ """Permute and shuffle within 4 threads to change the layout from
63
+ T0 | T1 | T2 | T3
64
+ a b | c d | e f | g h
65
+ to
66
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
67
+ a | b | c | d | e | f | g | h
68
+ This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
69
+ """
70
+
71
+ assert t.element_type.width == 32
72
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
73
+
74
+ quad_idx = cute.arch.lane_idx() % 4
75
+ # left_map = [0, 2, 1, 3]
76
+ # right_map = [2, 0, 3, 1]
77
+ # indexing isn't supported so we have to do arithmetic
78
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
79
+ right_idx = left_idx ^ 0b10
80
+
81
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
82
+ width = 4
83
+ mask = cute.arch.WARP_SIZE - width
84
+ clamp = cute.arch.WARP_SIZE - 1
85
+ mask_and_clamp = mask << 8 | clamp
86
+
87
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
88
+ for r in cutlass.range(2, unroll_full=True):
89
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
90
+ # a b | c d | e f | g h -> a b | c d | f e | h g
91
+ left0 = left if quad_idx < 2 else right
92
+ right0 = right if quad_idx < 2 else left
93
+ # a b | c d | f e | h g -> a b | f d | c e | h g
94
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
95
+ # a b | f d | c e | h g -> a e | f b | c g | h d
96
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
97
+ # a e | f b | c g | h d -> a e | b f | c g | d h
98
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
99
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
100
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
101
+
102
+
103
+ @cute.jit
104
+ def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
105
+ """Permute and shuffle within 4 threads to change the layout from
106
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
107
+ a | b | c | d | e | f | g | h
108
+ to
109
+ T0 | T1 | T2 | T3
110
+ a b | c d | e f | g h
111
+ This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
112
+ """
113
+
114
+ assert t.element_type.width == 32
115
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
116
+
117
+ quad_idx = cute.arch.lane_idx() % 4
118
+ # left_map = [0, 2, 1, 3]
119
+ # right_map = [1, 3, 0, 2]
120
+ # indexing isn't supported so we have to do arithmetic
121
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
122
+ right_idx = left_idx ^ 0b01
123
+
124
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
125
+ width = 4
126
+ mask = cute.arch.WARP_SIZE - width
127
+ clamp = cute.arch.WARP_SIZE - 1
128
+ mask_and_clamp = mask << 8 | clamp
129
+
130
+ # This is just the inverse of permute_Cregs_b32_for_stsm
131
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
132
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
133
+ for r in cutlass.range(2, unroll_full=True):
134
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
135
+ # a e | b f | c g | d h -> a e | f b | c g | h d
136
+ left0 = left if quad_idx % 2 == 0 else right
137
+ right0 = right if quad_idx % 2 == 0 else left
138
+ # a e | f b | c g | h d -> a b | f d | c e | h g
139
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
140
+ # a b | f d | c e | h g -> a b | c d | f e | h g
141
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
142
+ # a b | c d | f e | h g -> a b | c d | e f | g h
143
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
144
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
145
+
146
+
147
+ @cute.jit
148
+ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
149
+ return cute.make_layout(
150
+ tuple(l.shape for l in layouts),
151
+ stride=tuple(l.stride for l in layouts),
152
+ )
153
+
154
+
155
+ def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
156
+ """
157
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
158
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
159
+ """
160
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
161
+ shape = (
162
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
163
+ (
164
+ acc_layout_col_major.shape[0][0],
165
+ *acc_layout_col_major.shape[0][2:],
166
+ acc_layout_col_major.shape[2],
167
+ ), # MMA_N
168
+ *acc_layout_col_major.shape[3:],
169
+ )
170
+ stride = (
171
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
172
+ (
173
+ acc_layout_col_major.stride[0][0],
174
+ *acc_layout_col_major.stride[0][2:],
175
+ acc_layout_col_major.stride[2],
176
+ ), # MMA_N
177
+ *acc_layout_col_major.stride[3:],
178
+ )
179
+ if const_expr(transpose):
180
+ shape = (shape[1], shape[0], *shape[2:])
181
+ stride = (stride[1], stride[0], *stride[2:])
182
+ acc_layout_mn = cute.make_layout(shape, stride=stride)
183
+ return cute.composition(acc_layout, acc_layout_mn)
184
+
185
+
186
+ def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
187
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
188
+
189
+
190
+ def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
191
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
192
+
193
+
194
+ @cute.jit
195
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
196
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
197
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
198
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
199
+ # TODO: Sm90 FP8
200
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
201
+ l = cute.logical_divide(
202
+ acc_layout, ((None, None, 2), None, None)
203
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
204
+ rA_mma_view = cute.make_layout(
205
+ (
206
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
207
+ l.shape[1],
208
+ (l.shape[0][2][1], l.shape[2]),
209
+ ),
210
+ stride=(
211
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
212
+ l.stride[1],
213
+ (l.stride[0][2][1], l.stride[2]),
214
+ ),
215
+ )
216
+ else: # Sm80
217
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
218
+ l = cute.logical_divide(acc_layout, (None, None, 2))
219
+ rA_mma_view = cute.make_layout(
220
+ (
221
+ (l.shape[0], l.shape[2][0]),
222
+ l.shape[1],
223
+ l.shape[2][1],
224
+ ),
225
+ stride=(
226
+ (l.stride[0], l.stride[2][0]),
227
+ l.stride[1],
228
+ l.stride[2][1],
229
+ ),
230
+ )
231
+ return rA_mma_view
232
+
233
+
234
+ def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
235
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
236
+
237
+
238
+ def convert_layout_zero_stride(
239
+ input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
240
+ ) -> cute.Layout:
241
+ layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
242
+ # Group the modes with non-zero stride in the ref_layout together,
243
+ # and the modes with zero stride together
244
+ layout_flat = cute.flatten(layout)
245
+ ref_layout_flat = cute.flatten(ref_layout)
246
+ nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
247
+ zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
248
+ # There's an edge case when all modes are zero stride
249
+ new_shape = (
250
+ tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
251
+ tuple(layout_flat[i].shape for i in zero_modes),
252
+ )
253
+ new_stride = (
254
+ tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
255
+ tuple(layout_flat[i].stride for i in zero_modes),
256
+ )
257
+ out_layout = cute.make_layout(new_shape, stride=new_stride)
258
+ if const_expr(isinstance(input, cute.Tensor)):
259
+ return cute.make_tensor(input.iterator, out_layout)
260
+ else:
261
+ return out_layout
262
+
263
+
264
+ def mma_partition_C_vec(
265
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
266
+ ) -> cute.Tensor:
267
+ assert cute.rank(sVec) == 2
268
+ assert sVec.stride[0] == 1
269
+ stage = sVec.shape[1]
270
+ shape = (
271
+ (sVec.shape[0], expand_shape, stage)
272
+ if const_expr(is_colvec)
273
+ else (expand_shape, sVec.shape[0], stage)
274
+ )
275
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
276
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
277
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
278
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
279
+
280
+
281
+ def mma_partition_A_vec(
282
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
283
+ ) -> cute.Tensor:
284
+ assert cute.rank(sVec) == 2
285
+ assert sVec.stride[0] == 1
286
+ stage = sVec.shape[1]
287
+ shape = (
288
+ (sVec.shape[0], expand_shape, stage)
289
+ if const_expr(is_colvec)
290
+ else (expand_shape, sVec.shape[0], stage)
291
+ )
292
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
293
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
294
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
295
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
build/torch211-cxx11-cu128-x86_64-linux/quantize.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Transformer Engine NVFP4 quantization helper.
5
+
6
+ This file is intended as a customer-facing example for preparing KV tensors
7
+ for the KVFP4 attention kernel:
8
+ - BF16/FP16 K/V input
9
+ - packed E2M1 FP4 data from Transformer Engine
10
+ - E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout
11
+ - one FP32 tensor/global scale per tensor
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Tuple
18
+
19
+ import torch
20
+
21
+
22
+ NVFP4_BLOCK_SIZE = 16
23
+ NVFP4_FP4_MAX = 6.0
24
+ NVFP4_FP8_E4M3_MAX = 448.0
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class Nvfp4QuantizedTensor:
29
+ """Packed NVFP4 tensor plus dequantization metadata.
30
+
31
+ Attributes
32
+ ----------
33
+ data : torch.Tensor
34
+ Packed E2M1 FP4 data from Transformer Engine. The last dimension is
35
+ half of the original logical last dimension because each byte stores
36
+ two FP4 values.
37
+ scale_128x4 : torch.Tensor
38
+ E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage.
39
+ global_scale : torch.Tensor
40
+ FP32 tensor/global dequant scale.
41
+ logical_scale_shape : tuple[int, int]
42
+ Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling.
43
+ original_shape : tuple[int, ...]
44
+ Original BF16/FP16 tensor shape before quantization.
45
+ """
46
+
47
+ data: torch.Tensor
48
+ scale_128x4: torch.Tensor
49
+ global_scale: torch.Tensor
50
+ logical_scale_shape: Tuple[int, int]
51
+ original_shape: Tuple[int, ...]
52
+
53
+
54
+ def _round_up(x: int, multiple: int) -> int:
55
+ return ((int(x) + multiple - 1) // multiple) * multiple
56
+
57
+
58
+ def nvfp4_scale_128x4_offset(
59
+ row: torch.Tensor,
60
+ col: torch.Tensor,
61
+ scale_cols: int,
62
+ ) -> torch.Tensor:
63
+ """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage.
64
+
65
+ Parameters
66
+ ----------
67
+ row : torch.Tensor
68
+ Logical row indices.
69
+ col : torch.Tensor
70
+ Logical scale-column indices.
71
+ scale_cols : int
72
+ Logical number of scale columns before padding to a multiple of 4.
73
+
74
+ Returns
75
+ -------
76
+ torch.Tensor
77
+ Flat offsets into the padded 128x4 tiled storage.
78
+ """
79
+
80
+ tiles_n = _round_up(scale_cols, 4) // 4
81
+ tile_m = row // 128
82
+ tile_n = col // 4
83
+ outer = row % 128
84
+ inner = col % 4
85
+ return (
86
+ (tile_m * tiles_n + tile_n) * 512
87
+ + (outer % 32) * 16
88
+ + (outer // 32) * 4
89
+ + inner
90
+ )
91
+
92
+
93
+ def swizzle_nvfp4_scale_to_128x4(
94
+ scale: torch.Tensor,
95
+ *,
96
+ rows: int,
97
+ cols: int,
98
+ ) -> torch.Tensor:
99
+ """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout.
100
+
101
+ Parameters
102
+ ----------
103
+ scale : torch.Tensor
104
+ Logical rowwise scale tensor with at least shape ``[rows, cols]``.
105
+ rows : int
106
+ Number of logical rows to convert.
107
+ cols : int
108
+ Number of logical scale columns to convert.
109
+
110
+ Returns
111
+ -------
112
+ torch.Tensor
113
+ Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)``
114
+ and swizzled into 128x4 tiled storage.
115
+ """
116
+
117
+ if scale.ndim != 2:
118
+ raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}")
119
+
120
+ rows = int(rows)
121
+ cols = int(cols)
122
+ padded_rows = _round_up(rows, 128)
123
+ padded_cols = _round_up(cols, 4)
124
+ if scale.shape[0] < rows or scale.shape[1] < cols:
125
+ raise ValueError(
126
+ "scale is smaller than the requested logical shape: "
127
+ f"got {tuple(scale.shape)}, need at least {(rows, cols)}"
128
+ )
129
+
130
+ logical = scale[:rows, :cols].contiguous()
131
+ if logical.shape != (padded_rows, padded_cols):
132
+ logical = torch.nn.functional.pad(
133
+ logical.to(torch.float32),
134
+ (0, padded_cols - cols, 0, padded_rows - rows),
135
+ ).to(scale.dtype)
136
+ swizzled = torch.empty_like(logical)
137
+
138
+ row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None]
139
+ col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :]
140
+ offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1)
141
+ swizzled.reshape(-1)[offset] = logical.reshape(-1)
142
+ return swizzled
143
+
144
+
145
+ def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor:
146
+ """Compute TE NVFP4 tensor/global dequant scale from rowwise amax.
147
+
148
+ Parameters
149
+ ----------
150
+ amax : torch.Tensor
151
+ Rowwise absolute maxima returned by Transformer Engine.
152
+
153
+ Returns
154
+ -------
155
+ torch.Tensor
156
+ FP32 global scale equal to ``amax / (448 * 6)``.
157
+ """
158
+
159
+ return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX)
160
+
161
+
162
+ def _import_te_nvfp4_quantizer():
163
+ try:
164
+ from transformer_engine.pytorch.tensor import NVFP4Quantizer
165
+ except Exception as exc: # pragma: no cover - environment dependent
166
+ raise RuntimeError(
167
+ "Transformer Engine NVFP4 quantization is unavailable. Install a "
168
+ "Transformer Engine build with its PyTorch dependencies, including "
169
+ "FlashAttention v3 when required by that TE build."
170
+ ) from exc
171
+ return NVFP4Quantizer
172
+
173
+
174
+ def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor:
175
+ """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine.
176
+
177
+ TE returns rowwise scales in logical padded layout. This helper returns
178
+ the scales in physical 128x4 tiled storage, so the attention kernel can
179
+ load them with ``nvfp4_scale_128x4_offset``.
180
+
181
+ Parameters
182
+ ----------
183
+ x : torch.Tensor
184
+ CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16,
185
+ and the flattened row dimension ``prod(x.shape[:-1])`` must also be
186
+ divisible by 16.
187
+
188
+ Returns
189
+ -------
190
+ Nvfp4QuantizedTensor
191
+ Packed FP4 data, 128x4-swizzled block scales, global scale, and shape
192
+ metadata needed by the KVFP4 attention kernel or by reference
193
+ dequantization.
194
+ """
195
+
196
+ if not x.is_cuda:
197
+ raise ValueError("NVFP4 quantization requires a CUDA tensor")
198
+ if x.dtype not in (torch.bfloat16, torch.float16):
199
+ raise TypeError(f"x must be bf16 or fp16, got {x.dtype}")
200
+ if x.ndim < 2:
201
+ raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}")
202
+ if x.shape[-1] % NVFP4_BLOCK_SIZE != 0:
203
+ raise ValueError(
204
+ f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}"
205
+ )
206
+
207
+ rows = 1
208
+ for dim in x.shape[:-1]:
209
+ rows *= int(dim)
210
+ if rows % NVFP4_BLOCK_SIZE != 0:
211
+ raise ValueError(
212
+ "flattened row dimension must be divisible by "
213
+ f"{NVFP4_BLOCK_SIZE}, got {rows}"
214
+ )
215
+
216
+ NVFP4Quantizer = _import_te_nvfp4_quantizer()
217
+ quantizer = NVFP4Quantizer(rowwise=True, columnwise=False)
218
+ qx = quantizer.quantize(x.contiguous())
219
+ meta = qx.get_metadata()
220
+
221
+ data = meta["rowwise_data"]
222
+ if data.dtype is not torch.uint8:
223
+ data = data.view(torch.uint8)
224
+ logical_scale = meta["rowwise_scale_inv"]
225
+ amax = meta["amax_rowwise"]
226
+ scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE
227
+ scale_128x4 = swizzle_nvfp4_scale_to_128x4(
228
+ logical_scale,
229
+ rows=rows,
230
+ cols=scale_cols,
231
+ )
232
+ global_scale = nvfp4_global_scale_from_amax(amax).contiguous()
233
+
234
+ return Nvfp4QuantizedTensor(
235
+ data=data,
236
+ scale_128x4=scale_128x4,
237
+ global_scale=global_scale,
238
+ logical_scale_shape=(rows, scale_cols),
239
+ original_shape=tuple(int(v) for v in x.shape),
240
+ )
241
+
242
+
243
+ def quantize_kv_bf16_to_nvfp4_128x4(
244
+ k: torch.Tensor,
245
+ v: torch.Tensor,
246
+ ) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]:
247
+ """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention.
248
+
249
+ Parameters
250
+ ----------
251
+ k : torch.Tensor
252
+ CUDA BF16 or FP16 K tensor.
253
+ v : torch.Tensor
254
+ CUDA BF16 or FP16 V tensor.
255
+
256
+ Returns
257
+ -------
258
+ tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]
259
+ Quantized K and V tensors with independent scales.
260
+ """
261
+
262
+ return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v)
263
+
264
+
265
+ def dequantize_nvfp4_128x4_to_bf16(
266
+ qx: Nvfp4QuantizedTensor,
267
+ *,
268
+ include_global_scale: bool = True,
269
+ ) -> torch.Tensor:
270
+ """Reference dequantization for validation.
271
+
272
+ This mirrors the kernel contract:
273
+ x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale
274
+
275
+ Parameters
276
+ ----------
277
+ qx : Nvfp4QuantizedTensor
278
+ Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``.
279
+ include_global_scale : bool, optional
280
+ If True, multiply by ``qx.global_scale`` after applying per-block
281
+ scales.
282
+
283
+ Returns
284
+ -------
285
+ torch.Tensor
286
+ BF16 tensor with shape ``qx.original_shape``.
287
+ """
288
+
289
+ data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8)
290
+ if data.shape[-1] * 2 != qx.original_shape[-1]:
291
+ raise ValueError(
292
+ "packed data last dimension does not match original shape: "
293
+ f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical"
294
+ )
295
+
296
+ rows, scale_cols = qx.logical_scale_shape
297
+ logical_dim = int(qx.original_shape[-1])
298
+ if scale_cols * NVFP4_BLOCK_SIZE != logical_dim:
299
+ raise ValueError(
300
+ "logical scale columns do not match original last dimension: "
301
+ f"{scale_cols} scale cols vs dim {logical_dim}"
302
+ )
303
+
304
+ fp4_lut = torch.tensor(
305
+ [
306
+ 0.0,
307
+ 0.5,
308
+ 1.0,
309
+ 1.5,
310
+ 2.0,
311
+ 3.0,
312
+ 4.0,
313
+ 6.0,
314
+ -0.0,
315
+ -0.5,
316
+ -1.0,
317
+ -1.5,
318
+ -2.0,
319
+ -3.0,
320
+ -4.0,
321
+ -6.0,
322
+ ],
323
+ dtype=torch.float32,
324
+ device=data.device,
325
+ )
326
+ packed = data.reshape(rows, logical_dim // 2)
327
+ lo = packed & 0x0F
328
+ hi = packed >> 4
329
+ values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device)
330
+ values[:, 0::2] = fp4_lut[lo.long()]
331
+ values[:, 1::2] = fp4_lut[hi.long()]
332
+
333
+ row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None]
334
+ col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :]
335
+ offset = nvfp4_scale_128x4_offset(row, col, scale_cols)
336
+ scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols)
337
+ scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32)
338
+ scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1)
339
+ out = values * scale
340
+ if include_global_scale:
341
+ global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32)
342
+ out = out * global_scale
343
+ return out.reshape(qx.original_shape).to(torch.bfloat16)
344
+
345
+
346
+ def _example() -> None:
347
+ device = torch.device("cuda")
348
+ k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16)
349
+ v = torch.randn_like(k)
350
+ k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v)
351
+ print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype)
352
+ print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype)
353
+ print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype)
354
+ print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype)
355
+ print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype)
356
+ print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ if not torch.cuda.is_available():
361
+ raise RuntimeError("quantize.py requires CUDA")
362
+ _example()
build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Host-side q2k <-> k2q index conversion for sparse attention.
5
+
6
+ These utilities prepare sparse metadata on the Python side for tests,
7
+ benchmarks, and other offline preprocessing flows. They are not kernel
8
+ runtime helpers, so they intentionally live outside `src/common`.
9
+
10
+ Sparse attention pattern:
11
+ - Each Q token independently selects up to topK KV blocks (blk_kv tokens each).
12
+ - Under GQA, all Q heads in one group share the same sparsity pattern,
13
+ so indices are defined at the head_kv level.
14
+
15
+ Shapes:
16
+ q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks),
17
+ trailing unused slots padded with -1
18
+ k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1
19
+ k2q_counts: [batch, head_kv, Nkv] int32
20
+
21
+ CSR reverse-index format:
22
+ q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices
23
+ k2q_row_ptr: [head_kv, total_rows + 1] int32
24
+ k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx
25
+ """
26
+
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+
31
+ from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100
32
+
33
+
34
+ def q2k_to_k2q(
35
+ q2k_indices: torch.Tensor,
36
+ num_kv_blocks: int,
37
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
38
+ """Convert q2k sparse indices to k2q representation.
39
+
40
+ For each KV block, find which Q tokens attend to it.
41
+
42
+ Args:
43
+ q2k_indices: [batch, head_kv, Sq, topK] int32.
44
+ For each Q token, the KV blocks it attends to. Unused slots must
45
+ be padded with -1.
46
+ num_kv_blocks: Total number of KV blocks (= Skv / blk_kv).
47
+
48
+ Returns:
49
+ k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32.
50
+ For each KV block, the Q token indices that attend to it,
51
+ left-packed and padded with -1. Last dim fixed to Sq (upper bound).
52
+ k2q_counts: [batch, head_kv, num_kv_blocks] int32.
53
+ Actual number of Q tokens per KV block.
54
+ """
55
+ B, H, Sq, topK = q2k_indices.shape
56
+ device = q2k_indices.device
57
+ N = Sq * topK
58
+
59
+ kv_flat = q2k_indices.reshape(B, H, N).long()
60
+ valid_flat = kv_flat >= 0
61
+ q_flat = (
62
+ torch.arange(Sq, device=device)
63
+ .unsqueeze(-1)
64
+ .expand(Sq, topK)
65
+ .reshape(N)
66
+ .unsqueeze(0)
67
+ .unsqueeze(0)
68
+ .expand(B, H, N)
69
+ )
70
+
71
+ k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device)
72
+ safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat))
73
+ k2q_counts.scatter_add_(
74
+ 2,
75
+ safe_kv_flat,
76
+ valid_flat.to(torch.int32),
77
+ )
78
+
79
+ sort_keys = torch.where(
80
+ valid_flat,
81
+ kv_flat,
82
+ torch.full_like(kv_flat, num_kv_blocks),
83
+ )
84
+ sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True)
85
+ sorted_q = q_flat.gather(-1, sort_idx)
86
+ sorted_valid = valid_flat.gather(-1, sort_idx)
87
+
88
+ offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device)
89
+ offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long()
90
+
91
+ global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N)
92
+ group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1))
93
+ pos_in_group = global_pos - group_offset
94
+
95
+ k2q_indices = torch.full(
96
+ (B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device
97
+ )
98
+ flat_k2q = k2q_indices.reshape(B, H, -1)
99
+ flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group
100
+ for b in range(B):
101
+ for h in range(H):
102
+ valid = sorted_valid[b, h]
103
+ flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int()
104
+
105
+ return k2q_indices, k2q_counts
106
+
107
+
108
+ def k2q_to_q2k(
109
+ k2q_indices: torch.Tensor,
110
+ k2q_counts: torch.Tensor,
111
+ Sq: int,
112
+ topK: int,
113
+ ) -> torch.Tensor:
114
+ """Convert dense k2q indices back to q2k representation.
115
+
116
+ Parameters
117
+ ----------
118
+ k2q_indices : torch.Tensor
119
+ Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values
120
+ are Q token indices padded with ``-1``.
121
+ k2q_counts : torch.Tensor
122
+ Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of
123
+ valid Q indices per KV block.
124
+ Sq : int
125
+ Q sequence length per batch item in this dense reference format.
126
+ topK : int
127
+ Maximum number of KV blocks selected per Q token.
128
+
129
+ Returns
130
+ -------
131
+ torch.Tensor
132
+ Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted
133
+ by KV block index with ``-1`` padding at the tail.
134
+ """
135
+ B, H, Nkv, _ = k2q_indices.shape
136
+ device = k2q_indices.device
137
+
138
+ q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device)
139
+ counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device)
140
+
141
+ for b in range(B):
142
+ for h in range(H):
143
+ for kv_blk in range(Nkv):
144
+ count = k2q_counts[b, h, kv_blk].item()
145
+ for j in range(count):
146
+ qt = k2q_indices[b, h, kv_blk, j].item()
147
+ if qt < 0:
148
+ continue
149
+ p = counters[b, h, qt].item()
150
+ if p < topK:
151
+ q2k[b, h, qt, p] = kv_blk
152
+ counters[b, h, qt] += 1
153
+
154
+ q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k)
155
+ _, sort_idx = q2k_sort_key.sort(dim=-1)
156
+ q2k = q2k.gather(-1, sort_idx)
157
+ return q2k
158
+
159
+
160
+ def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None:
161
+ if cu_seqlens.dtype != torch.int32:
162
+ raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}")
163
+ if cu_seqlens.ndim != 1:
164
+ raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}")
165
+ if cu_seqlens.numel() < 1:
166
+ raise ValueError(f"{name} must have at least one element")
167
+ if not cu_seqlens.is_contiguous():
168
+ raise ValueError(f"{name} must be contiguous")
169
+
170
+
171
+ def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor:
172
+ seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
173
+ return (seqlens_k + kv_block_size - 1) // kv_block_size
174
+
175
+
176
+ def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]:
177
+ rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist()
178
+ batch = len(rows_per_batch_cpu)
179
+ max_rows = max(rows_per_batch_cpu, default=0)
180
+ row_dtype = (
181
+ torch.int32
182
+ if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max
183
+ else torch.int64
184
+ )
185
+ row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype)
186
+ row_linear = 0
187
+ for kv_block_idx in range(max_rows):
188
+ for batch_idx, row_count in enumerate(rows_per_batch_cpu):
189
+ if kv_block_idx < row_count:
190
+ row_map_cpu[batch_idx, kv_block_idx] = row_linear
191
+ row_linear += 1
192
+ return row_map_cpu.to(rows_per_batch.device), row_linear
193
+
194
+
195
+ def build_k2q_csr_torch_reference(
196
+ q2k_indices: torch.Tensor,
197
+ cu_seqlens_q: torch.Tensor,
198
+ cu_seqlens_k: torch.Tensor,
199
+ kv_block_size: int,
200
+ ) -> tuple:
201
+ """Torch reference for q2k -> k2q CSR conversion.
202
+
203
+ Parameters
204
+ ----------
205
+ q2k_indices : torch.Tensor
206
+ Shape ``[head_kv, total_q, topK]``, dtype int32. Values are
207
+ batch-local KV block indices padded with ``-1``.
208
+ cu_seqlens_q : torch.Tensor
209
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
210
+ cu_seqlens_k : torch.Tensor
211
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
212
+ kv_block_size : int
213
+ Number of KV tokens per sparse block.
214
+
215
+ Returns
216
+ -------
217
+ tuple[torch.Tensor, torch.Tensor]
218
+ ``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape
219
+ ``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape
220
+ ``[head_kv, total_q * topK]``.
221
+ """
222
+ if kv_block_size <= 0:
223
+ raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}")
224
+ if q2k_indices.dtype != torch.int32:
225
+ raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}")
226
+ if q2k_indices.ndim != 3:
227
+ raise ValueError(
228
+ "q2k_indices must have shape [head_kv, total_q, topK], "
229
+ f"got {tuple(q2k_indices.shape)}"
230
+ )
231
+ _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q")
232
+ _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k")
233
+ if cu_seqlens_q.shape != cu_seqlens_k.shape:
234
+ raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]")
235
+ if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device:
236
+ raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device")
237
+
238
+ head_kv, total_q, topk = q2k_indices.shape
239
+ if total_q != int(cu_seqlens_q[-1].item()):
240
+ raise ValueError(
241
+ f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] "
242
+ f"({int(cu_seqlens_q[-1].item())})"
243
+ )
244
+
245
+ rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size)
246
+ row_map, total_rows = _build_packed_row_map(rows_per_batch)
247
+ nnz_upper_bound = total_q * topk
248
+
249
+ k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device)
250
+ k2q_q_indices = torch.full(
251
+ (head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device
252
+ )
253
+ if total_rows == 0 or total_q == 0 or topk == 0:
254
+ return k2q_row_ptr, k2q_q_indices
255
+
256
+ counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device)
257
+ total_entries = total_q * topk
258
+ row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64
259
+ row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device)
260
+ q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device)
261
+ valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device)
262
+ rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist()
263
+ q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist()
264
+ entry_cursor = 0
265
+
266
+ for batch_idx, kv_rows in enumerate(rows_per_batch_cpu):
267
+ q_start = q_cu_cpu[batch_idx]
268
+ q_end = q_cu_cpu[batch_idx + 1]
269
+ q_len = q_end - q_start
270
+ if q_len == 0:
271
+ continue
272
+ num_entries = q_len * topk
273
+ q2k_batch = q2k_indices[:, q_start:q_end, :]
274
+ valid_batch = q2k_batch >= 0
275
+ if valid_batch.any():
276
+ max_valid_kv = int(q2k_batch[valid_batch].max().item())
277
+ if max_valid_kv >= kv_rows:
278
+ raise ValueError(
279
+ f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, "
280
+ f"but that batch only has {kv_rows} logical kv blocks"
281
+ )
282
+ kv_flat = q2k_batch.reshape(head_kv, num_entries).long()
283
+ valid_flat = valid_batch.reshape(head_kv, num_entries)
284
+ safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat))
285
+ row_map_batch = row_map[batch_idx]
286
+ row_flat = row_map_batch[safe_kv_flat]
287
+ q_flat = (
288
+ torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32)
289
+ .view(1, q_len, 1)
290
+ .expand(head_kv, q_len, topk)
291
+ .reshape(head_kv, num_entries)
292
+ )
293
+ row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat
294
+ q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat
295
+ valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat
296
+ counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32))
297
+ entry_cursor += num_entries
298
+
299
+ k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32)
300
+
301
+ sort_stride = max(total_q, 1)
302
+ invalid_key = total_rows * sort_stride
303
+ max_sort_key = invalid_key + max(total_q - 1, 0)
304
+ if max_sort_key < torch.iinfo(torch.int32).max:
305
+ sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32)
306
+ sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all]
307
+ else:
308
+ sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64)
309
+ sort_keys[valid_all] = (
310
+ row_all[valid_all].to(torch.int64) * sort_stride
311
+ + q_all[valid_all].to(torch.int64)
312
+ )
313
+ _, sort_idx = sort_keys.sort(dim=1, stable=True)
314
+ sorted_q = q_all.gather(1, sort_idx)
315
+
316
+ valid_counts = valid_all.sum(dim=1)
317
+ write_mask = (
318
+ torch.arange(total_entries, device=q2k_indices.device)
319
+ .unsqueeze(0)
320
+ .expand(head_kv, -1)
321
+ < valid_counts.unsqueeze(1)
322
+ )
323
+ k2q_q_indices[write_mask] = sorted_q[write_mask]
324
+
325
+ return k2q_row_ptr, k2q_q_indices
326
+
327
+
328
+ _K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100()
329
+
330
+
331
+ def build_k2q_csr(
332
+ q2k_indices: torch.Tensor,
333
+ cu_seqlens_q: torch.Tensor,
334
+ cu_seqlens_k: torch.Tensor,
335
+ kv_block_size: int,
336
+ *,
337
+ total_k: Optional[int] = None,
338
+ max_seqlen_k: Optional[int] = None,
339
+ max_seqlen_q: Optional[int] = None,
340
+ total_rows: Optional[int] = None,
341
+ qhead_per_kv: int = 1,
342
+ return_schedule: bool = False,
343
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]:
344
+ """Build the public k2q CSR reverse index on GPU.
345
+
346
+ Runtime construction does not read device-side ``cu_seqlens`` on the host,
347
+ so callers must provide size hints such as ``total_k`` from already-known
348
+ tensor shapes.
349
+
350
+ Parameters
351
+ ----------
352
+ q2k_indices : torch.Tensor
353
+ Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are
354
+ batch-local KV block indices with trailing ``-1`` padding.
355
+ cu_seqlens_q : torch.Tensor
356
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
357
+ cu_seqlens_k : torch.Tensor
358
+ Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
359
+ kv_block_size : int
360
+ Number of KV tokens per sparse block.
361
+ total_k : int
362
+ Total KV token count. Required; normally ``k.shape[0]`` for dense KV
363
+ or ``sum(kv_segment_lens)`` for paged KV.
364
+ max_seqlen_k : int, optional
365
+ Maximum KV sequence length. Passing this avoids recomputing a bound.
366
+ max_seqlen_q : int, optional
367
+ Maximum Q sequence length.
368
+ total_rows : int, optional
369
+ Total number of packed KV-block rows across the batch. If omitted,
370
+ the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``.
371
+ qhead_per_kv : int, optional
372
+ Number of Q heads per KV head under GQA.
373
+ return_schedule : bool, optional
374
+ If True, also return the sparse forward schedule object produced by the
375
+ SM100 builder.
376
+
377
+ Returns
378
+ -------
379
+ tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object]
380
+ ``(k2q_row_ptr, k2q_q_indices)`` or
381
+ ``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on
382
+ the same CUDA device as ``q2k_indices``.
383
+ """
384
+ if total_k is None:
385
+ raise ValueError("build_k2q_csr requires total_k from k.shape[0]")
386
+ if kv_block_size <= 0:
387
+ raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}")
388
+ if q2k_indices.dtype != torch.int32:
389
+ raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}")
390
+ if q2k_indices.ndim != 3:
391
+ raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}")
392
+ if not q2k_indices.is_contiguous():
393
+ raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]")
394
+ _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q")
395
+ _validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k")
396
+ if cu_seqlens_q.shape != cu_seqlens_k.shape:
397
+ raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]")
398
+ if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device:
399
+ raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device")
400
+ return _K2Q_CSR_BUILDER(
401
+ q2k_indices,
402
+ cu_seqlens_q,
403
+ cu_seqlens_k,
404
+ total_k=int(total_k),
405
+ blk_kv=int(kv_block_size),
406
+ max_seqlen_k=max_seqlen_k,
407
+ max_seqlen_q=max_seqlen_q,
408
+ total_rows=total_rows,
409
+ qhead_per_kv=qhead_per_kv,
410
+ return_schedule=return_schedule,
411
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Persistent AOT cache for CuTe DSL compiled kernels.
5
+
6
+ Saves compiled TVM FFI kernels as .o files on first compile,
7
+ loads them on subsequent runs to skip JIT compilation.
8
+
9
+ Environment variables:
10
+ MM_SPARSE_ATTN_AOT_CACHE: Override cache directory
11
+ (default: ~/.cache/minfer/mm_sparse_attn)
12
+ MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely
13
+ """
14
+
15
+ import hashlib
16
+ import os
17
+ import time
18
+
19
+ import cutlass.cute as cute
20
+
21
+ _AOT_CACHE_DIR = os.environ.get(
22
+ "MM_SPARSE_ATTN_AOT_CACHE",
23
+ os.path.expanduser("~/.cache/minfer/mm_sparse_attn"),
24
+ )
25
+ _AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1"
26
+
27
+ _loaded_modules: dict[str, object] = {}
28
+
29
+
30
+ def _key_to_path(key: tuple) -> str:
31
+ h = hashlib.sha256(repr(key).encode()).hexdigest()[:16]
32
+ name = str(key[0]).replace("/", "_")
33
+ return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}")
34
+
35
+
36
+ def try_load_aot(key: tuple):
37
+ if _AOT_DISABLE:
38
+ return None
39
+ obj_path = _key_to_path(key) + ".o"
40
+ if not os.path.isfile(obj_path):
41
+ return None
42
+ func_name = str(key[0])
43
+ try:
44
+ if obj_path not in _loaded_modules:
45
+ _loaded_modules[obj_path] = cute.runtime.load_module(
46
+ obj_path, enable_tvm_ffi=True
47
+ )
48
+ return getattr(_loaded_modules[obj_path], func_name)
49
+ except Exception as e:
50
+ print(f"[aot_cache] Failed to load {obj_path}: {e}")
51
+ return None
52
+
53
+
54
+ def save_aot(key: tuple, compiled) -> None:
55
+ if _AOT_DISABLE:
56
+ return
57
+ if not hasattr(compiled, "export_to_c"):
58
+ return
59
+ obj_path = _key_to_path(key) + ".o"
60
+ os.makedirs(_AOT_CACHE_DIR, exist_ok=True)
61
+ tmp_path = obj_path + f".tmp.{os.getpid()}"
62
+ func_name = str(key[0])
63
+ try:
64
+ t0 = time.time()
65
+ compiled.export_to_c(tmp_path, function_name=func_name)
66
+ os.replace(tmp_path, obj_path)
67
+ dt = time.time() - t0
68
+ print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)")
69
+ except Exception as e:
70
+ print(f"[aot_cache] Failed to save {func_name}: {e}")
71
+ if os.path.exists(tmp_path):
72
+ os.remove(tmp_path)
build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Int32
7
+ from cutlass.cutlass_dsl import T, dsl_user_op
8
+ from cutlass._mlir.dialects import llvm
9
+
10
+
11
+ @dsl_user_op
12
+ def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
13
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
14
+ state = llvm.inline_asm(
15
+ T.i32(),
16
+ [lock_ptr_i64],
17
+ "ld.global.acquire.gpu.b32 $0, [$1];",
18
+ "=r,l",
19
+ has_side_effects=True,
20
+ is_align_stack=False,
21
+ asm_dialect=llvm.AsmDialect.AD_ATT,
22
+ )
23
+ return cutlass.Int32(state)
24
+
25
+
26
+ @dsl_user_op
27
+ def red_relaxed(
28
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
29
+ ) -> None:
30
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
31
+ llvm.inline_asm(
32
+ None,
33
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
34
+ "red.relaxed.gpu.global.add.s32 [$0], $1;",
35
+ "l,r",
36
+ has_side_effects=True,
37
+ is_align_stack=False,
38
+ asm_dialect=llvm.AsmDialect.AD_ATT,
39
+ )
40
+
41
+
42
+ @dsl_user_op
43
+ def red_release(
44
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
45
+ ) -> None:
46
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
47
+ llvm.inline_asm(
48
+ None,
49
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
50
+ "red.release.gpu.global.add.s32 [$0], $1;",
51
+ "l,r",
52
+ has_side_effects=True,
53
+ is_align_stack=False,
54
+ asm_dialect=llvm.AsmDialect.AD_ATT,
55
+ )
56
+
57
+
58
+ @cute.jit
59
+ def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
60
+ flag_ptr = lock_ptr + flag_offset
61
+ if thread_idx == 0:
62
+ read_val = Int32(0)
63
+ while read_val != val:
64
+ read_val = ld_acquire(flag_ptr)
65
+
66
+
67
+ @cute.jit
68
+ def arrive_inc(
69
+ lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
70
+ ) -> None:
71
+ flag_ptr = lock_ptr + flag_offset
72
+ if thread_idx == 0:
73
+ red_release(flag_ptr, val)
74
+ # red_relaxed(flag_ptr, val)
build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, Boolean, const_expr
9
+ from cutlass.cute.nvgpu import tcgen05
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+ from . import mma_sm100_desc as sm100_desc
13
+
14
+
15
+ @cute.jit
16
+ def gemm_w_idx(
17
+ tiled_mma: cute.TiledMma,
18
+ acc: cute.Tensor,
19
+ tCrA: cute.Tensor,
20
+ tCrB: cute.Tensor,
21
+ A_idx: Optional[Int32] = None,
22
+ B_idx: Optional[Int32] = None,
23
+ zero_init: bool | Boolean = False,
24
+ swap_AB: bool = False,
25
+ num_unroll_groups: int = 1,
26
+ ) -> None:
27
+ if const_expr(swap_AB):
28
+ return gemm_w_idx(
29
+ tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
30
+ )
31
+ else:
32
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
33
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
34
+
35
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
36
+ for k in cutlass.range(
37
+ cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups
38
+ ):
39
+ mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
40
+ cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
41
+
42
+
43
+ @cute.jit
44
+ def gemm_ptx_w_idx(
45
+ tiled_mma: cute.TiledMma,
46
+ acc: cute.Tensor,
47
+ tCrA: cute.Tensor,
48
+ tCrB: cute.Tensor,
49
+ sA: Optional[cute.Tensor],
50
+ sB: cute.Tensor,
51
+ A_idx: Optional[Int32] = None,
52
+ B_idx: Optional[Int32] = None,
53
+ zero_init: bool | Boolean = False,
54
+ cta_group: int = 1,
55
+ **kwargs,
56
+ ) -> None:
57
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
58
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
59
+ sA_cur = None
60
+ if const_expr(sA is not None):
61
+ sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
62
+ sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
63
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
64
+ acc_tmem_addr = acc.iterator.toint()
65
+ gemm_ptx_partial(
66
+ mma_atom.op,
67
+ acc_tmem_addr,
68
+ rA,
69
+ rB,
70
+ sA_cur,
71
+ sB_cur,
72
+ zero_init=zero_init,
73
+ cta_group=cta_group,
74
+ **kwargs,
75
+ )
76
+
77
+
78
+ @cute.jit
79
+ def gemm(
80
+ tiled_mma: cute.TiledMma,
81
+ acc: cute.Tensor,
82
+ tCrA: cute.Tensor,
83
+ tCrB: cute.Tensor,
84
+ zero_init: bool | Boolean = False,
85
+ ) -> None:
86
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
87
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
88
+ mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
89
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
90
+
91
+
92
+ def i64_to_i32x2(i: int) -> Tuple[int, int]:
93
+ """Convert a 64-bit integer to a tuple of two 32-bit integers."""
94
+ return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
95
+
96
+
97
+ @cute.jit
98
+ def gemm_ptx(
99
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
100
+ acc: cute.Tensor,
101
+ tCrA: cute.Tensor,
102
+ tCrB: cute.Tensor,
103
+ sA: Optional[cute.Tensor],
104
+ sB: cute.Tensor,
105
+ zero_init: bool | Boolean = False,
106
+ ) -> None:
107
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
108
+ if const_expr(not is_ts):
109
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
110
+ sA_layout = sA.layout if sA is not None else None
111
+ sB_layout = sB.layout
112
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
113
+ if const_expr(not is_ts):
114
+ sA_swizzle = sA.iterator.type.swizzle_type
115
+ smem_desc_base_a: int = const_expr(
116
+ sm100_desc.make_smem_desc_base(
117
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
118
+ sA_swizzle,
119
+ sm100_desc.Major.K
120
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
121
+ else sm100_desc.Major.MN,
122
+ )
123
+ )
124
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
125
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
126
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
127
+ else:
128
+ smem_desc_base_a = None
129
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
130
+ sB_swizzle = sB.iterator.type.swizzle_type
131
+ smem_desc_base_b: int = const_expr(
132
+ sm100_desc.make_smem_desc_base(
133
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
134
+ sB_swizzle,
135
+ sm100_desc.Major.K
136
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
137
+ else sm100_desc.Major.MN,
138
+ )
139
+ )
140
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
141
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
142
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
143
+
144
+ if const_expr(not is_ts):
145
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
146
+ sA[None, None, 0].iterator
147
+ )
148
+ else:
149
+ smem_desc_start_a_lo = None
150
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
151
+ sB[None, None, 0].iterator
152
+ )
153
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
154
+ if const_expr(not is_ts):
155
+ smem_desc_a_lo = smem_desc_start_a_lo + (
156
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
157
+ )
158
+ smem_desc_b_lo = smem_desc_start_b_lo + (
159
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
160
+ )
161
+ # with cute.arch.elect_one():
162
+ # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
163
+ # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
164
+ with cute.arch.elect_one():
165
+ if const_expr(not is_ts):
166
+ llvm.inline_asm(
167
+ None,
168
+ [
169
+ acc.iterator.toint().ir_value(),
170
+ smem_desc_a_lo.ir_value(),
171
+ smem_desc_b_lo.ir_value(),
172
+ Int32(not zero_init or k != 0).ir_value(),
173
+ ],
174
+ "{\n\t"
175
+ ".reg .pred p;\n\t"
176
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
177
+ ".reg .b32 idesc;\n\t"
178
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
179
+ f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
180
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
181
+ "setp.ne.b32 p, $3, 0;\n\t"
182
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
183
+ "}\n",
184
+ "r,r,r,r",
185
+ has_side_effects=True,
186
+ is_align_stack=False,
187
+ asm_dialect=llvm.AsmDialect.AD_ATT,
188
+ )
189
+ else:
190
+ llvm.inline_asm(
191
+ None,
192
+ [
193
+ acc.iterator.toint().ir_value(),
194
+ tCrA[None, None, k].iterator.toint().ir_value(),
195
+ smem_desc_b_lo.ir_value(),
196
+ Int32(not zero_init or k != 0).ir_value(),
197
+ ],
198
+ "{\n\t"
199
+ ".reg .pred p;\n\t"
200
+ ".reg .b64 smem_desc_b;\n\t"
201
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
202
+ "setp.ne.b32 p, $3, 0;\n\t"
203
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
204
+ "}\n",
205
+ "r,r,r,r",
206
+ has_side_effects=True,
207
+ is_align_stack=False,
208
+ asm_dialect=llvm.AsmDialect.AD_ATT,
209
+ )
210
+
211
+
212
+ @cute.jit
213
+ def gemm_ptx_loop(
214
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
215
+ acc: cute.Tensor,
216
+ tCrA: cute.Tensor,
217
+ tCrB: cute.Tensor,
218
+ sA: Optional[cute.Tensor],
219
+ sB: cute.Tensor,
220
+ zero_init: bool | Boolean = False,
221
+ ) -> None:
222
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
223
+ if const_expr(not is_ts):
224
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
225
+ sA_layout = sA.layout if sA is not None else tCrA.layout
226
+ sB_layout = sB.layout
227
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
228
+ if const_expr(not is_ts):
229
+ sA_swizzle = sA.iterator.type.swizzle_type
230
+ smem_desc_base_a: int = const_expr(
231
+ sm100_desc.make_smem_desc_base(
232
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
233
+ sA_swizzle,
234
+ sm100_desc.Major.K
235
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
236
+ else sm100_desc.Major.MN,
237
+ )
238
+ )
239
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
240
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
241
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
242
+ else:
243
+ smem_desc_base_a = None
244
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
245
+ sB_swizzle = sB.iterator.type.swizzle_type
246
+ smem_desc_base_b: int = const_expr(
247
+ sm100_desc.make_smem_desc_base(
248
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
249
+ sB_swizzle,
250
+ sm100_desc.Major.K
251
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
252
+ else sm100_desc.Major.MN,
253
+ )
254
+ )
255
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
256
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
257
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
258
+
259
+ if const_expr(not is_ts):
260
+ offset_a = [
261
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
262
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
263
+ ]
264
+ else:
265
+ offset_a = [
266
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
267
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
268
+ ]
269
+ offset_a_diff = [
270
+ offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
271
+ ]
272
+ offset_b = [
273
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
274
+ for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
275
+ ]
276
+ offset_b_diff = [
277
+ offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
278
+ ]
279
+
280
+ if const_expr(not is_ts):
281
+ smem_desc_start_a_lo = Int32(
282
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
283
+ )
284
+ else:
285
+ smem_desc_start_a_lo = None
286
+ smem_desc_start_b_lo = Int32(
287
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
288
+ )
289
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
290
+ if const_expr(not is_ts):
291
+ llvm.inline_asm(
292
+ None,
293
+ [
294
+ acc.iterator.toint().ir_value(),
295
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
296
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
297
+ Int32(not zero_init).ir_value(),
298
+ ],
299
+ "{\n\t"
300
+ ".reg .pred leader_thread;\n\t"
301
+ ".reg .pred p;\n\t"
302
+ ".reg .b32 idesc;\n\t"
303
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
304
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
305
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
306
+ "elect.sync _|leader_thread, -1;\n\t"
307
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
308
+ "mov.b32 smem_desc_a_lo, $1;\n\t"
309
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
310
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
311
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
312
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
313
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
314
+ "setp.ne.b32 p, $3, 0;\n\t"
315
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
316
+ + "".join(
317
+ (
318
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
319
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
320
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
321
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
322
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
323
+ )
324
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
325
+ )
326
+ + "}\n",
327
+ "r,r,r,r",
328
+ has_side_effects=True,
329
+ is_align_stack=False,
330
+ asm_dialect=llvm.AsmDialect.AD_ATT,
331
+ )
332
+ else:
333
+ llvm.inline_asm(
334
+ None,
335
+ [
336
+ acc.iterator.toint().ir_value(),
337
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
338
+ Int32(smem_desc_start_b_lo).ir_value(),
339
+ Int32(not zero_init).ir_value(),
340
+ ],
341
+ "{\n\t"
342
+ ".reg .pred leader_thread;\n\t"
343
+ ".reg .pred p;\n\t"
344
+ ".reg .b32 idesc;\n\t"
345
+ ".reg .b32 tmem_a;\n\t"
346
+ ".reg .b32 smem_desc_b_lo;\n\t"
347
+ ".reg .b32 smem_desc_b_hi;\n\t"
348
+ ".reg .b64 smem_desc_b;\n\t"
349
+ "elect.sync _|leader_thread, -1;\n\t"
350
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
351
+ "mov.b32 tmem_a, $1;\n\t"
352
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
353
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
354
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
355
+ "setp.ne.b32 p, $3, 0;\n\t"
356
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
357
+ + "".join(
358
+ (
359
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
360
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
361
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
362
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
363
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
364
+ )
365
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
366
+ )
367
+ + "}\n",
368
+ "r,r,r,r",
369
+ has_side_effects=True,
370
+ is_align_stack=False,
371
+ asm_dialect=llvm.AsmDialect.AD_ATT,
372
+ )
373
+
374
+
375
+ @cute.jit
376
+ def gemm_ptx_partial(
377
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
378
+ acc_tmem_addr: Int32,
379
+ tCrA: cute.Tensor,
380
+ tCrB: cute.Tensor,
381
+ sA: Optional[cute.Tensor],
382
+ sB: cute.Tensor,
383
+ mbar_ptr: Optional[cutlass.Pointer] = None,
384
+ mbar_phase: Optional[Int32] = None,
385
+ split_arrive: Optional[int] = None,
386
+ zero_init: bool | Boolean = False,
387
+ # sA_offset: Int32 = 0,
388
+ # acc_offset: Int32 = 0,
389
+ tA_addr: Optional[Int32] = None,
390
+ cta_group: int = 1,
391
+ mma_kind: str = "f16",
392
+ ) -> None:
393
+ # acc_tmem_addr += acc_offset
394
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
395
+ if const_expr(not is_ts):
396
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
397
+ sA_layout = sA.layout if sA is not None else tCrA.layout
398
+ sB_layout = sB.layout
399
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
400
+ if const_expr(not is_ts):
401
+ sA_swizzle = sA.iterator.type.swizzle_type
402
+ smem_desc_base_a: int = const_expr(
403
+ sm100_desc.make_smem_desc_base(
404
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
405
+ sA_swizzle,
406
+ sm100_desc.Major.K
407
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
408
+ else sm100_desc.Major.MN,
409
+ )
410
+ )
411
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
412
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
413
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
414
+ else:
415
+ smem_desc_base_a = None
416
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
417
+ sB_swizzle = sB.iterator.type.swizzle_type
418
+ smem_desc_base_b: int = const_expr(
419
+ sm100_desc.make_smem_desc_base(
420
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
421
+ sB_swizzle,
422
+ sm100_desc.Major.K
423
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
424
+ else sm100_desc.Major.MN,
425
+ )
426
+ )
427
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
428
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
429
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
430
+
431
+ tCrA_layout = (
432
+ tCrA.layout
433
+ if const_expr(not is_ts)
434
+ else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
435
+ )
436
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
437
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
438
+ offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
439
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
440
+
441
+ if const_expr(not is_ts):
442
+ smem_desc_start_a_lo = Int32(
443
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
444
+ )
445
+ # ) + sA_offset
446
+ else:
447
+ smem_desc_start_a_lo = None
448
+ smem_desc_start_b_lo = Int32(
449
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
450
+ )
451
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
452
+ if const_expr(not is_ts):
453
+ assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
454
+ llvm.inline_asm(
455
+ None,
456
+ [
457
+ # acc.iterator.toint().ir_value(),
458
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
459
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
460
+ Int32(not zero_init).ir_value(),
461
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
462
+ ],
463
+ "{\n\t"
464
+ ".reg .pred leader_thread;\n\t"
465
+ ".reg .pred p;\n\t"
466
+ ".reg .b32 idesc;\n\t"
467
+ ".reg .b32 tmem_acc;\n\t"
468
+ ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
469
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
470
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
471
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
472
+ "elect.sync _|leader_thread, -1;\n\t"
473
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
474
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
475
+ f"mov.b32 tmem_acc, $3;\n\t"
476
+ "mov.b32 smem_desc_a_lo_start, $0;\n\t"
477
+ "mov.b32 smem_desc_b_lo_start, $1;\n\t"
478
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
479
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
480
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
481
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
482
+ "setp.ne.b32 p, $2, 0;\n\t"
483
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
484
+ + "".join(
485
+ (
486
+ # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
487
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
488
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
489
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
490
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
491
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
492
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
493
+ )
494
+ for k in range(1, cute.size(tCrA.shape[2]))
495
+ )
496
+ + "}\n",
497
+ # "r,r,r",
498
+ "r,r,r,r",
499
+ has_side_effects=True,
500
+ is_align_stack=False,
501
+ asm_dialect=llvm.AsmDialect.AD_ATT,
502
+ )
503
+ else:
504
+ # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
505
+ # explicitly pass in the tA_addr for correctness.
506
+ tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
507
+ input_args = [
508
+ # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
509
+ Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
510
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
511
+ Int32(not zero_init).ir_value(),
512
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
513
+ ]
514
+ if const_expr(mbar_ptr is not None):
515
+ assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
516
+ assert split_arrive is not None, (
517
+ "split_arrive must be provided when mbar_ptr is not None"
518
+ )
519
+ split_arrive_idx = split_arrive // op.shape_mnk[2]
520
+ input_args.append(mbar_ptr.toint().ir_value())
521
+ input_args.append(Int32(mbar_phase).ir_value())
522
+ mbar_wait_str = (
523
+ ".reg .pred P1; \n\t"
524
+ "LAB_WAIT: \n\t"
525
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
526
+ "@P1 bra DONE; \n\t"
527
+ "bra LAB_WAIT; \n\t"
528
+ "DONE: \n\t"
529
+ )
530
+ else:
531
+ mbar_wait_str = ""
532
+ llvm.inline_asm(
533
+ None,
534
+ # [
535
+ # # acc.iterator.toint().ir_value(),
536
+ # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
537
+ # Int32(smem_desc_start_b_lo).ir_value(),
538
+ # Int32(not zero_init).ir_value(),
539
+ # ],
540
+ input_args,
541
+ "{\n\t"
542
+ ".reg .pred leader_thread;\n\t"
543
+ ".reg .pred p;\n\t"
544
+ ".reg .b32 idesc;\n\t"
545
+ ".reg .b32 tmem_acc;\n\t"
546
+ ".reg .b32 tmem_a;\n\t"
547
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
548
+ ".reg .b32 smem_desc_b_lo;\n\t"
549
+ ".reg .b32 smem_desc_b_hi;\n\t"
550
+ ".reg .b64 smem_desc_b;\n\t"
551
+ "elect.sync _|leader_thread, -1;\n\t"
552
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
553
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
554
+ f"mov.b32 tmem_acc, $3;\n\t"
555
+ f"mov.b32 tmem_a, $0;\n\t"
556
+ f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
557
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
558
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
559
+ "setp.ne.b32 p, $2, 0;\n\t"
560
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
561
+ + "".join(
562
+ (
563
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
564
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
565
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
566
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
567
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
568
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
569
+ )
570
+ for k in range(
571
+ 1,
572
+ cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx,
573
+ )
574
+ )
575
+ + mbar_wait_str
576
+ + (
577
+ "".join(
578
+ (
579
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
580
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
581
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
582
+ )
583
+ for k in range(split_arrive_idx, cute.size(tCrA.shape[2]))
584
+ )
585
+ if const_expr(mbar_ptr is not None)
586
+ else ""
587
+ )
588
+ + "}\n",
589
+ "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
590
+ has_side_effects=True,
591
+ is_align_stack=False,
592
+ asm_dialect=llvm.AsmDialect.AD_ATT,
593
+ )
594
+
595
+
596
+ @cute.jit
597
+ def gemm_ptx_partial1(
598
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
599
+ acc_tmem_addr: cutlass.Constexpr[int],
600
+ tCrA: cute.Tensor,
601
+ tCrB: cute.Tensor,
602
+ sA_base_addr_for_desc: Int32,
603
+ sA_addr_offset_for_desc: cutlass.Constexpr[int],
604
+ sA_stage: Int32,
605
+ sB_base_addr_for_desc: Int32,
606
+ sB_addr_offset_for_desc: cutlass.Constexpr[int],
607
+ sB_stage: Int32,
608
+ sA_layout: Optional[cute.Layout],
609
+ sB_layout: Optional[cute.Layout],
610
+ sA_swizzle: Optional[cute.Swizzle],
611
+ sB_swizzle: cute.Swizzle,
612
+ zero_init: bool | Boolean = False,
613
+ ) -> None:
614
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
615
+ if const_expr(not is_ts):
616
+ assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
617
+ assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
618
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
619
+ if const_expr(not is_ts):
620
+ smem_desc_base_a: int = const_expr(
621
+ sm100_desc.make_smem_desc_base(
622
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
623
+ sA_swizzle,
624
+ sm100_desc.Major.K
625
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
626
+ else sm100_desc.Major.MN,
627
+ )
628
+ )
629
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
630
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
631
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
632
+ else:
633
+ smem_desc_base_a = None
634
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
635
+ smem_desc_base_b: int = const_expr(
636
+ sm100_desc.make_smem_desc_base(
637
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
638
+ sB_swizzle,
639
+ sm100_desc.Major.K
640
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
641
+ else sm100_desc.Major.MN,
642
+ )
643
+ )
644
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
645
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
646
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
647
+ mask = [Int32(0)] * 4
648
+
649
+ if const_expr(not is_ts):
650
+ offset_a = [
651
+ (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
652
+ for k in range(cute.size(tCrA.shape[2]))
653
+ ]
654
+ else:
655
+ offset_a = [
656
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
657
+ for k in range(cute.size(tCrA.shape[2]))
658
+ ]
659
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
660
+ offset_b = [
661
+ (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
662
+ for k in range(cute.size(tCrB.shape[2]))
663
+ ]
664
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
665
+
666
+ if const_expr(not is_ts):
667
+ # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
668
+ smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
669
+ else:
670
+ smem_desc_start_a_lo = None
671
+ # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
672
+ smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
673
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
674
+ if const_expr(not is_ts):
675
+ llvm.inline_asm(
676
+ None,
677
+ [
678
+ # acc.iterator.toint().ir_value(),
679
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
680
+ Int32(sA_base_addr_for_desc).ir_value(),
681
+ Int32(sA_stage).ir_value(),
682
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
683
+ Int32(sB_base_addr_for_desc).ir_value(),
684
+ Int32(sB_stage).ir_value(),
685
+ Int32(not zero_init).ir_value(),
686
+ mask[0].ir_value(),
687
+ mask[1].ir_value(),
688
+ mask[2].ir_value(),
689
+ mask[3].ir_value(),
690
+ ],
691
+ "{\n\t"
692
+ ".reg .pred leader_thread;\n\t"
693
+ ".reg .pred p;\n\t"
694
+ ".reg .b32 idesc;\n\t"
695
+ ".reg .b32 tmem_acc;\n\t"
696
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
697
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
698
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
699
+ "elect.sync _|leader_thread, -1;\n\t"
700
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
701
+ f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
702
+ # "mov.b32 smem_desc_a_lo, $0;\n\t"
703
+ # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
704
+ f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
705
+ # "mov.b32 smem_desc_b_lo, $2;\n\t"
706
+ f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
707
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
708
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
709
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
710
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
711
+ "setp.ne.b32 p, $4, 0;\n\t"
712
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
713
+ + "".join(
714
+ (
715
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
716
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
717
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
718
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
719
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
720
+ )
721
+ for k in range(1, cute.size(tCrA.shape[2]))
722
+ )
723
+ + "}\n",
724
+ "r,r,r,r,r,r,r,r,r",
725
+ has_side_effects=True,
726
+ is_align_stack=False,
727
+ asm_dialect=llvm.AsmDialect.AD_ATT,
728
+ )
729
+ else:
730
+ llvm.inline_asm(
731
+ None,
732
+ [
733
+ # acc.iterator.toint().ir_value(),
734
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
735
+ Int32(smem_desc_start_b_lo).ir_value(),
736
+ Int32(not zero_init).ir_value(),
737
+ mask[0].ir_value(),
738
+ mask[1].ir_value(),
739
+ mask[2].ir_value(),
740
+ mask[3].ir_value(),
741
+ ],
742
+ "{\n\t"
743
+ ".reg .pred leader_thread;\n\t"
744
+ ".reg .pred p;\n\t"
745
+ ".reg .b32 idesc;\n\t"
746
+ ".reg .b32 tmem_a;\n\t"
747
+ ".reg .b32 smem_desc_b_lo;\n\t"
748
+ ".reg .b32 smem_desc_b_hi;\n\t"
749
+ ".reg .b64 smem_desc_b;\n\t"
750
+ "elect.sync _|leader_thread, -1;\n\t"
751
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
752
+ f"mov.b32 tmem_a, $1;\n\t"
753
+ f"mov.b32 smem_desc_b_lo, $2;\n\t"
754
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
755
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
756
+ "setp.ne.b32 p, $3, 0;\n\t"
757
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
758
+ + "".join(
759
+ (
760
+ f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
761
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
762
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
763
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
764
+ )
765
+ for k in range(1, cute.size(tCrA.shape[2]))
766
+ )
767
+ + "}\n",
768
+ "r,r,r,r,r,r,r,r",
769
+ has_side_effects=True,
770
+ is_align_stack=False,
771
+ asm_dialect=llvm.AsmDialect.AD_ATT,
772
+ )
773
+
774
+
775
+ @cute.jit
776
+ def gemm_ptx_precomputed(
777
+ acc_tmem_addr: Int32,
778
+ smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
779
+ smem_desc_start_b: Int32,
780
+ idesc: int,
781
+ smem_desc_base_a: Optional[int],
782
+ smem_desc_base_b: int,
783
+ tCrA_layout: cute.Layout,
784
+ tCrB_layout: cute.Layout,
785
+ mbar_ptr: Optional[cutlass.Pointer] = None,
786
+ mbar_phase: Optional[Int32] = None,
787
+ zero_init: bool | Boolean = False,
788
+ cta_group: int = 1,
789
+ ) -> None:
790
+ # acc_tmem_addr += acc_offset
791
+ is_ts = const_expr(smem_desc_base_a is None)
792
+ num_k_tile = cute.size(tCrA_layout.shape[2])
793
+ if const_expr(not is_ts):
794
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
795
+ else:
796
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
797
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
798
+
799
+ tCrA_layout = (
800
+ tCrA_layout
801
+ if const_expr(not is_ts)
802
+ # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
803
+ # currently hard-coding the width to 16
804
+ else cute.recast_layout(32, 16, tCrA_layout)
805
+ )
806
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
807
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)]
808
+ offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
809
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)]
810
+
811
+ smem_desc_start_a_lo = None
812
+ if const_expr(not is_ts):
813
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
814
+ # smem_desc_start_a_lo = smem_desc_start_a
815
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
816
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
817
+ if const_expr(not is_ts):
818
+ assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
819
+ llvm.inline_asm(
820
+ None,
821
+ [
822
+ # acc.iterator.toint().ir_value(),
823
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
824
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
825
+ Int32(not zero_init).ir_value(),
826
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
827
+ ],
828
+ "{\n\t"
829
+ ".reg .pred leader_thread;\n\t"
830
+ ".reg .pred p;\n\t"
831
+ ".reg .b32 idesc;\n\t"
832
+ ".reg .b32 tmem_acc;\n\t"
833
+ ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
834
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
835
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
836
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
837
+ "elect.sync _|leader_thread, -1;\n\t"
838
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
839
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
840
+ f"mov.b32 tmem_acc, $3;\n\t"
841
+ "mov.b32 smem_desc_a_lo_start, $0;\n\t"
842
+ "mov.b32 smem_desc_b_lo_start, $1;\n\t"
843
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
844
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
845
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
846
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
847
+ "setp.ne.b32 p, $2, 0;\n\t"
848
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
849
+ + "".join(
850
+ (
851
+ # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
852
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
853
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
854
+ f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
855
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
856
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
857
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
858
+ )
859
+ for k in range(1, num_k_tile)
860
+ )
861
+ + "}\n",
862
+ # "r,r,r",
863
+ "r,r,r,r",
864
+ has_side_effects=True,
865
+ is_align_stack=False,
866
+ asm_dialect=llvm.AsmDialect.AD_ATT,
867
+ )
868
+ else:
869
+ input_args = [
870
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(),
871
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
872
+ Int32(not zero_init).ir_value(),
873
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
874
+ ]
875
+ if const_expr(mbar_ptr is not None):
876
+ assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
877
+ input_args.append(mbar_ptr.toint().ir_value())
878
+ input_args.append(Int32(mbar_phase).ir_value())
879
+ mbar_wait_str = (
880
+ ".reg .pred P1; \n\t"
881
+ "LAB_WAIT: \n\t"
882
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
883
+ "@P1 bra DONE; \n\t"
884
+ "bra LAB_WAIT; \n\t"
885
+ "DONE: \n\t"
886
+ )
887
+ else:
888
+ mbar_wait_str = ""
889
+ llvm.inline_asm(
890
+ None,
891
+ # [
892
+ # # acc.iterator.toint().ir_value(),
893
+ # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(),
894
+ # Int32(smem_desc_start_b_lo).ir_value(),
895
+ # Int32(not zero_init).ir_value(),
896
+ # ],
897
+ input_args,
898
+ "{\n\t"
899
+ ".reg .pred leader_thread;\n\t"
900
+ ".reg .pred p;\n\t"
901
+ ".reg .b32 idesc;\n\t"
902
+ ".reg .b32 tmem_acc;\n\t"
903
+ ".reg .b32 tmem_a;\n\t"
904
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
905
+ ".reg .b32 smem_desc_b_lo;\n\t"
906
+ ".reg .b32 smem_desc_b_hi;\n\t"
907
+ ".reg .b64 smem_desc_b;\n\t"
908
+ "elect.sync _|leader_thread, -1;\n\t"
909
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
910
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
911
+ f"mov.b32 tmem_acc, $3;\n\t"
912
+ f"mov.b32 tmem_a, $0;\n\t"
913
+ f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
914
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
915
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
916
+ "setp.ne.b32 p, $2, 0;\n\t"
917
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
918
+ + "".join(
919
+ (
920
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
921
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
922
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
923
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
924
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
925
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
926
+ )
927
+ for k in range(
928
+ 1,
929
+ num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3,
930
+ )
931
+ )
932
+ + mbar_wait_str
933
+ + (
934
+ "".join(
935
+ (
936
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
937
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
938
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
939
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
940
+ )
941
+ for k in range(num_k_tile // 4 * 3, num_k_tile)
942
+ )
943
+ if const_expr(mbar_ptr is not None)
944
+ else ""
945
+ )
946
+ + "}\n",
947
+ "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
948
+ has_side_effects=True,
949
+ is_align_stack=False,
950
+ asm_dialect=llvm.AsmDialect.AD_ATT,
951
+ )
952
+
953
+
954
+ @cute.jit
955
+ def declare_ptx_smem_desc(
956
+ smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
957
+ smem_desc_base_a: Optional[int],
958
+ tCrA_layout: cute.Layout,
959
+ var_name_prefix: str = "smem_desc",
960
+ ) -> None:
961
+ is_ts = const_expr(smem_desc_base_a is None)
962
+ num_k_tile = cute.size(tCrA_layout.shape[2])
963
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
964
+ if const_expr(not is_ts):
965
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
966
+ tCrA_layout = (
967
+ tCrA_layout
968
+ if const_expr(not is_ts)
969
+ # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
970
+ # currently hard-coding the width to 16
971
+ else cute.recast_layout(32, 16, tCrA_layout)
972
+ )
973
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
974
+ smem_desc_start_a_lo = None
975
+ if const_expr(not is_ts):
976
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
977
+ if const_expr(not is_ts):
978
+ llvm.inline_asm(
979
+ None,
980
+ [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()],
981
+ f".reg .b32 {var_name_prefix}_lo;\n\t"
982
+ f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t"
983
+ f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t"
984
+ + "".join(
985
+ (
986
+ f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t"
987
+ f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t"
988
+ )
989
+ for k in range(1, num_k_tile)
990
+ ),
991
+ "r",
992
+ has_side_effects=True,
993
+ is_align_stack=False,
994
+ asm_dialect=llvm.AsmDialect.AD_ATT,
995
+ )
996
+
997
+
998
+ @cute.jit
999
+ def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None:
1000
+ idesc = const_expr(sm100_desc.mma_op_to_idesc(op))
1001
+ llvm.inline_asm(
1002
+ None,
1003
+ [],
1004
+ f".reg .b32 {var_name};\n\t" # noqa
1005
+ f"mov.b32 {var_name}, {hex(idesc)};\n\t",
1006
+ constraints="",
1007
+ has_side_effects=True,
1008
+ is_align_stack=False,
1009
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1010
+ )
1011
+
1012
+
1013
+ @cute.jit
1014
+ def gemm_ptx_precomputed_varname(
1015
+ acc_tmem_addr: Int32,
1016
+ smem_desc_start_b: Int32,
1017
+ # idesc: int,
1018
+ smem_desc_base_b: int,
1019
+ tCrB_layout: cute.Layout,
1020
+ smem_var_name_prefix: str,
1021
+ idesc_var_name: str,
1022
+ smem_offset: int,
1023
+ zero_init: bool | Boolean = False,
1024
+ cta_group: int = 1,
1025
+ mma_kind: str = "f16",
1026
+ ) -> None:
1027
+ is_ts = False
1028
+ num_k_tile = cute.size(tCrB_layout.shape[2])
1029
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
1030
+ offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
1031
+
1032
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
1033
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
1034
+ if const_expr(not is_ts):
1035
+ llvm.inline_asm(
1036
+ None,
1037
+ [
1038
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
1039
+ Int32(not zero_init).ir_value(),
1040
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
1041
+ ],
1042
+ "{\n\t"
1043
+ ".reg .pred leader_thread;\n\t"
1044
+ ".reg .pred p;\n\t"
1045
+ # ".reg .b32 idesc;\n\t"
1046
+ ".reg .b32 tmem_acc;\n\t"
1047
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
1048
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
1049
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
1050
+ # ".reg .b64 smem_desc_b;\n\t"
1051
+ f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t"
1052
+ "elect.sync _|leader_thread, -1;\n\t"
1053
+ # f"mov.b32 idesc, {hex(idesc)};\n\t"
1054
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
1055
+ f"mov.b32 tmem_acc, $2;\n\t"
1056
+ "mov.b32 smem_desc_b_lo_start, $0;\n\t"
1057
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
1058
+ f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t"
1059
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1060
+ f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1061
+ f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
1062
+ + "".join(
1063
+ (
1064
+ f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
1065
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1066
+ f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
1067
+ f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1068
+ f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
1069
+ )
1070
+ for k in range(1, num_k_tile)
1071
+ )
1072
+ + "setp.ne.b32 p, $1, 0;\n\t"
1073
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t"
1074
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t"
1075
+ + "".join(
1076
+ (
1077
+ # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
1078
+ # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1079
+ # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
1080
+ # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1081
+ # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
1082
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t"
1083
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t"
1084
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t"
1085
+ )
1086
+ for k in range(1, num_k_tile)
1087
+ )
1088
+ + "}\n",
1089
+ "r,r,r",
1090
+ has_side_effects=True,
1091
+ is_align_stack=False,
1092
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1093
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from typing import Tuple
5
+ from dataclasses import dataclass
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, const_expr
10
+
11
+ from ...src.common.seqlen_info import SeqlenInfoQK
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class BlockInfo:
16
+ tile_m: cutlass.Constexpr[int]
17
+ tile_n: cutlass.Constexpr[int]
18
+ is_causal: cutlass.Constexpr[bool]
19
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
20
+
21
+ @cute.jit
22
+ def get_n_block_min_max(
23
+ self,
24
+ seqlen_info: SeqlenInfoQK,
25
+ m_block: Int32,
26
+ split_idx: Int32 = 0,
27
+ num_splits: Int32 = 1,
28
+ ) -> Tuple[Int32, Int32]:
29
+ n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
30
+ if const_expr(self.is_causal):
31
+ m_idx_max = (m_block + 1) * self.tile_m
32
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
33
+ m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
34
+ n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
35
+ n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n))
36
+ n_block_min = 0
37
+ if num_splits > 1:
38
+ num_n_blocks_per_split = (
39
+ Int32(0)
40
+ if n_block_max <= n_block_min
41
+ else (n_block_max - n_block_min + num_splits - 1) // num_splits
42
+ )
43
+ n_block_min = n_block_min + split_idx * num_n_blocks_per_split
44
+ n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
45
+ return n_block_min, n_block_max
46
+
47
+ @cute.jit
48
+ def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
49
+ m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
50
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
51
+ m_block_max = cute.ceil_div(
52
+ seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m
53
+ )
54
+ m_block_min = 0
55
+ if const_expr(self.is_causal):
56
+ n_idx_min = n_block * self.tile_n
57
+ m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
58
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
59
+ m_idx *= self.qhead_per_kvhead_packgqa
60
+ m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m)
61
+ return m_block_min, m_block_max
build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Copy, store, and layout execution helpers.
5
+
6
+ `copy_utils.py` is the canonical owner for generic copy primitives, async
7
+ bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers.
8
+ """
9
+
10
+ import math
11
+ from typing import Optional, Type, Callable
12
+
13
+ import cutlass
14
+ import cutlass.cute as cute
15
+ from cutlass import Float32, Int32, const_expr
16
+ from cutlass.cute.nvgpu import cpasync
17
+ import cutlass.utils.blackwell_helpers as sm100_utils
18
+ from cutlass.cutlass_dsl import T, dsl_user_op
19
+ from cutlass._mlir.dialects import llvm
20
+ import cutlass.pipeline
21
+
22
+
23
+ # Generic Copy Primitives
24
+
25
+ @dsl_user_op
26
+ def cvt_copy(
27
+ atom: cute.CopyAtom,
28
+ src: cute.Tensor,
29
+ dst: cute.Tensor,
30
+ *,
31
+ pred: Optional[cute.Tensor] = None,
32
+ loc=None,
33
+ ip=None,
34
+ **kwargs,
35
+ ) -> None:
36
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
37
+ if const_expr(src.element_type != dst.element_type):
38
+ src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip)
39
+ src_cvt.store(src.load().to(dst.element_type))
40
+ src = src_cvt
41
+ cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
42
+
43
+
44
+ @dsl_user_op
45
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
46
+ dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
47
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
48
+ return dst
49
+
50
+
51
+ @dsl_user_op
52
+ def get_copy_atom(
53
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
54
+ ) -> cute.CopyAtom:
55
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
56
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
57
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
58
+
59
+
60
+ @dsl_user_op
61
+ def make_tmem_copy(
62
+ tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None
63
+ ) -> cute.CopyAtom:
64
+ num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)
65
+ assert num_dp == 32
66
+ assert num_bits == 32
67
+ tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)
68
+ layout_tv = cute.make_layout(
69
+ ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))
70
+ )
71
+ return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)
72
+
73
+
74
+ @dsl_user_op
75
+ def copy(
76
+ src: cute.Tensor,
77
+ dst: cute.Tensor,
78
+ *,
79
+ pred: Optional[cute.Tensor] = None,
80
+ num_copy_elems: int = 1,
81
+ is_async: bool = False,
82
+ loc=None,
83
+ ip=None,
84
+ **kwargs,
85
+ ) -> None:
86
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
87
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
88
+
89
+
90
+ def tiled_copy_1d(
91
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
92
+ ) -> cute.TiledCopy:
93
+ num_copy_bits = num_copy_elems * dtype.width
94
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
95
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
96
+ thr_layout = cute.make_layout(num_threads)
97
+ val_layout = cute.make_layout(num_copy_elems)
98
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
99
+
100
+
101
+ def tiled_copy_2d(
102
+ dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
103
+ ) -> cute.TiledCopy:
104
+ num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
105
+ copy_elems = num_copy_bits // dtype.width
106
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
107
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
108
+ gmem_threads_per_row = major_mode_size // copy_elems
109
+ assert num_threads % gmem_threads_per_row == 0
110
+ thr_layout = cute.make_ordered_layout(
111
+ (num_threads // gmem_threads_per_row, gmem_threads_per_row),
112
+ order=(1, 0),
113
+ )
114
+ val_layout = cute.make_layout((1, copy_elems))
115
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
116
+
117
+
118
+ @dsl_user_op
119
+ def atomic_add_fp32x4(
120
+ a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
121
+ ) -> None:
122
+ gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
123
+ # cache_hint = cutlass.Int64(0x12F0000000000000)
124
+ llvm.inline_asm(
125
+ None,
126
+ [
127
+ gmem_ptr_i64,
128
+ Float32(a).ir_value(loc=loc, ip=ip),
129
+ Float32(b).ir_value(loc=loc, ip=ip),
130
+ Float32(c).ir_value(loc=loc, ip=ip),
131
+ Float32(d).ir_value(loc=loc, ip=ip),
132
+ ],
133
+ # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
134
+ "{\n\t"
135
+ # ".reg .b128 abcd;\n\t"
136
+ # "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
137
+ ".reg .v4 .f32 abcd;\n\t"
138
+ # "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
139
+ "mov.f32 abcd.x, $1;\n\t"
140
+ "mov.f32 abcd.y, $2;\n\t"
141
+ "mov.f32 abcd.z, $3;\n\t"
142
+ "mov.f32 abcd.w, $4;\n\t"
143
+ "red.global.add.v4.f32 [$0], abcd;\n\t"
144
+ # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t"
145
+ "}\n",
146
+ # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
147
+ # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
148
+ "l,f,f,f,f",
149
+ # "l,f,l",
150
+ has_side_effects=True,
151
+ is_align_stack=False,
152
+ asm_dialect=llvm.AsmDialect.AD_ATT,
153
+ )
154
+
155
+
156
+ # Store/Layout Helpers
157
+
158
+ @dsl_user_op
159
+ def atomic_add_i32(gmem_ptr, *, loc=None, ip=None):
160
+ """Simple atomicAdd. Intended for use under a single-thread guard."""
161
+ result = llvm.inline_asm(
162
+ T.i32(),
163
+ [gmem_ptr.toint().ir_value(loc=loc, ip=ip)],
164
+ "atom.global.add.u32 $0, [$1], 1;\n",
165
+ "=r,l",
166
+ has_side_effects=True,
167
+ is_align_stack=False,
168
+ asm_dialect=llvm.AsmDialect.AD_ATT,
169
+ loc=loc,
170
+ ip=ip,
171
+ )
172
+ return Int32(result)
173
+
174
+
175
+ @dsl_user_op
176
+ def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None):
177
+ """Lane-0 atomicAdd broadcast to the whole warp via shfl."""
178
+ result = llvm.inline_asm(
179
+ T.i32(),
180
+ [gmem_ptr.toint().ir_value(loc=loc, ip=ip)],
181
+ "{\n"
182
+ ".reg .pred p;\n"
183
+ ".reg .u32 lane, r;\n"
184
+ "mov.u32 lane, %laneid;\n"
185
+ "mov.u32 r, 0;\n"
186
+ "setp.eq.u32 p, lane, 0;\n"
187
+ "@p atom.global.add.u32 r, [$1], 1;\n"
188
+ "shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n"
189
+ "mov.u32 $0, r;\n"
190
+ "}\n",
191
+ "=r,l",
192
+ has_side_effects=True,
193
+ is_align_stack=False,
194
+ asm_dialect=llvm.AsmDialect.AD_ATT,
195
+ loc=loc,
196
+ ip=ip,
197
+ )
198
+ return Int32(result)
199
+
200
+
201
+ @dsl_user_op
202
+ def stg_128(
203
+ gmem_ptr: cute.Pointer,
204
+ v0: Float32,
205
+ v1: Float32,
206
+ v2: Float32,
207
+ v3: Float32,
208
+ *,
209
+ loc=None,
210
+ ip=None,
211
+ ):
212
+ llvm.inline_asm(
213
+ llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
214
+ [
215
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
216
+ Float32(v0).ir_value(loc=loc, ip=ip),
217
+ Float32(v1).ir_value(loc=loc, ip=ip),
218
+ Float32(v2).ir_value(loc=loc, ip=ip),
219
+ Float32(v3).ir_value(loc=loc, ip=ip),
220
+ ],
221
+ "st.global.v4.f32 [$4], {$5, $6, $7, $8}; "
222
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
223
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
224
+ "=f,=f,=f,=f,l,f,f,f,f",
225
+ has_side_effects=True,
226
+ is_align_stack=False,
227
+ asm_dialect=llvm.AsmDialect.AD_ATT,
228
+ loc=loc,
229
+ ip=ip,
230
+ )
231
+
232
+
233
+ @dsl_user_op
234
+ def stg_128_cs(
235
+ gmem_ptr: cute.Pointer,
236
+ v0: Float32,
237
+ v1: Float32,
238
+ v2: Float32,
239
+ v3: Float32,
240
+ *,
241
+ loc=None,
242
+ ip=None,
243
+ ):
244
+ llvm.inline_asm(
245
+ llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
246
+ [
247
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
248
+ Float32(v0).ir_value(loc=loc, ip=ip),
249
+ Float32(v1).ir_value(loc=loc, ip=ip),
250
+ Float32(v2).ir_value(loc=loc, ip=ip),
251
+ Float32(v3).ir_value(loc=loc, ip=ip),
252
+ ],
253
+ "st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; "
254
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
255
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
256
+ "=f,=f,=f,=f,l,f,f,f,f",
257
+ has_side_effects=True,
258
+ is_align_stack=False,
259
+ asm_dialect=llvm.AsmDialect.AD_ATT,
260
+ loc=loc,
261
+ ip=ip,
262
+ )
263
+
264
+
265
+ @dsl_user_op
266
+ def stg_64_bf16(
267
+ gmem_ptr: cute.Pointer,
268
+ v0: Float32,
269
+ v1: Float32,
270
+ v2: Float32,
271
+ v3: Float32,
272
+ *,
273
+ loc=None,
274
+ ip=None,
275
+ ):
276
+ llvm.inline_asm(
277
+ llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
278
+ [
279
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
280
+ Float32(v0).ir_value(loc=loc, ip=ip),
281
+ Float32(v1).ir_value(loc=loc, ip=ip),
282
+ Float32(v2).ir_value(loc=loc, ip=ip),
283
+ Float32(v3).ir_value(loc=loc, ip=ip),
284
+ ],
285
+ "{\n"
286
+ ".reg .b16 h0, h1, h2, h3;\n"
287
+ ".reg .b32 p0, p1;\n"
288
+ "cvt.rn.bf16.f32 h0, $5;\n"
289
+ "cvt.rn.bf16.f32 h1, $6;\n"
290
+ "cvt.rn.bf16.f32 h2, $7;\n"
291
+ "cvt.rn.bf16.f32 h3, $8;\n"
292
+ "mov.b32 p0, {h0, h1};\n"
293
+ "mov.b32 p1, {h2, h3};\n"
294
+ "st.global.v2.b32 [$4], {p0, p1};\n"
295
+ "}\n"
296
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
297
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
298
+ "=f,=f,=f,=f,l,f,f,f,f",
299
+ has_side_effects=True,
300
+ is_align_stack=False,
301
+ asm_dialect=llvm.AsmDialect.AD_ATT,
302
+ loc=loc,
303
+ ip=ip,
304
+ )
305
+
306
+
307
+ @dsl_user_op
308
+ def stg_64_f16(
309
+ gmem_ptr: cute.Pointer,
310
+ v0: Float32,
311
+ v1: Float32,
312
+ v2: Float32,
313
+ v3: Float32,
314
+ *,
315
+ loc=None,
316
+ ip=None,
317
+ ):
318
+ llvm.inline_asm(
319
+ llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
320
+ [
321
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
322
+ Float32(v0).ir_value(loc=loc, ip=ip),
323
+ Float32(v1).ir_value(loc=loc, ip=ip),
324
+ Float32(v2).ir_value(loc=loc, ip=ip),
325
+ Float32(v3).ir_value(loc=loc, ip=ip),
326
+ ],
327
+ "{\n"
328
+ ".reg .f16 h0, h1, h2, h3;\n"
329
+ ".reg .b32 p0, p1;\n"
330
+ "cvt.rn.f16.f32 h0, $5;\n"
331
+ "cvt.rn.f16.f32 h1, $6;\n"
332
+ "cvt.rn.f16.f32 h2, $7;\n"
333
+ "cvt.rn.f16.f32 h3, $8;\n"
334
+ "mov.b32 p0, {h0, h1};\n"
335
+ "mov.b32 p1, {h2, h3};\n"
336
+ "st.global.v2.b32 [$4], {p0, p1};\n"
337
+ "}\n"
338
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
339
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
340
+ "=f,=f,=f,=f,l,f,f,f,f",
341
+ has_side_effects=True,
342
+ is_align_stack=False,
343
+ asm_dialect=llvm.AsmDialect.AD_ATT,
344
+ loc=loc,
345
+ ip=ip,
346
+ )
347
+
348
+
349
+ @dsl_user_op
350
+ def stg_32_fp8_e4m3(
351
+ gmem_ptr: cute.Pointer,
352
+ v0: Float32,
353
+ v1: Float32,
354
+ v2: Float32,
355
+ v3: Float32,
356
+ *,
357
+ loc=None,
358
+ ip=None,
359
+ ):
360
+ llvm.inline_asm(
361
+ llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
362
+ [
363
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
364
+ Float32(v0).ir_value(loc=loc, ip=ip),
365
+ Float32(v1).ir_value(loc=loc, ip=ip),
366
+ Float32(v2).ir_value(loc=loc, ip=ip),
367
+ Float32(v3).ir_value(loc=loc, ip=ip),
368
+ ],
369
+ "{\n"
370
+ ".reg .b16 h0, h1;\n"
371
+ ".reg .b32 p0;\n"
372
+ "cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n"
373
+ "cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n"
374
+ "mov.b32 p0, {h0, h1};\n"
375
+ "st.global.b32 [$4], p0;\n"
376
+ "}\n"
377
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
378
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
379
+ "=f,=f,=f,=f,l,f,f,f,f",
380
+ has_side_effects=True,
381
+ is_align_stack=False,
382
+ asm_dialect=llvm.AsmDialect.AD_ATT,
383
+ loc=loc,
384
+ ip=ip,
385
+ )
386
+
387
+
388
+ @dsl_user_op
389
+ def sts_32_bf16(
390
+ smem_ptr: cute.Pointer,
391
+ v0: Float32,
392
+ v1: Float32,
393
+ *,
394
+ loc=None,
395
+ ip=None,
396
+ ):
397
+ """Store two bf16 values to shared memory as one 32-bit transaction."""
398
+ llvm.inline_asm(
399
+ None,
400
+ [
401
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
402
+ Float32(v0).ir_value(loc=loc, ip=ip),
403
+ Float32(v1).ir_value(loc=loc, ip=ip),
404
+ ],
405
+ "{\n"
406
+ ".reg .u32 sa;\n"
407
+ ".reg .b16 h0, h1;\n"
408
+ ".reg .b32 p0;\n"
409
+ "cvt.u32.u64 sa, $0;\n"
410
+ "cvt.rn.bf16.f32 h0, $1;\n"
411
+ "cvt.rn.bf16.f32 h1, $2;\n"
412
+ "mov.b32 p0, {h0, h1};\n"
413
+ "st.shared.b32 [sa], p0;\n"
414
+ "}\n",
415
+ "l,f,f",
416
+ has_side_effects=True,
417
+ is_align_stack=False,
418
+ asm_dialect=llvm.AsmDialect.AD_ATT,
419
+ loc=loc,
420
+ ip=ip,
421
+ )
422
+
423
+
424
+ @dsl_user_op
425
+ def sts_32_f16(
426
+ smem_ptr: cute.Pointer,
427
+ v0: Float32,
428
+ v1: Float32,
429
+ *,
430
+ loc=None,
431
+ ip=None,
432
+ ):
433
+ """Store two fp16 values to shared memory as one 32-bit transaction."""
434
+ llvm.inline_asm(
435
+ None,
436
+ [
437
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
438
+ Float32(v0).ir_value(loc=loc, ip=ip),
439
+ Float32(v1).ir_value(loc=loc, ip=ip),
440
+ ],
441
+ "{\n"
442
+ ".reg .u32 sa;\n"
443
+ ".reg .f16 h0, h1;\n"
444
+ ".reg .b32 p0;\n"
445
+ "cvt.u32.u64 sa, $0;\n"
446
+ "cvt.rn.f16.f32 h0, $1;\n"
447
+ "cvt.rn.f16.f32 h1, $2;\n"
448
+ "mov.b32 p0, {h0, h1};\n"
449
+ "st.shared.b32 [sa], p0;\n"
450
+ "}\n",
451
+ "l,f,f",
452
+ has_side_effects=True,
453
+ is_align_stack=False,
454
+ asm_dialect=llvm.AsmDialect.AD_ATT,
455
+ loc=loc,
456
+ ip=ip,
457
+ )
458
+
459
+
460
+ @dsl_user_op
461
+ def stg_128_bf16(
462
+ gmem_ptr: cute.Pointer,
463
+ v0: Float32,
464
+ v1: Float32,
465
+ v2: Float32,
466
+ v3: Float32,
467
+ v4: Float32,
468
+ v5: Float32,
469
+ v6: Float32,
470
+ v7: Float32,
471
+ *,
472
+ loc=None,
473
+ ip=None,
474
+ ):
475
+ llvm.inline_asm(
476
+ llvm.StructType.get_literal(
477
+ [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
478
+ ),
479
+ [
480
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
481
+ Float32(v0).ir_value(loc=loc, ip=ip),
482
+ Float32(v1).ir_value(loc=loc, ip=ip),
483
+ Float32(v2).ir_value(loc=loc, ip=ip),
484
+ Float32(v3).ir_value(loc=loc, ip=ip),
485
+ Float32(v4).ir_value(loc=loc, ip=ip),
486
+ Float32(v5).ir_value(loc=loc, ip=ip),
487
+ Float32(v6).ir_value(loc=loc, ip=ip),
488
+ Float32(v7).ir_value(loc=loc, ip=ip),
489
+ ],
490
+ "{\n"
491
+ ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
492
+ ".reg .b32 p0, p1, p2, p3;\n"
493
+ "cvt.rn.bf16.f32 h0, $9;\n"
494
+ "cvt.rn.bf16.f32 h1, $10;\n"
495
+ "cvt.rn.bf16.f32 h2, $11;\n"
496
+ "cvt.rn.bf16.f32 h3, $12;\n"
497
+ "cvt.rn.bf16.f32 h4, $13;\n"
498
+ "cvt.rn.bf16.f32 h5, $14;\n"
499
+ "cvt.rn.bf16.f32 h6, $15;\n"
500
+ "cvt.rn.bf16.f32 h7, $16;\n"
501
+ "mov.b32 p0, {h0, h1};\n"
502
+ "mov.b32 p1, {h2, h3};\n"
503
+ "mov.b32 p2, {h4, h5};\n"
504
+ "mov.b32 p3, {h6, h7};\n"
505
+ "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n"
506
+ "}\n"
507
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
508
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
509
+ "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
510
+ "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
511
+ "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
512
+ has_side_effects=True,
513
+ is_align_stack=False,
514
+ asm_dialect=llvm.AsmDialect.AD_ATT,
515
+ loc=loc,
516
+ ip=ip,
517
+ )
518
+
519
+
520
+ @dsl_user_op
521
+ def stg_128_bf16_cs(
522
+ gmem_ptr: cute.Pointer,
523
+ v0: Float32,
524
+ v1: Float32,
525
+ v2: Float32,
526
+ v3: Float32,
527
+ v4: Float32,
528
+ v5: Float32,
529
+ v6: Float32,
530
+ v7: Float32,
531
+ *,
532
+ loc=None,
533
+ ip=None,
534
+ ):
535
+ llvm.inline_asm(
536
+ llvm.StructType.get_literal(
537
+ [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
538
+ ),
539
+ [
540
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
541
+ Float32(v0).ir_value(loc=loc, ip=ip),
542
+ Float32(v1).ir_value(loc=loc, ip=ip),
543
+ Float32(v2).ir_value(loc=loc, ip=ip),
544
+ Float32(v3).ir_value(loc=loc, ip=ip),
545
+ Float32(v4).ir_value(loc=loc, ip=ip),
546
+ Float32(v5).ir_value(loc=loc, ip=ip),
547
+ Float32(v6).ir_value(loc=loc, ip=ip),
548
+ Float32(v7).ir_value(loc=loc, ip=ip),
549
+ ],
550
+ "{\n"
551
+ ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
552
+ ".reg .b32 p0, p1, p2, p3;\n"
553
+ "cvt.rn.bf16.f32 h0, $9;\n"
554
+ "cvt.rn.bf16.f32 h1, $10;\n"
555
+ "cvt.rn.bf16.f32 h2, $11;\n"
556
+ "cvt.rn.bf16.f32 h3, $12;\n"
557
+ "cvt.rn.bf16.f32 h4, $13;\n"
558
+ "cvt.rn.bf16.f32 h5, $14;\n"
559
+ "cvt.rn.bf16.f32 h6, $15;\n"
560
+ "cvt.rn.bf16.f32 h7, $16;\n"
561
+ "mov.b32 p0, {h0, h1};\n"
562
+ "mov.b32 p1, {h2, h3};\n"
563
+ "mov.b32 p2, {h4, h5};\n"
564
+ "mov.b32 p3, {h6, h7};\n"
565
+ "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n"
566
+ "}\n"
567
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
568
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
569
+ "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
570
+ "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
571
+ "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
572
+ has_side_effects=True,
573
+ is_align_stack=False,
574
+ asm_dialect=llvm.AsmDialect.AD_ATT,
575
+ loc=loc,
576
+ ip=ip,
577
+ )
578
+
579
+
580
+ @dsl_user_op
581
+ def stg_128_f16(
582
+ gmem_ptr: cute.Pointer,
583
+ v0: Float32,
584
+ v1: Float32,
585
+ v2: Float32,
586
+ v3: Float32,
587
+ v4: Float32,
588
+ v5: Float32,
589
+ v6: Float32,
590
+ v7: Float32,
591
+ *,
592
+ loc=None,
593
+ ip=None,
594
+ ):
595
+ llvm.inline_asm(
596
+ llvm.StructType.get_literal(
597
+ [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
598
+ ),
599
+ [
600
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
601
+ Float32(v0).ir_value(loc=loc, ip=ip),
602
+ Float32(v1).ir_value(loc=loc, ip=ip),
603
+ Float32(v2).ir_value(loc=loc, ip=ip),
604
+ Float32(v3).ir_value(loc=loc, ip=ip),
605
+ Float32(v4).ir_value(loc=loc, ip=ip),
606
+ Float32(v5).ir_value(loc=loc, ip=ip),
607
+ Float32(v6).ir_value(loc=loc, ip=ip),
608
+ Float32(v7).ir_value(loc=loc, ip=ip),
609
+ ],
610
+ "{\n"
611
+ ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
612
+ ".reg .b32 p0, p1, p2, p3;\n"
613
+ "cvt.rn.f16.f32 h0, $9;\n"
614
+ "cvt.rn.f16.f32 h1, $10;\n"
615
+ "cvt.rn.f16.f32 h2, $11;\n"
616
+ "cvt.rn.f16.f32 h3, $12;\n"
617
+ "cvt.rn.f16.f32 h4, $13;\n"
618
+ "cvt.rn.f16.f32 h5, $14;\n"
619
+ "cvt.rn.f16.f32 h6, $15;\n"
620
+ "cvt.rn.f16.f32 h7, $16;\n"
621
+ "mov.b32 p0, {h0, h1};\n"
622
+ "mov.b32 p1, {h2, h3};\n"
623
+ "mov.b32 p2, {h4, h5};\n"
624
+ "mov.b32 p3, {h6, h7};\n"
625
+ "st.global.v4.b32 [$8], {p0, p1, p2, p3};\n"
626
+ "}\n"
627
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
628
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
629
+ "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
630
+ "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
631
+ "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
632
+ has_side_effects=True,
633
+ is_align_stack=False,
634
+ asm_dialect=llvm.AsmDialect.AD_ATT,
635
+ loc=loc,
636
+ ip=ip,
637
+ )
638
+
639
+
640
+ @dsl_user_op
641
+ def stg_128_f16_cs(
642
+ gmem_ptr: cute.Pointer,
643
+ v0: Float32,
644
+ v1: Float32,
645
+ v2: Float32,
646
+ v3: Float32,
647
+ v4: Float32,
648
+ v5: Float32,
649
+ v6: Float32,
650
+ v7: Float32,
651
+ *,
652
+ loc=None,
653
+ ip=None,
654
+ ):
655
+ llvm.inline_asm(
656
+ llvm.StructType.get_literal(
657
+ [T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
658
+ ),
659
+ [
660
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
661
+ Float32(v0).ir_value(loc=loc, ip=ip),
662
+ Float32(v1).ir_value(loc=loc, ip=ip),
663
+ Float32(v2).ir_value(loc=loc, ip=ip),
664
+ Float32(v3).ir_value(loc=loc, ip=ip),
665
+ Float32(v4).ir_value(loc=loc, ip=ip),
666
+ Float32(v5).ir_value(loc=loc, ip=ip),
667
+ Float32(v6).ir_value(loc=loc, ip=ip),
668
+ Float32(v7).ir_value(loc=loc, ip=ip),
669
+ ],
670
+ "{\n"
671
+ ".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
672
+ ".reg .b32 p0, p1, p2, p3;\n"
673
+ "cvt.rn.f16.f32 h0, $9;\n"
674
+ "cvt.rn.f16.f32 h1, $10;\n"
675
+ "cvt.rn.f16.f32 h2, $11;\n"
676
+ "cvt.rn.f16.f32 h3, $12;\n"
677
+ "cvt.rn.f16.f32 h4, $13;\n"
678
+ "cvt.rn.f16.f32 h5, $14;\n"
679
+ "cvt.rn.f16.f32 h6, $15;\n"
680
+ "cvt.rn.f16.f32 h7, $16;\n"
681
+ "mov.b32 p0, {h0, h1};\n"
682
+ "mov.b32 p1, {h2, h3};\n"
683
+ "mov.b32 p2, {h4, h5};\n"
684
+ "mov.b32 p3, {h6, h7};\n"
685
+ "st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n"
686
+ "}\n"
687
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
688
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
689
+ "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
690
+ "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
691
+ "=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
692
+ has_side_effects=True,
693
+ is_align_stack=False,
694
+ asm_dialect=llvm.AsmDialect.AD_ATT,
695
+ loc=loc,
696
+ ip=ip,
697
+ )
698
+
699
+
700
+ @dsl_user_op
701
+ def stg_128_fp8_e4m3_cs(
702
+ gmem_ptr: cute.Pointer,
703
+ v0: Float32,
704
+ v1: Float32,
705
+ v2: Float32,
706
+ v3: Float32,
707
+ v4: Float32,
708
+ v5: Float32,
709
+ v6: Float32,
710
+ v7: Float32,
711
+ v8: Float32,
712
+ v9: Float32,
713
+ v10: Float32,
714
+ v11: Float32,
715
+ v12: Float32,
716
+ v13: Float32,
717
+ v14: Float32,
718
+ v15: Float32,
719
+ *,
720
+ loc=None,
721
+ ip=None,
722
+ ):
723
+ llvm.inline_asm(
724
+ llvm.StructType.get_literal(
725
+ [
726
+ T.f32(),
727
+ T.f32(),
728
+ T.f32(),
729
+ T.f32(),
730
+ T.f32(),
731
+ T.f32(),
732
+ T.f32(),
733
+ T.f32(),
734
+ T.f32(),
735
+ T.f32(),
736
+ T.f32(),
737
+ T.f32(),
738
+ T.f32(),
739
+ T.f32(),
740
+ T.f32(),
741
+ T.f32(),
742
+ ]
743
+ ),
744
+ [
745
+ gmem_ptr.toint().ir_value(loc=loc, ip=ip),
746
+ Float32(v0).ir_value(loc=loc, ip=ip),
747
+ Float32(v1).ir_value(loc=loc, ip=ip),
748
+ Float32(v2).ir_value(loc=loc, ip=ip),
749
+ Float32(v3).ir_value(loc=loc, ip=ip),
750
+ Float32(v4).ir_value(loc=loc, ip=ip),
751
+ Float32(v5).ir_value(loc=loc, ip=ip),
752
+ Float32(v6).ir_value(loc=loc, ip=ip),
753
+ Float32(v7).ir_value(loc=loc, ip=ip),
754
+ Float32(v8).ir_value(loc=loc, ip=ip),
755
+ Float32(v9).ir_value(loc=loc, ip=ip),
756
+ Float32(v10).ir_value(loc=loc, ip=ip),
757
+ Float32(v11).ir_value(loc=loc, ip=ip),
758
+ Float32(v12).ir_value(loc=loc, ip=ip),
759
+ Float32(v13).ir_value(loc=loc, ip=ip),
760
+ Float32(v14).ir_value(loc=loc, ip=ip),
761
+ Float32(v15).ir_value(loc=loc, ip=ip),
762
+ ],
763
+ "{\n"
764
+ ".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
765
+ ".reg .b32 p0, p1, p2, p3;\n"
766
+ "cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n"
767
+ "cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n"
768
+ "cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n"
769
+ "cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n"
770
+ "cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n"
771
+ "cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n"
772
+ "cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n"
773
+ "cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n"
774
+ "mov.b32 p0, {h0, h1};\n"
775
+ "mov.b32 p1, {h2, h3};\n"
776
+ "mov.b32 p2, {h4, h5};\n"
777
+ "mov.b32 p3, {h6, h7};\n"
778
+ "st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n"
779
+ "}\n"
780
+ "mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
781
+ "mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
782
+ "mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
783
+ "mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; "
784
+ "mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; "
785
+ "mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; "
786
+ "mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; "
787
+ "mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;",
788
+ (
789
+ "=f,=f,=f,=f,=f,=f,=f,=f,"
790
+ "=f,=f,=f,=f,=f,=f,=f,=f,"
791
+ "l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f"
792
+ ),
793
+ has_side_effects=True,
794
+ is_align_stack=False,
795
+ asm_dialect=llvm.AsmDialect.AD_ATT,
796
+ loc=loc,
797
+ ip=ip,
798
+ )
799
+
800
+
801
+ def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout:
802
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
803
+ acc_layout_mn = cute.make_layout(
804
+ (
805
+ acc_layout_col_major.shape[0][0],
806
+ acc_layout_col_major.shape[0][1],
807
+ acc_layout_col_major.shape[1],
808
+ *acc_layout_col_major.shape[2:],
809
+ ),
810
+ stride=(
811
+ acc_layout_col_major.stride[0][0],
812
+ acc_layout_col_major.stride[0][1],
813
+ acc_layout_col_major.stride[1],
814
+ *acc_layout_col_major.stride[2:],
815
+ ),
816
+ )
817
+ return cute.composition(acc_layout, acc_layout_mn)
818
+
819
+
820
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
821
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
822
+ acc_layout_mn = cute.make_layout(
823
+ (
824
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]),
825
+ (
826
+ acc_layout_col_major.shape[0][0],
827
+ *acc_layout_col_major.shape[0][2:],
828
+ acc_layout_col_major.shape[2],
829
+ ),
830
+ *acc_layout_col_major.shape[3:],
831
+ ),
832
+ stride=(
833
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]),
834
+ (
835
+ acc_layout_col_major.stride[0][0],
836
+ *acc_layout_col_major.stride[0][2:],
837
+ acc_layout_col_major.stride[2],
838
+ ),
839
+ *acc_layout_col_major.stride[3:],
840
+ ),
841
+ )
842
+ return cute.composition(acc_layout, acc_layout_mn)
843
+
844
+
845
+ def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor:
846
+ layout = convert_layout_acc_mn(
847
+ convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout)
848
+ )
849
+ return cute.make_tensor(tensor.iterator, layout)
850
+
851
+
852
+ def real_col_to_stg128_fake_col(col: Int32) -> Int32:
853
+ nt = col // Int32(16)
854
+ col16 = col - nt * Int32(16)
855
+ pair = col16 // Int32(2)
856
+ rank = pair % Int32(4)
857
+ kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2))
858
+ return nt * Int32(16) + rank * Int32(4) + kv
859
+
860
+
861
+ def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32:
862
+ nt = fake_col // Int32(16)
863
+ fake16 = fake_col - nt * Int32(16)
864
+ rank = fake16 // Int32(4)
865
+ kv = fake16 % Int32(4)
866
+ return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2))
867
+
868
+
869
+ def real_col_to_stg128_half_fake_col(col: Int32) -> Int32:
870
+ nt = col // Int32(32)
871
+ col32 = col - nt * Int32(32)
872
+ lane = (col32 % Int32(8)) // Int32(2)
873
+ group = col32 // Int32(8)
874
+ elem = col32 % Int32(2)
875
+ return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem
876
+
877
+
878
+ def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32:
879
+ nt = fake_col // Int32(32)
880
+ fake32 = fake_col - nt * Int32(32)
881
+ lane = fake32 // Int32(8)
882
+ lane_slot = fake32 - lane * Int32(8)
883
+ group = lane_slot // Int32(2)
884
+ elem = lane_slot - group * Int32(2)
885
+ return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem
886
+
887
+
888
+ def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32:
889
+ nt = col // Int32(64)
890
+ col64 = col - nt * Int32(64)
891
+ lane = (col64 % Int32(8)) // Int32(2)
892
+ group = col64 // Int32(8)
893
+ elem = col64 % Int32(2)
894
+ return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem
895
+
896
+
897
+ def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32:
898
+ nt = fake_col // Int32(64)
899
+ fake64 = fake_col - nt * Int32(64)
900
+ lane = fake64 // Int32(16)
901
+ lane_slot = fake64 - lane * Int32(16)
902
+ group = lane_slot // Int32(2)
903
+ elem = lane_slot - group * Int32(2)
904
+ return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem
905
+
906
+
907
+ # Cluster & Bulk Async Ops
908
+
909
+ @dsl_user_op
910
+ def set_block_rank(
911
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
912
+ ) -> Int32:
913
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
914
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
915
+ return Int32(
916
+ llvm.inline_asm(
917
+ T.i32(),
918
+ [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
919
+ "mapa.shared::cluster.u32 $0, $1, $2;",
920
+ "=r,r,r",
921
+ has_side_effects=False,
922
+ is_align_stack=False,
923
+ asm_dialect=llvm.AsmDialect.AD_ATT,
924
+ )
925
+ )
926
+
927
+
928
+ @dsl_user_op
929
+ def store_shared_remote_fp32x4(
930
+ a: Float32,
931
+ b: Float32,
932
+ c: Float32,
933
+ d: Float32,
934
+ smem_ptr: cute.Pointer,
935
+ mbar_ptr: cute.Pointer,
936
+ peer_cta_rank_in_cluster: Int32,
937
+ *,
938
+ loc=None,
939
+ ip=None,
940
+ ) -> None:
941
+ remote_smem_ptr_i32 = set_block_rank(
942
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
943
+ ).ir_value()
944
+ remote_mbar_ptr_i32 = set_block_rank(
945
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
946
+ ).ir_value()
947
+ llvm.inline_asm(
948
+ None,
949
+ [
950
+ remote_smem_ptr_i32,
951
+ remote_mbar_ptr_i32,
952
+ Float32(a).ir_value(loc=loc, ip=ip),
953
+ Float32(b).ir_value(loc=loc, ip=ip),
954
+ Float32(c).ir_value(loc=loc, ip=ip),
955
+ Float32(d).ir_value(loc=loc, ip=ip),
956
+ ],
957
+ "{\n\t"
958
+ ".reg .v4 .f32 abcd;\n\t"
959
+ "mov.f32 abcd.x, $2;\n\t"
960
+ "mov.f32 abcd.y, $3;\n\t"
961
+ "mov.f32 abcd.z, $4;\n\t"
962
+ "mov.f32 abcd.w, $5;\n\t"
963
+ "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t"
964
+ "}\n",
965
+ "r,r,f,f,f,f",
966
+ has_side_effects=True,
967
+ is_align_stack=False,
968
+ asm_dialect=llvm.AsmDialect.AD_ATT,
969
+ )
970
+
971
+
972
+ @dsl_user_op
973
+ def cpasync_bulk_s2cluster(
974
+ smem_src_ptr: cute.Pointer,
975
+ smem_dst_ptr: cute.Pointer,
976
+ mbar_ptr: cute.Pointer,
977
+ size: int | Int32,
978
+ peer_cta_rank_in_cluster: Int32,
979
+ *,
980
+ loc=None,
981
+ ip=None,
982
+ ):
983
+ smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()
984
+ smem_dst_ptr_i32 = set_block_rank(
985
+ smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
986
+ ).ir_value()
987
+ mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
988
+ llvm.inline_asm(
989
+ None,
990
+ [
991
+ smem_dst_ptr_i32,
992
+ smem_src_ptr_i32,
993
+ mbar_ptr_i32,
994
+ Int32(size).ir_value(loc=loc, ip=ip),
995
+ ],
996
+ "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];",
997
+ "r,r,r,r",
998
+ has_side_effects=True,
999
+ is_align_stack=False,
1000
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1001
+ )
1002
+
1003
+
1004
+ @dsl_user_op
1005
+ def cpasync_bulk_g2s(
1006
+ gmem_ptr: cute.Pointer,
1007
+ smem_ptr: cute.Pointer,
1008
+ tma_bar_ptr: cute.Pointer,
1009
+ size: int | Int32,
1010
+ *,
1011
+ loc=None,
1012
+ ip=None,
1013
+ ):
1014
+ gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
1015
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
1016
+ mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()
1017
+ llvm.inline_asm(
1018
+ None,
1019
+ [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],
1020
+ "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];",
1021
+ "l,r,r,r",
1022
+ has_side_effects=True,
1023
+ is_align_stack=False,
1024
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1025
+ )
1026
+
1027
+
1028
+ @dsl_user_op
1029
+ def cpasync_reduce_bulk_add_f32(
1030
+ smem_ptr: cute.Pointer,
1031
+ gmem_ptr: cute.Pointer,
1032
+ store_bytes: int | Int32,
1033
+ *,
1034
+ loc=None,
1035
+ ip=None,
1036
+ ):
1037
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
1038
+ # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
1039
+ llvm.inline_asm(
1040
+ None,
1041
+ [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
1042
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
1043
+ "l,r,r",
1044
+ # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
1045
+ # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
1046
+ # "l,r,r,l",
1047
+ has_side_effects=True,
1048
+ is_align_stack=False,
1049
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1050
+ )
1051
+
1052
+
1053
+ def cpasync_bulk_get_copy_fn(
1054
+ src_tensor: cute.Tensor,
1055
+ dst_tensor: cute.Tensor,
1056
+ single_stage: bool = False,
1057
+ **kwargs,
1058
+ ) -> Callable:
1059
+ # src_is_smem = const_expr(
1060
+ # isinstance(src_tensor.iterator, cute.Pointer)
1061
+ # and src_tensor.memspace == cute.AddressSpace.smem
1062
+ # )
1063
+ group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
1064
+ group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
1065
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
1066
+ src = cute.group_modes(src_tensor, 0, group_rank_src)
1067
+ dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
1068
+
1069
+ def copy_bulk(src_idx, dst_idx, **new_kwargs):
1070
+ size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)
1071
+ cpasync_bulk_g2s(
1072
+ src[None, src_idx].iterator,
1073
+ dst[None, dst_idx].iterator,
1074
+ size=size,
1075
+ **new_kwargs,
1076
+ **kwargs,
1077
+ )
1078
+
1079
+ def copy_bulk_single_stage(**new_kwargs):
1080
+ size = const_expr(cute.size(src.shape) * src.element_type.width // 8)
1081
+ cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)
1082
+
1083
+ return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
1084
+
1085
+
1086
+ # TMA Copy Adapters
1087
+
1088
+ def tma_get_copy_fn(
1089
+ atom: cute.CopyAtom,
1090
+ cta_coord: cute.Coord,
1091
+ cta_layout: cute.Layout,
1092
+ src_tensor: cute.Tensor,
1093
+ dst_tensor: cute.Tensor,
1094
+ filter_zeros: bool = False,
1095
+ single_stage: bool = False,
1096
+ **kwargs,
1097
+ ) -> Callable:
1098
+ src_is_smem = const_expr(
1099
+ isinstance(src_tensor.iterator, cute.Pointer)
1100
+ and src_tensor.memspace == cute.AddressSpace.smem
1101
+ )
1102
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
1103
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
1104
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
1105
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
1106
+ s, g = cpasync.tma_partition(
1107
+ atom,
1108
+ cta_coord,
1109
+ cta_layout,
1110
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
1111
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
1112
+ )
1113
+ if const_expr(filter_zeros):
1114
+ s = cute.filter_zeros(s)
1115
+ g = cute.filter_zeros(g)
1116
+ src, dst = (s, g) if src_is_smem else (g, s)
1117
+
1118
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
1119
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
1120
+
1121
+ def copy_tma_single_stage(**new_kwargs):
1122
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
1123
+
1124
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
1125
+
1126
+
1127
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
1128
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
1129
+ copy(
1130
+ src_idx=src_idx,
1131
+ dst_idx=producer_state.index,
1132
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
1133
+ **new_kwargs,
1134
+ )
1135
+
1136
+ return copy_fn
1137
+
1138
+
1139
+ __all__ = [
1140
+ "atomic_add_broadcast_i32",
1141
+ "atomic_add_fp32x4",
1142
+ "atomic_add_i32",
1143
+ "convert_layout_acc_mn",
1144
+ "convert_layout_from_tmem16x256b_to_acc_sm90",
1145
+ "copy",
1146
+ "cpasync_bulk_g2s",
1147
+ "cpasync_bulk_get_copy_fn",
1148
+ "cpasync_bulk_s2cluster",
1149
+ "cpasync_reduce_bulk_add_f32",
1150
+ "cvt_copy",
1151
+ "get_copy_atom",
1152
+ "load_s2r",
1153
+ "make_16x256b_tensor_mn_view",
1154
+ "make_tmem_copy",
1155
+ "real_col_to_stg128_fake_col",
1156
+ "real_col_to_stg128_fp8_fake_col",
1157
+ "real_col_to_stg128_half_fake_col",
1158
+ "set_block_rank",
1159
+ "stg128_fake_col_to_real_col",
1160
+ "stg128_fp8_fake_col_to_real_col",
1161
+ "stg128_half_fake_col_to_real_col",
1162
+ "stg_128",
1163
+ "stg_128_cs",
1164
+ "stg_128_bf16",
1165
+ "stg_128_bf16_cs",
1166
+ "stg_128_f16",
1167
+ "stg_128_f16_cs",
1168
+ "stg_128_fp8_e4m3_cs",
1169
+ "stg_32_fp8_e4m3",
1170
+ "stg_64_bf16",
1171
+ "stg_64_f16",
1172
+ "sts_32_bf16",
1173
+ "sts_32_f16",
1174
+ "store_shared_remote_fp32x4",
1175
+ "tiled_copy_1d",
1176
+ "tiled_copy_2d",
1177
+ "tma_get_copy_fn",
1178
+ "tma_producer_copy_fn",
1179
+ ]
build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import logging
5
+ import os
6
+ import pathlib
7
+ import time
8
+ from typing import Tuple
9
+ from functools import partial, lru_cache
10
+ from dataclasses import dataclass, fields
11
+
12
+ import torch
13
+
14
+ logger = logging.getLogger("minimax")
15
+
16
+ try:
17
+ from triton.tools.disasm import extract
18
+ except ImportError:
19
+ extract = None
20
+
21
+ import cutlass
22
+ import cutlass.cute as cute
23
+ from cutlass.base_dsl.typing import JitArgument
24
+ from cutlass.cutlass_dsl import NumericMeta
25
+ from cutlass.cute.runtime import from_dlpack
26
+
27
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
28
+
29
+
30
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
31
+ cute_compile_og = cute.compile
32
+
33
+
34
+ torch2cute_dtype_map = {
35
+ torch.float16: cutlass.Float16,
36
+ torch.bfloat16: cutlass.BFloat16,
37
+ torch.float32: cutlass.Float32,
38
+ torch.float8_e4m3fn: cutlass.Float8E4M3FN,
39
+ }
40
+
41
+
42
+ @lru_cache
43
+ def get_max_active_clusters(cluster_size):
44
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
45
+
46
+
47
+ @lru_cache
48
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
49
+ return torch.cuda.get_device_capability(device)
50
+
51
+
52
+ @dataclass
53
+ class ArgumentsBase(JitArgument):
54
+ def __c_pointers__(self):
55
+ all_fields = [getattr(self, field.name) for field in fields(self)]
56
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
57
+ c_ptrs = []
58
+ for obj in non_constexpr_fields:
59
+ if hasattr(obj, "__c_pointers__"):
60
+ c_ptrs.extend(obj.__c_pointers__())
61
+ return c_ptrs
62
+
63
+ def __get_mlir_types__(self):
64
+ all_fields = [getattr(self, field.name) for field in fields(self)]
65
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
66
+ types, self._values_pos = [], []
67
+ for obj in non_constexpr_fields:
68
+ if hasattr(obj, "__get_mlir_types__"):
69
+ obj_types = obj.__get_mlir_types__()
70
+ types.extend(obj_types)
71
+ self._values_pos.append(len(obj_types))
72
+ else:
73
+ self._values_pos.append(0)
74
+ return types
75
+
76
+ def __new_from_mlir_values__(self, values):
77
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
78
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
79
+ non_constexpr_fields = {
80
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
81
+ }
82
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
83
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
84
+ values = values[n_items:]
85
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
86
+
87
+
88
+ def load_cubin_module_data_patched(cubin_data, filepath):
89
+ pathlib.Path(filepath).write_bytes(cubin_data)
90
+ return load_cubin_module_data_og(cubin_data)
91
+
92
+
93
+ def cute_compile_patched(*args, **kwargs):
94
+ """A patched version of cute.compile.
95
+
96
+ Behaviour:
97
+ - Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set.
98
+ - Logs JIT compile wall time at DEBUG level via the ``minimax`` logger,
99
+ tagged with the kernel's class name when available. Enable with
100
+ ``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env
101
+ ``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT
102
+ (~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md).
103
+ """
104
+ cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
105
+ if cubin_path is not None:
106
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
107
+ load_cubin_module_data_patched, filepath=cubin_path
108
+ )
109
+ kernel_obj = args[0] if args else kwargs.get("op")
110
+ kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "<unknown>"
111
+ t0 = time.time()
112
+ output = cute_compile_og(*args, **kwargs)
113
+ dt = time.time() - t0
114
+ logger.debug("[%s] compiled in %.1fs", kernel_name, dt)
115
+ if cubin_path is not None:
116
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
117
+ if extract is not None:
118
+ sass = extract(cubin_path, None)
119
+ pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
120
+ return output
121
+
122
+
123
+ if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1":
124
+ if not logger.handlers:
125
+ _h = logging.StreamHandler()
126
+ _h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s"))
127
+ logger.addHandler(_h)
128
+ logger.setLevel(logging.DEBUG)
129
+
130
+
131
+ # Monkey-patch cute.compile so every JIT compile across the repo gets timed
132
+ # without touching individual call sites. Idempotent: only patches once.
133
+ if cute.compile is not cute_compile_patched:
134
+ cute.compile = cute_compile_patched
135
+
136
+
137
+ def assume_strides_aligned(t):
138
+ """Assume all strides except the last are divisible by 128 bits.
139
+
140
+ Python int strides (e.g., stride=0 from GQA expand) are kept as-is
141
+ since they're static and don't need alignment assumptions.
142
+ """
143
+ divby = 128 // t.element_type.width
144
+ strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
145
+ return (*strides, t.stride[-1])
146
+
147
+
148
+ def assume_tensor_aligned(t):
149
+ """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
150
+ if t is None:
151
+ return None
152
+ return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
153
+
154
+
155
+ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
156
+ """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
157
+ tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
158
+ if fully_dynamic:
159
+ return tensor.mark_layout_dynamic()
160
+ if leading_dim == -1:
161
+ leading_dim = t.ndim - 1
162
+ return tensor.mark_layout_dynamic(leading_dim=leading_dim)
163
+
164
+
165
+ def to_cute_aux_tensor(t, enable_tvm_ffi=True):
166
+ """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
167
+ This allows the user to specify alignment and leading dimension for aux tensors used in
168
+ custom score_mod callables.
169
+ """
170
+ assumed_align: int = getattr(t, "__assumed_align__", None)
171
+ leading_dim: int = getattr(t, "__leading_dim__", None)
172
+ fully_dynamic: bool = leading_dim is None
173
+
174
+ return to_cute_tensor(
175
+ t,
176
+ assumed_align=assumed_align,
177
+ leading_dim=leading_dim,
178
+ fully_dynamic=fully_dynamic,
179
+ enable_tvm_ffi=enable_tvm_ffi,
180
+ )
181
+
182
+
183
+ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
184
+ """Return tuple of bools indicating which dims have stride=0 (broadcast).
185
+
186
+ This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
187
+ stride=0 as static, meaning kernels compiled with different broadcast
188
+ patterns are not interchangeable.
189
+ """
190
+ return tuple(s == 0 for s in tensor.stride())
build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Int32
7
+
8
+
9
+ @cute.jit
10
+ def clz(x: Int32) -> Int32:
11
+ # for i in cutlass.range_constexpr(32):
12
+ # if (1 << (31 - i)) & x:
13
+ # return Int32(i)
14
+ # return Int32(32)
15
+ # Early exit is not supported yet
16
+ res = Int32(32)
17
+ done = False
18
+ for i in cutlass.range(32):
19
+ if ((1 << (31 - i)) & x) and not done:
20
+ res = Int32(i)
21
+ done = True
22
+ return res
build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from typing import Callable, Optional, TypeAlias
5
+ from dataclasses import dataclass
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32, Int32, Uint32, const_expr
10
+
11
+ from ...src.common import utils as utils
12
+ from ...src.common.seqlen_info import SeqlenInfoQK
13
+
14
+ MaskGenFn: TypeAlias = Callable[[int], Uint32]
15
+ MASK_R2P_CHUNK_SIZE: int = 32
16
+
17
+
18
+ @cute.jit
19
+ def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
20
+ m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)
21
+ return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))
22
+
23
+
24
+ @cute.jit
25
+ def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
26
+ n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)
27
+ return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))
28
+
29
+
30
+ @cute.jit
31
+ def mask_r2p_lambda(
32
+ X: cute.Tensor,
33
+ mask_gen_fn: cutlass.Constexpr[MaskGenFn],
34
+ rank1: bool = False,
35
+ ) -> None:
36
+ ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
37
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)):
38
+ mask = mask_gen_fn(s)
39
+ for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)):
40
+ in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
41
+ c = s * MASK_R2P_CHUNK_SIZE + i
42
+ if const_expr(rank1):
43
+ X[c] = X[c] if in_bound else -Float32.inf
44
+ else:
45
+ for r in cutlass.range_constexpr(cute.size(X.shape[0])):
46
+ X[r, c] = X[r, c] if in_bound else -Float32.inf
47
+
48
+
49
+ @cute.jit
50
+ def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
51
+ return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class AttentionMask:
56
+ tile_m: cutlass.Constexpr[int]
57
+ tile_n: cutlass.Constexpr[int]
58
+ seqlen_info: SeqlenInfoQK
59
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
60
+ swap_AB: cutlass.Constexpr[bool] = False
61
+
62
+ @property
63
+ def seqlen_q(self) -> Int32:
64
+ return self.seqlen_info.seqlen_q
65
+
66
+ @property
67
+ def seqlen_k(self) -> Int32:
68
+ return self.seqlen_info.seqlen_k
69
+
70
+ @cute.jit
71
+ def apply_mask_sm100(
72
+ self,
73
+ acc_S: cute.Tensor,
74
+ tScS_t2r: cute.Tensor,
75
+ m_block: Int32,
76
+ n_block: Int32,
77
+ mask_seqlen: cutlass.Constexpr[bool],
78
+ mask_causal: cutlass.Constexpr[bool],
79
+ row_idx: Optional[Int32] = None,
80
+ kv_valid_cols: Optional[Int32] = None,
81
+ kv_block_col_start: Optional[Int32] = None,
82
+ ) -> None:
83
+ if const_expr(not mask_seqlen and not mask_causal):
84
+ return
85
+
86
+ col_limit = Int32(self.tile_n)
87
+ if const_expr(mask_seqlen):
88
+ if const_expr(kv_valid_cols is not None):
89
+ col_limit = kv_valid_cols
90
+ else:
91
+ col_limit = self.seqlen_k - n_block * Int32(self.tile_n)
92
+
93
+ if const_expr(mask_causal):
94
+ if const_expr(row_idx is None):
95
+ row_axis = 0 if const_expr(not self.swap_AB) else 1
96
+ row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m)
97
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
98
+ row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa)
99
+ else:
100
+ row_idx_cur = row_idx
101
+ if const_expr(kv_block_col_start is not None):
102
+ block_col_start = kv_block_col_start
103
+ else:
104
+ block_col_start = n_block * Int32(self.tile_n)
105
+ causal_col_limit = (
106
+ row_idx_cur + self.seqlen_k - self.seqlen_q
107
+ - block_col_start + Int32(1)
108
+ )
109
+ col_limit = (
110
+ cutlass.min(col_limit, causal_col_limit)
111
+ if const_expr(mask_seqlen)
112
+ else causal_col_limit
113
+ )
114
+
115
+ if col_limit < Int32(self.tile_n):
116
+ mask_r2p_lambda(
117
+ acc_S,
118
+ lambda s: r2p_bitmask_below(col_limit, s),
119
+ rank1=True,
120
+ )
121
+
122
+ @cute.jit
123
+ def apply_mask_sm100_transposed(
124
+ self,
125
+ acc_S: cute.Tensor,
126
+ tScS_t2r: cute.Tensor,
127
+ t0ScS_t2r: cute.Tensor,
128
+ m_block: cutlass.Int32,
129
+ n_block: cutlass.Int32,
130
+ mask_seqlen: cutlass.Constexpr,
131
+ mask_causal: cutlass.Constexpr,
132
+ is_full_block: bool = False,
133
+ check_m_boundary: bool = True,
134
+ valid_tok_count: Optional[Int32] = None,
135
+ q_idx_tile: Optional[cute.Tensor] = None,
136
+ masked_tok_count: Optional[Int32] = None,
137
+ ) -> None:
138
+ del is_full_block, check_m_boundary
139
+ del t0ScS_t2r
140
+ row_axis = 0 if const_expr(not self.swap_AB) else 1
141
+ col_axis = 1 if const_expr(not self.swap_AB) else 0
142
+
143
+ if const_expr(valid_tok_count is not None):
144
+ kv_block_col_start = n_block * Int32(self.tile_n)
145
+ causal_q_offset = self.seqlen_k - self.seqlen_q
146
+ nfrag = const_expr(cute.size(acc_S.shape))
147
+ for i in cutlass.range(nfrag, unroll_full=True):
148
+ row_idx = tScS_t2r[i][row_axis]
149
+ tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa)
150
+ acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i]
151
+ if const_expr(mask_seqlen):
152
+ kv_idx = kv_block_col_start + tScS_t2r[i][col_axis]
153
+ acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i]
154
+ if const_expr(mask_causal):
155
+ if const_expr(q_idx_tile is not None):
156
+ causal_tok_count = (
157
+ masked_tok_count
158
+ if const_expr(masked_tok_count is not None)
159
+ else Int32(0)
160
+ )
161
+ if tok_idx < causal_tok_count:
162
+ q_idx = q_idx_tile[tok_idx]
163
+ kv_idx = kv_block_col_start + tScS_t2r[i][col_axis]
164
+ acc_S[i] = (
165
+ -Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i]
166
+ )
167
+ return
168
+
169
+ thr_col_offset = tScS_t2r[0][col_axis]
170
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
171
+
172
+ if const_expr(not mask_causal):
173
+ if const_expr(mask_seqlen) and seqlenk_col_limit <= 0:
174
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
175
+ acc_S[i] = -cutlass.Float32.inf
176
+ return
177
+
178
+ thr_row_offset = tScS_t2r[0][row_axis]
179
+ seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
180
+ row_limit_top = seqlenq_row_limit - seqlenk_col_limit
181
+ if const_expr(mask_seqlen) and seqlenk_col_limit <= 0:
182
+ row_limit_top = self.tile_m
183
+ num_rep = cute.size(tScS_t2r, mode=[0])
184
+ row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2)
185
+ mask_r2p_lambda(
186
+ acc_S,
187
+ lambda s: r2p_bitmask_above(row_limit, s),
188
+ rank1=True,
189
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # The bit-field encodings, enum values, and descriptor layout below mirror the
5
+ # SM100 tcgen05 MMA instruction descriptor as documented and
6
+ # implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay
7
+ # identical to the hardware/ISA encodings; see the "Third-party licenses"
8
+ # section of README.md at the repo root for attribution.
9
+
10
+ from enum import IntEnum
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Enumerations that match the HW encodings (values MUST stay identical)
17
+ # ---------------------------------------------------------------------------
18
+
19
+
20
+ class Major(IntEnum): # matrix "layout" in the ISA docs
21
+ K = 0
22
+ MN = 1
23
+
24
+
25
+ class ScaleIn(IntEnum): # negate flags
26
+ One = 0
27
+ Neg = 1
28
+
29
+
30
+ class Saturate(IntEnum):
31
+ False_ = 0
32
+ True_ = 1
33
+
34
+
35
+ class CFormat(IntEnum): # 2-bit field (bits 4-5)
36
+ F16 = 0
37
+ F32 = 1
38
+ S32 = 2
39
+
40
+
41
+ class F16F32Format(IntEnum): # 3-bit field (A/B element type)
42
+ F16 = 0
43
+ BF16 = 1
44
+ TF32 = 2
45
+
46
+
47
+ class S8Format(IntEnum):
48
+ UINT8 = 0
49
+ INT8 = 1
50
+
51
+
52
+ class MXF8F6F4Format(IntEnum):
53
+ E4M3 = 0
54
+ E5M2 = 1
55
+ E2M3 = 3
56
+ E3M2 = 4
57
+ E2M1 = 5
58
+
59
+
60
+ class MaxShift(IntEnum):
61
+ NoShift = 0
62
+ MaxShift8 = 1
63
+ MaxShift16 = 2
64
+ MaxShift32 = 3
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # CUTLASS-type -> encoding helpers
69
+ # ---------------------------------------------------------------------------
70
+
71
+
72
+ def to_UMMA_format(cutlass_type) -> int:
73
+ """
74
+ Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
75
+ """
76
+ if cutlass_type is cutlass.Int8:
77
+ return S8Format.INT8
78
+ # Unsigned 8-bit (if available in your CUTLASS build)
79
+ if cutlass_type is cutlass.Uint8:
80
+ return S8Format.UINT8
81
+ # FP-16 / BF-16
82
+ if cutlass_type is cutlass.Float16:
83
+ return F16F32Format.F16
84
+ if cutlass_type is cutlass.BFloat16:
85
+ return F16F32Format.BF16
86
+ # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
87
+ if cutlass_type is cutlass.TFloat32:
88
+ return F16F32Format.TF32
89
+ # Float-8 / Float-6 / Float-4
90
+ if cutlass_type is cutlass.Float8E4M3FN:
91
+ return MXF8F6F4Format.E4M3
92
+ if cutlass_type is cutlass.Float8E5M2:
93
+ return MXF8F6F4Format.E5M2
94
+ raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
95
+
96
+
97
+ def to_C_format(cutlass_type) -> int:
98
+ """
99
+ Map a CUTLASS scalar class to the 2-bit accumulator encoding.
100
+ """
101
+ if cutlass_type is cutlass.Float16:
102
+ return CFormat.F16
103
+ if cutlass_type is cutlass.Float32:
104
+ return CFormat.F32
105
+ if cutlass_type is cutlass.Int32:
106
+ return CFormat.S32
107
+ raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # The constructor – accepts only CUTLASS scalar classes
112
+ # ---------------------------------------------------------------------------
113
+
114
+
115
+ def make_instr_desc(
116
+ a_type, # CUTLASS scalar class, e.g. cutlass.Int8
117
+ b_type,
118
+ c_type,
119
+ M: int, # 64, 128 or 256
120
+ N: int, # 8 … 256 (multiple of 8)
121
+ a_major: Major,
122
+ b_major: Major,
123
+ a_neg: ScaleIn = ScaleIn.One,
124
+ b_neg: ScaleIn = ScaleIn.One,
125
+ c_sat: Saturate = Saturate.False_,
126
+ is_sparse: bool = False,
127
+ max_shift: MaxShift = MaxShift.NoShift,
128
+ ) -> int:
129
+ """
130
+ Build the 32-bit instruction descriptor for SM100 MMA.
131
+ All matrix/accumulator **types must be CUTLASS scalar classes** –
132
+ passing integers is forbidden.
133
+ """
134
+ # --- encode element formats -------------------------------------------------
135
+ a_fmt = int(to_UMMA_format(a_type))
136
+ b_fmt = int(to_UMMA_format(b_type))
137
+ c_fmt = int(to_C_format(c_type))
138
+ is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2)
139
+
140
+ # --- range checks on M/N -----------------------------------------------------
141
+ if M not in (64, 128, 256):
142
+ raise ValueError("M must be 64, 128 or 256")
143
+ if N < 8 or N > 256 or (N & 7):
144
+ raise ValueError("N must be a multiple of 8 in the range 8…256")
145
+
146
+ m_dim = M >> 4 # 5-bit field
147
+ n_dim = N >> 3 # 6-bit field
148
+
149
+ # fmt: off
150
+ # --- pack the bit-fields -----------------------------------------------------
151
+ desc = 0
152
+ desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
153
+ desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
154
+ desc |= (int(c_sat) & 0x1) << 3 # saturate
155
+ desc |= (c_fmt & 0x3) << 4 # c_format
156
+ desc |= (a_fmt & 0x7) << 7 # a_format
157
+ desc |= (b_fmt & 0x7) << 10 # b_format
158
+ desc |= (int(a_neg) & 0x1) << 13 # a_negate
159
+ desc |= (int(b_neg) & 0x1) << 14 # b_negate
160
+ desc |= (int(a_major) & 0x1) << 15 # a_major
161
+ desc |= (int(b_major) & 0x1) << 16 # b_major
162
+ desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
163
+ # CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this
164
+ # descriptor aligned with generated/reference SM100 FP8 kernels.
165
+ desc |= (int(is_f8f6f4) & 0x1) << 23
166
+ desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
167
+ desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
168
+ # fmt: on
169
+
170
+ return desc & 0xFFFF_FFFF # ensure 32-bit result
171
+
172
+
173
+ def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
174
+ return make_instr_desc(
175
+ op.a_dtype,
176
+ op.b_dtype,
177
+ op.acc_dtype,
178
+ op.shape_mnk[0],
179
+ op.shape_mnk[1],
180
+ Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
181
+ Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
182
+ )
183
+
184
+
185
+ class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
186
+ SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs)
187
+ SWIZZLE_128B_BASE32B = 1
188
+ SWIZZLE_128B = 2
189
+ SWIZZLE_64B = 4
190
+ SWIZZLE_32B = 6
191
+ # values 3,5,7 are reserved / illegal for UMMA
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Helpers – figure out the SWIZZLE_* family from the tensor layout
196
+ # ---------------------------------------------------------------------------
197
+
198
+
199
+ def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
200
+ B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift
201
+
202
+ if M == 4: # Swizzle<*,4,3>
203
+ if S != 3:
204
+ raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
205
+ return {
206
+ 0: LayoutType.SWIZZLE_NONE,
207
+ 1: LayoutType.SWIZZLE_32B,
208
+ 2: LayoutType.SWIZZLE_64B,
209
+ 3: LayoutType.SWIZZLE_128B,
210
+ }[B] # KeyError ⇒ invalid B→ raise
211
+ if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
212
+ if (B, S) != (2, 2):
213
+ raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
214
+ return LayoutType.SWIZZLE_128B_BASE32B
215
+
216
+ # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
217
+ raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
218
+
219
+
220
+ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
221
+ """
222
+ Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit
223
+ smem-descriptor, without the smem start address.
224
+ layout must correspond to layout of an uint128 tensor.
225
+ """
226
+ # ------------------------------------------------------------------ meta
227
+ layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
228
+
229
+ VERSION = 1 # bits 46–47
230
+ LBO_MODE = 0 # bit 52
231
+ BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
232
+
233
+ # ---------------------------------------------------------- strides (units: uint128_t = 16 B)
234
+ swizzle_atom_mn_size = {
235
+ LayoutType.SWIZZLE_NONE: 1,
236
+ LayoutType.SWIZZLE_32B: 2,
237
+ LayoutType.SWIZZLE_64B: 4,
238
+ LayoutType.SWIZZLE_128B: 8,
239
+ LayoutType.SWIZZLE_128B_BASE32B: 8,
240
+ }[layout_type]
241
+
242
+ if major is Major.MN:
243
+ swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
244
+ canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
245
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
246
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
247
+ stride_00 = canonical_layout.stride[0][0]
248
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
249
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
250
+ stride_10 = canonical_layout.stride[1][0]
251
+ if stride_10 != swizzle_atom_mn_size:
252
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
253
+ stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
254
+ if layout_type is LayoutType.SWIZZLE_NONE:
255
+ stride_byte_offset, leading_byte_offset = stride_01, stride_11
256
+ else:
257
+ stride_byte_offset, leading_byte_offset = stride_11, stride_01
258
+ else:
259
+ if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
260
+ raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
261
+ if not cute.size(layout.shape[0]) % 8 == 0:
262
+ raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
263
+ canonical_layout = cute.logical_divide(layout, (8, 2))
264
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
265
+ raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
266
+ stride_00 = canonical_layout.stride[0][0]
267
+ if stride_00 != swizzle_atom_mn_size:
268
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
269
+ stride_10 = canonical_layout.stride[1][0]
270
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
271
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
272
+ stride_01 = canonical_layout.stride[0][1]
273
+ stride_byte_offset, leading_byte_offset = stride_01, stride_10
274
+
275
+ # ------------------------------------------------------------------ pack
276
+ desc = 0
277
+ # leading_byte_offset_ [16:30)
278
+ desc |= (leading_byte_offset & 0x3FFF) << 16
279
+ # stride_byte_offset_ [32:46)
280
+ desc |= (stride_byte_offset & 0x3FFF) << 32
281
+ # version_ [46:48)
282
+ desc |= (VERSION & 0x3) << 46
283
+ # base_offset_ [49:52)
284
+ desc |= (BASE_OFFSET & 0x7) << 49
285
+ # lbo_mode_ [52:53)
286
+ desc |= (LBO_MODE & 0x1) << 52
287
+ # layout_type_ [61:64)
288
+ desc |= (int(layout_type) & 0x7) << 61
289
+
290
+ return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
291
+
292
+
293
+ def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
294
+ # 14 bits, remove 4 LSB (bits 0-13 in desc)
295
+ return (start_addr.toint() & 0x3FFFF) >> 4
296
+
297
+
298
+ def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:
299
+ sA_swizzle = sA.iterator.type.swizzle_type
300
+ return make_smem_desc_base(
301
+ cute.recast_layout(128, sA.element_type.width, sA.layout[0]),
302
+ sA_swizzle,
303
+ major,
304
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import enum
5
+
6
+
7
+ class NamedBarrierFwdSm100(enum.IntEnum):
8
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
9
+ TmemPtr = enum.auto()
10
+ SoftmaxStatsW0 = enum.auto()
11
+ SoftmaxStatsW1 = enum.auto()
12
+ SoftmaxStatsW2 = enum.auto()
13
+ SoftmaxStatsW3 = enum.auto()
14
+ SoftmaxStatsW4 = enum.auto()
15
+ SoftmaxStatsW5 = enum.auto()
16
+ SoftmaxStatsW6 = enum.auto()
17
+ SoftmaxStatsW7 = enum.auto()
18
+ LoadWG = enum.auto()
19
+ StoreEpilogue = enum.auto()
20
+ KvLoad = enum.auto()
21
+ KvDequantK = enum.auto()
22
+ KvDequantV = enum.auto()
build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """PackGQA primitives for GQA (grouped-query attention) tile layouts.
5
+
6
+ Contains:
7
+ - ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead``
8
+ into the seqlen dimension of a tensor layout (zero-copy view).
9
+ - ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` /
10
+ ``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)``
11
+ as a single packed row dimension.
12
+ - ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE``
13
+ for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q
14
+ innermost (stride-1).
15
+ """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional
19
+
20
+ import cutlass
21
+ import cutlass.cute as cute
22
+ from cutlass import Float32, Int32, const_expr
23
+ from cutlass.cute import FastDivmodDivisor
24
+
25
+ from ...quack import layout_utils
26
+
27
+ from . import utils
28
+
29
+
30
+ def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):
31
+ """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0).
32
+
33
+ The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
34
+ are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
35
+ as-is (e.g. batch).
36
+
37
+ For Q/O tensors (head_idx=2):
38
+ (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...)
39
+ For LSE tensors (head_idx=1):
40
+ (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...)
41
+ """
42
+ head_stride = T.stride[head_idx]
43
+ shape_packed = (
44
+ (qhead_per_kvhead, T.shape[0]),
45
+ *[T.shape[i] for i in range(1, head_idx)],
46
+ nheads_kv,
47
+ *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
48
+ )
49
+ stride_packed = (
50
+ (head_stride, T.stride[0]),
51
+ *[T.stride[i] for i in range(1, head_idx)],
52
+ head_stride * qhead_per_kvhead,
53
+ *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
54
+ )
55
+ return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
56
+
57
+
58
+ def unpack_gqa_layout(T, qhead_per_kvhead, head_idx):
59
+ """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0).
60
+
61
+ The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
62
+ are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
63
+ as-is (e.g. batch).
64
+
65
+ For Q/O tensors (head_idx=2):
66
+ ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...)
67
+ For LSE tensors (head_idx=1):
68
+ ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...)
69
+ """
70
+ seqlen_stride = T.stride[0][1]
71
+ head_stride = T.stride[0][0]
72
+ shape_unpacked = (
73
+ T.shape[0][1],
74
+ *[T.shape[i] for i in range(1, head_idx)],
75
+ T.shape[head_idx] * qhead_per_kvhead,
76
+ *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
77
+ )
78
+ stride_unpacked = (
79
+ seqlen_stride,
80
+ *[T.stride[i] for i in range(1, head_idx)],
81
+ head_stride,
82
+ *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
83
+ )
84
+ return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked))
85
+
86
+
87
+ @dataclass
88
+ class PackGQA:
89
+ m_block_size: cutlass.Constexpr[int]
90
+ head_dim_padded: cutlass.Constexpr[int]
91
+ check_hdim_oob: cutlass.Constexpr[bool]
92
+ qhead_per_kvhead: cutlass.Constexpr[bool]
93
+
94
+ @cute.jit
95
+ def compute_ptr(
96
+ self,
97
+ tensor: cute.Tensor,
98
+ cRows: cute.Tensor,
99
+ tidx: cutlass.Int32,
100
+ block: cutlass.Int32,
101
+ threads_per_row: cutlass.Constexpr[int],
102
+ num_threads: cutlass.Constexpr[int],
103
+ ):
104
+ num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
105
+ tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64)
106
+ for i in cutlass.range_constexpr(num_ptr_per_thread):
107
+ row = i * num_threads + cRows[tidx % threads_per_row][0]
108
+ idx = block * self.m_block_size + row
109
+ m_idx = idx // self.qhead_per_kvhead
110
+ h_idx = idx - m_idx * self.qhead_per_kvhead
111
+ tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
112
+ return tPrPtr
113
+
114
+ @cute.jit
115
+ def load_Q(
116
+ self,
117
+ mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
118
+ sQ: cute.Tensor, # (m_block_size, head_dim_padded)
119
+ gmem_tiled_copy: cute.TiledCopy,
120
+ tidx: cutlass.Int32,
121
+ block: cutlass.Int32,
122
+ seqlen: cutlass.Int32,
123
+ ):
124
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
125
+ cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
126
+ tQsQ = gmem_thr_copy.partition_D(sQ)
127
+ tQcQ = gmem_thr_copy.partition_S(cQ)
128
+ t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
129
+ tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
130
+ tQcQ_row = tQcQ[0, None, 0]
131
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
132
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
133
+ num_threads = gmem_tiled_copy.size
134
+ tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
135
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
136
+ q_ptr_i64 = utils.shuffle_sync(
137
+ tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
138
+ )
139
+ q_gmem_ptr = cute.make_ptr(
140
+ mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
141
+ )
142
+ if (
143
+ t0QcQ[0, m, 0][0]
144
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
145
+ ):
146
+ mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
147
+ elems_per_load = cute.size(tQsQ.shape[0][0])
148
+ mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
149
+ for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
150
+ ki = tQcQ[0, 0, k][1] // elems_per_load
151
+ cute.copy(
152
+ gmem_thr_copy,
153
+ mQ_cur_copy[None, ki],
154
+ tQsQ[None, m, k],
155
+ pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
156
+ )
157
+
158
+ @cute.jit
159
+ def store_LSE(
160
+ self,
161
+ mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
162
+ tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
163
+ tiled_mma: cute.TiledMma,
164
+ tidx: cutlass.Int32,
165
+ block: cutlass.Int32,
166
+ seqlen: cutlass.Int32,
167
+ ):
168
+ thr_mma = tiled_mma.get_slice(tidx)
169
+ caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
170
+ taccOcO = thr_mma.partition_C(caccO)
171
+ taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
172
+ assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
173
+ threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
174
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
175
+ assert cute.size(tLSErLSE) <= threads_per_row
176
+ num_threads = tiled_mma.size
177
+ tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
178
+ for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
179
+ lse_ptr_i64 = utils.shuffle_sync(
180
+ tPrLSEPtr[m // threads_per_row],
181
+ m % threads_per_row,
182
+ width=threads_per_row,
183
+ )
184
+ lse_gmem_ptr = cute.make_ptr(
185
+ mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
186
+ )
187
+ row = block * self.m_block_size + taccOcO_row[m][0]
188
+ # Only the thread corresponding to column 0 writes out the lse to gmem
189
+ if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
190
+ mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
191
+ mLSE_copy[0] = tLSErLSE[m]
192
+
193
+ @cute.jit
194
+ def store_O(
195
+ self,
196
+ mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
197
+ tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
198
+ gmem_tiled_copy: cute.TiledCopy,
199
+ tidx: cutlass.Int32,
200
+ block: cutlass.Int32,
201
+ seqlen: cutlass.Int32,
202
+ ):
203
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
204
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
205
+ tOcO = gmem_thr_copy.partition_S(cO)
206
+ t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
207
+ tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
208
+ tOcO_row = tOcO[0, None, 0]
209
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
210
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
211
+ num_threads = gmem_tiled_copy.size
212
+ tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
213
+ for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
214
+ o_ptr_i64 = utils.shuffle_sync(
215
+ tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
216
+ )
217
+ o_gmem_ptr = cute.make_ptr(
218
+ mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
219
+ )
220
+ if (
221
+ t0OcO[0, m, 0][0]
222
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
223
+ ):
224
+ mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
225
+ elems_per_load = cute.size(tOrO.shape[0][0])
226
+ mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
227
+ for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
228
+ ki = tOcO[0, 0, k][1] // elems_per_load
229
+ cute.copy(
230
+ gmem_thr_copy,
231
+ tOrO[None, m, k],
232
+ mO_cur_copy[None, ki],
233
+ pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
234
+ )
235
+
236
+
237
+ @dataclass
238
+ class PackGQAComb(PackGQA):
239
+ """PackGQA subclass for the K2 combine kernel.
240
+
241
+ Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from
242
+ ``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when
243
+ LSE_partial is laid out with H_q innermost.
244
+
245
+ K2 combine treats each query head independently (no GQA grouping in combine
246
+ itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller —
247
+ all heads are folded into one "group" per Sq position.
248
+ """
249
+
250
+ @cute.jit
251
+ def load_LSE(
252
+ self,
253
+ mLSE_partial: cute.Tensor,
254
+ # Packed layout after caller-side reshape:
255
+ # shape ((qhead_per_kvhead, seqlen_q), num_splits)
256
+ # stride ((1, qhead_per_kvhead), ...)
257
+ # — H_q is the innermost (stride-1) element of the packed first dim.
258
+ sLSE: cute.Tensor,
259
+ # SMEM destination: ``(topk, m_block_size)`` fp32.
260
+ topk: cutlass.Constexpr[int],
261
+ # Explicit topk so the identity tensor shape is a plain int,
262
+ # avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape.
263
+ gmem_tiled_copy: cute.TiledCopy,
264
+ tidx: Int32,
265
+ block: Int32,
266
+ num_splits: Int32,
267
+ seqlen: Int32,
268
+ num_heads_divmod: FastDivmodDivisor,
269
+ mCounter: Optional[cute.Tensor] = None,
270
+ batch_idx: Optional[Int32] = None,
271
+ qhead_per_kvhead: Int32 = Int32(1),
272
+ # divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so
273
+ # caller controls whether the divisor is constexpr or a runtime value.
274
+ ):
275
+ """Coalesced GMEM→SMEM async load of LSE_partial for one tile.
276
+
277
+ For each (split, row) slot this thread owns in the tile, compute the
278
+ GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32.
279
+ Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``)
280
+ are filled with ``-inf`` so they flow cleanly through downstream reductions.
281
+
282
+ Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values
283
+ (head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to
284
+ adjacent GMEM addresses when H_q is stride-1 — one sector per warp.
285
+ """
286
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
287
+ cLSE = cute.make_identity_tensor((topk, self.m_block_size))
288
+ tLSEcLSE = gmem_thr_copy.partition_S(cLSE)
289
+ tLSEsLSE = gmem_thr_copy.partition_D(sLSE)
290
+
291
+ for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
292
+ mi = tLSEcLSE[0, 0, m][1]
293
+ idx = block * self.m_block_size + mi
294
+ m_pos, h_pos = divmod(idx, num_heads_divmod)
295
+
296
+ if m_pos < seqlen:
297
+ row_count = (
298
+ mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead]
299
+ if const_expr(mCounter is not None)
300
+ else num_splits
301
+ )
302
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
303
+ si = tLSEcLSE[0, s, 0][0]
304
+ if si < num_splits and si < row_count:
305
+ # Build a 1-element GMEM tensor at ((h_pos, m_pos), si),
306
+ # matching PackGQA.store_LSE's ptr pattern so cute.copy
307
+ # receives a proper Tensor, not a scalar.
308
+ src_ptr_i64 = utils.elem_pointer(
309
+ mLSE_partial, ((h_pos, m_pos), si)).toint()
310
+ src_ptr = cute.make_ptr(
311
+ Float32, src_ptr_i64,
312
+ cute.AddressSpace.gmem, assumed_align=4,
313
+ )
314
+ src_t = cute.make_tensor(src_ptr, (1,))
315
+ cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m])
316
+ else:
317
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
318
+ else:
319
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
320
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, const_expr
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class PagedKVManager:
13
+ mPageTable: cute.Tensor
14
+ page_size: cutlass.Constexpr[int]
15
+ n_block_size: cutlass.Constexpr[int]
16
+
17
+ @staticmethod
18
+ def create(
19
+ mPageTable: cute.Tensor,
20
+ *,
21
+ page_size: int,
22
+ n_block_size: int,
23
+ ):
24
+ if page_size != n_block_size:
25
+ raise ValueError(
26
+ f"page_size ({page_size}) must equal blk_kv ({n_block_size})"
27
+ )
28
+ return PagedKVManager(
29
+ mPageTable,
30
+ page_size=page_size,
31
+ n_block_size=n_block_size,
32
+ )
33
+
34
+ @cute.jit
35
+ def logical_length(
36
+ self,
37
+ batch_idx: Int32,
38
+ num_kv_blocks: Int32,
39
+ mSeqUsedK=None,
40
+ ) -> Int32:
41
+ if const_expr(mSeqUsedK is not None):
42
+ return mSeqUsedK[batch_idx]
43
+ return num_kv_blocks * Int32(self.n_block_size)
44
+
45
+ @cute.jit
46
+ def valid_cols_in_block(
47
+ self,
48
+ batch_idx: Int32,
49
+ kv_block_idx: Int32,
50
+ num_kv_blocks: Int32,
51
+ mSeqUsedK=None,
52
+ ) -> Int32:
53
+ seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK)
54
+ block_start = kv_block_idx * Int32(self.n_block_size)
55
+ remaining = seqlen_k - block_start
56
+ remaining = cutlass.max(remaining, Int32(0))
57
+ return cutlass.min(remaining, Int32(self.n_block_size))
58
+
59
+ @cute.jit
60
+ def physical_block_index(
61
+ self,
62
+ batch_idx: Int32,
63
+ kv_block_idx: Int32,
64
+ ) -> Int32:
65
+ return self.mPageTable[batch_idx, kv_block_idx]
66
+
67
+ __all__ = ["PagedKVManager"]
build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ # import math
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+
8
+ import cutlass.cute as cute
9
+ from cutlass import Boolean, Int32, const_expr
10
+ from cutlass.cutlass_dsl import if_generate, dsl_user_op
11
+ from cutlass.pipeline import PipelineState
12
+ from cutlass.pipeline import PipelineUserType
13
+ from cutlass.pipeline import NamedBarrier as NamedBarrierOg
14
+ from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
15
+ from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
16
+ from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
17
+ from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
18
+ from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
19
+ import cutlass.pipeline as cutlass_pipeline
20
+
21
+
22
+ def make_pipeline_state(type: PipelineUserType, stages: int):
23
+ """Compatibility wrapper for FA-style helpers now vendored into src.common."""
24
+ return cutlass_pipeline.make_pipeline_state(type, stages)
25
+
26
+ @dataclass(frozen=True)
27
+ class NamedBarrier(NamedBarrierOg):
28
+ @staticmethod
29
+ def create(*args, **kwargs):
30
+ obj = NamedBarrierOg.create(*args, **kwargs)
31
+ # Can't assign to __class__ directly since the dataclass is frozen
32
+ object.__setattr__(obj, "__class__", NamedBarrier)
33
+ return obj
34
+
35
+ @dsl_user_op
36
+ def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
37
+ """
38
+ The aligned flavor of arrive is used when all threads in the CTA will execute the
39
+ same instruction. See PTX documentation.
40
+ """
41
+ cute.arch.barrier_arrive(
42
+ barrier_id=self.barrier_id + index,
43
+ number_of_threads=self.num_threads,
44
+ loc=loc,
45
+ ip=ip,
46
+ )
47
+
48
+ @dsl_user_op
49
+ def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
50
+ cute.arch.barrier(
51
+ barrier_id=self.barrier_id + index,
52
+ number_of_threads=self.num_threads,
53
+ loc=loc,
54
+ ip=ip,
55
+ )
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class PipelineAsync(PipelineAsyncOg):
60
+ @staticmethod
61
+ def create(*args, **kwargs):
62
+ obj = PipelineAsyncOg.create(*args, **kwargs)
63
+ # Can't assign to __class__ directly since the dataclass is frozen
64
+ # obj.__class__ = PipelineAsync
65
+ object.__setattr__(obj, "__class__", PipelineAsync)
66
+ return obj
67
+
68
+ @dsl_user_op
69
+ def producer_acquire_w_index_phase(
70
+ self,
71
+ index: Int32,
72
+ phase: Int32,
73
+ try_acquire_token: Optional[Boolean] = None,
74
+ *,
75
+ loc=None,
76
+ ip=None,
77
+ ):
78
+ if_generate(
79
+ try_acquire_token is None or try_acquire_token == 0,
80
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
81
+ loc=loc,
82
+ ip=ip,
83
+ )
84
+
85
+ @dsl_user_op
86
+ def producer_try_acquire_w_index_phase(
87
+ self,
88
+ index: Int32,
89
+ phase: Int32,
90
+ *,
91
+ loc=None,
92
+ ip=None,
93
+ ):
94
+ return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip)
95
+
96
+ @dsl_user_op
97
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
98
+ self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
99
+
100
+ @dsl_user_op
101
+ def consumer_wait_w_index_phase(
102
+ self,
103
+ index: Int32,
104
+ phase: Int32,
105
+ try_wait_token: Optional[Boolean] = None,
106
+ *,
107
+ loc=None,
108
+ ip=None,
109
+ ):
110
+ if_generate(
111
+ try_wait_token is None or try_wait_token == 0,
112
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
113
+ loc=loc,
114
+ ip=ip,
115
+ )
116
+
117
+ @dsl_user_op
118
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
119
+ self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
120
+
121
+
122
+ @dataclass(frozen=True)
123
+ class PipelineTmaAsync(PipelineTmaAsyncOg):
124
+ """
125
+ Override producer_acquire to take in extra_tx_count parameter.
126
+ """
127
+
128
+ @staticmethod
129
+ def create(*args, **kwargs):
130
+ obj = PipelineTmaAsyncOg.create(*args, **kwargs)
131
+ # Can't assign to __class__ directly since the dataclass is frozen
132
+ object.__setattr__(obj, "__class__", PipelineTmaAsync)
133
+ return obj
134
+
135
+ @dsl_user_op
136
+ def producer_acquire(
137
+ self,
138
+ state: PipelineState,
139
+ try_acquire_token: Optional[Boolean] = None,
140
+ extra_tx_count: int = 0,
141
+ *,
142
+ loc=None,
143
+ ip=None,
144
+ ):
145
+ """
146
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
147
+ """
148
+ if_generate(
149
+ try_acquire_token is None or try_acquire_token == 0,
150
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
151
+ loc=loc,
152
+ ip=ip,
153
+ )
154
+ if const_expr(extra_tx_count == 0):
155
+ self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
156
+ else:
157
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
158
+ self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
159
+
160
+
161
+ @dataclass(frozen=True)
162
+ class PipelineTmaUmma(PipelineTmaUmmaOg):
163
+ """
164
+ Override producer_acquire to take in extra_tx_count parameter.
165
+ """
166
+
167
+ @staticmethod
168
+ def create(*args, **kwargs):
169
+ obj = PipelineTmaUmmaOg.create(*args, **kwargs)
170
+ # Can't assign to __class__ directly since the dataclass is frozen
171
+ # obj.__class__ = PipelineTmaUmma
172
+ object.__setattr__(obj, "__class__", PipelineTmaUmma)
173
+ return obj
174
+
175
+ @dsl_user_op
176
+ def producer_acquire(
177
+ self,
178
+ state: PipelineState,
179
+ try_acquire_token: Optional[Boolean] = None,
180
+ extra_tx_count: int = 0,
181
+ *,
182
+ loc=None,
183
+ ip=None,
184
+ ):
185
+ """
186
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
187
+ """
188
+ if_generate(
189
+ try_acquire_token is None or try_acquire_token == 0,
190
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
191
+ loc=loc,
192
+ ip=ip,
193
+ )
194
+ if const_expr(extra_tx_count == 0):
195
+ if_generate(
196
+ self.is_leader_cta,
197
+ lambda: self.sync_object_full.arrive(
198
+ state.index, self.producer_mask, loc=loc, ip=ip
199
+ ),
200
+ loc=loc,
201
+ ip=ip,
202
+ )
203
+ else:
204
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
205
+ if_generate(
206
+ self.is_leader_cta,
207
+ lambda: self.sync_object_full.arrive_and_expect_tx(
208
+ state.index, tx_count, loc=loc, ip=ip
209
+ ),
210
+ loc=loc,
211
+ ip=ip,
212
+ )
213
+
214
+ @dsl_user_op
215
+ def producer_acquire_w_index_phase(
216
+ self,
217
+ index: Int32,
218
+ phase: Int32,
219
+ try_acquire_token: Optional[Boolean] = None,
220
+ *,
221
+ loc=None,
222
+ ip=None,
223
+ ):
224
+ """
225
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
226
+ """
227
+ if_generate(
228
+ try_acquire_token is None or try_acquire_token == 0,
229
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
230
+ loc=loc,
231
+ ip=ip,
232
+ )
233
+ if_generate(
234
+ self.is_leader_cta,
235
+ lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
236
+ loc=loc,
237
+ ip=ip,
238
+ )
239
+
240
+ @dsl_user_op
241
+ def consumer_wait_w_index_phase(
242
+ self,
243
+ index: Int32,
244
+ phase: Int32,
245
+ try_wait_token: Optional[Boolean] = None,
246
+ *,
247
+ loc=None,
248
+ ip=None,
249
+ ):
250
+ if_generate(
251
+ try_wait_token is None or try_wait_token == 0,
252
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
253
+ loc=loc,
254
+ ip=ip,
255
+ )
256
+
257
+ @dsl_user_op
258
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
259
+ """
260
+ UMMA consumer release buffer empty, cta_group needs to be provided.
261
+ """
262
+ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
263
+
264
+
265
+ @dataclass(frozen=True)
266
+ class PipelineUmmaAsync(PipelineUmmaAsyncOg):
267
+ @staticmethod
268
+ def create(*args, **kwargs):
269
+ obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
270
+ # Can't assign to __class__ directly since the dataclass is frozen
271
+ object.__setattr__(obj, "__class__", PipelineUmmaAsync)
272
+ return obj
273
+
274
+ @dsl_user_op
275
+ def producer_acquire_w_index_phase(
276
+ self,
277
+ index: Int32,
278
+ phase: Int32,
279
+ try_acquire_token: Optional[Boolean] = None,
280
+ *,
281
+ loc=None,
282
+ ip=None,
283
+ ):
284
+ if_generate(
285
+ try_acquire_token is None or try_acquire_token == 0,
286
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
287
+ loc=loc,
288
+ ip=ip,
289
+ )
290
+
291
+ @dsl_user_op
292
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
293
+ """
294
+ UMMA producer commit buffer full, cta_group needs to be provided.
295
+ """
296
+ self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
297
+
298
+ @dsl_user_op
299
+ def consumer_wait_w_index_phase(
300
+ self,
301
+ index: Int32,
302
+ phase: Int32,
303
+ try_wait_token: Optional[Boolean] = None,
304
+ *,
305
+ loc=None,
306
+ ip=None,
307
+ ):
308
+ if_generate(
309
+ try_wait_token is None or try_wait_token == 0,
310
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
311
+ loc=loc,
312
+ ip=ip,
313
+ )
314
+
315
+ @dsl_user_op
316
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
317
+ self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
318
+
319
+
320
+ @dataclass(frozen=True)
321
+ class PipelineAsyncUmma(PipelineAsyncUmmaOg):
322
+ @staticmethod
323
+ def create(*args, **kwargs):
324
+ obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
325
+ # Can't assign to __class__ directly since the dataclass is frozen
326
+ object.__setattr__(obj, "__class__", PipelineAsyncUmma)
327
+ return obj
328
+
329
+ @dsl_user_op
330
+ def producer_acquire_w_index_phase(
331
+ self,
332
+ index: Int32,
333
+ phase: Int32,
334
+ try_acquire_token: Optional[Boolean] = None,
335
+ *,
336
+ loc=None,
337
+ ip=None,
338
+ ):
339
+ if_generate(
340
+ try_acquire_token is None or try_acquire_token == 0,
341
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
342
+ loc=loc,
343
+ ip=ip,
344
+ )
345
+
346
+ @dsl_user_op
347
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
348
+ self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
349
+
350
+ @dsl_user_op
351
+ def consumer_wait_w_index_phase(
352
+ self,
353
+ index: Int32,
354
+ phase: Int32,
355
+ try_wait_token: Optional[Boolean] = None,
356
+ *,
357
+ loc=None,
358
+ ip=None,
359
+ ):
360
+ if_generate(
361
+ try_wait_token is None or try_wait_token == 0,
362
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
363
+ loc=loc,
364
+ ip=ip,
365
+ )
366
+
367
+ @dsl_user_op
368
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
369
+ """
370
+ UMMA consumer release buffer empty, cta_group needs to be provided.
371
+ """
372
+ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from typing import Optional
5
+ from dataclasses import dataclass
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Int32, const_expr
10
+
11
+ from ...quack import copy_utils
12
+
13
+ """
14
+ This consolidates all the info related to sequence length. This is so that we can do all
15
+ the gmem reads once at the beginning of each tile, rather than having to repeat these reads
16
+ to compute various things like n_block_min, n_block_max, etc.
17
+ """
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class SeqlenInfo:
22
+ offset: Int32
23
+ offset_padded: Int32
24
+ seqlen: Int32
25
+ has_cu_seqlens: cutlass.Constexpr[bool] = False
26
+
27
+ @staticmethod
28
+ def create(
29
+ batch_idx: Int32,
30
+ seqlen_static: Int32,
31
+ cu_seqlens: Optional[cute.Tensor] = None,
32
+ seqused: Optional[cute.Tensor] = None,
33
+ tile: cutlass.Constexpr[int] = 128,
34
+ ):
35
+ offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
36
+ offset_padded = (
37
+ 0
38
+ if const_expr(cu_seqlens is None)
39
+ # Add divby so that the compiler knows the alignment when moving by offset_padded
40
+ else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile)
41
+ )
42
+ if const_expr(seqused is not None):
43
+ seqlen = seqused[batch_idx]
44
+ elif const_expr(cu_seqlens is not None):
45
+ seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
46
+ else:
47
+ seqlen = seqlen_static
48
+ return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None)
49
+
50
+ def offset_batch(
51
+ self,
52
+ mT: cute.Tensor,
53
+ batch_idx: Int32,
54
+ dim: int,
55
+ padded: cutlass.Constexpr[bool] = False,
56
+ multiple: int = 1,
57
+ ) -> cute.Tensor:
58
+ """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0."""
59
+ if const_expr(not self.has_cu_seqlens):
60
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim)
61
+ return mT[idx]
62
+ else:
63
+ off = multiple * (self.offset if const_expr(not padded) else self.offset_padded)
64
+ offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off)
65
+ idx = (offset,) + (None,) * (cute.rank(mT) - 1)
66
+ return cute.domain_offset(idx, mT)
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class SeqlenInfoQK:
71
+ offset_q: Int32
72
+ offset_k: Int32
73
+ padded_offset_q: Int32
74
+ padded_offset_k: Int32
75
+ seqlen_q: Int32
76
+ seqlen_k: Int32
77
+ has_cu_seqlens_q: cutlass.Constexpr[bool]
78
+ has_cu_seqlens_k: cutlass.Constexpr[bool]
79
+ has_seqused_q: cutlass.Constexpr[bool]
80
+ has_seqused_k: cutlass.Constexpr[bool]
81
+
82
+ @staticmethod
83
+ def create(
84
+ batch_idx: Int32,
85
+ seqlen_q_static: Int32,
86
+ seqlen_k_static: Int32,
87
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
88
+ mCuSeqlensK: Optional[cute.Tensor] = None,
89
+ mSeqUsedQ: Optional[cute.Tensor] = None,
90
+ mSeqUsedK: Optional[cute.Tensor] = None,
91
+ tile_m: cutlass.Constexpr[Int32] = 128,
92
+ tile_n: cutlass.Constexpr[Int32] = 128,
93
+ ):
94
+ offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
95
+ offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
96
+ padded_offset_q = (
97
+ 0
98
+ if const_expr(mCuSeqlensQ is None)
99
+ else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m)
100
+ )
101
+ padded_offset_k = (
102
+ 0
103
+ if const_expr(mCuSeqlensK is None)
104
+ else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n)
105
+ )
106
+ if const_expr(mSeqUsedQ is not None):
107
+ seqlen_q = mSeqUsedQ[batch_idx]
108
+ else:
109
+ seqlen_q = (
110
+ seqlen_q_static
111
+ if const_expr(mCuSeqlensQ is None)
112
+ else mCuSeqlensQ[batch_idx + 1] - offset_q
113
+ )
114
+ if const_expr(mSeqUsedK is not None):
115
+ seqlen_k = mSeqUsedK[batch_idx]
116
+ else:
117
+ seqlen_k = (
118
+ seqlen_k_static
119
+ if const_expr(mCuSeqlensK is None)
120
+ else mCuSeqlensK[batch_idx + 1] - offset_k
121
+ )
122
+ return SeqlenInfoQK(
123
+ offset_q,
124
+ offset_k,
125
+ padded_offset_q,
126
+ padded_offset_k,
127
+ seqlen_q,
128
+ seqlen_k,
129
+ has_cu_seqlens_q=mCuSeqlensQ is not None,
130
+ has_cu_seqlens_k=mCuSeqlensK is not None,
131
+ has_seqused_q=mSeqUsedQ is not None,
132
+ has_seqused_k=mSeqUsedK is not None,
133
+ )
134
+
135
+ def offset_batch_Q(
136
+ self,
137
+ mQ: cute.Tensor,
138
+ batch_idx: Int32,
139
+ dim: int,
140
+ padded: cutlass.Constexpr[bool] = False,
141
+ ragged: cutlass.Constexpr[bool] = False,
142
+ ) -> cute.Tensor:
143
+ """Seqlen must be the first dimension of mQ"""
144
+ if const_expr(not ragged):
145
+ if const_expr(not self.has_cu_seqlens_q):
146
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
147
+ return mQ[idx]
148
+ else:
149
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
150
+ offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q)
151
+ idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1)
152
+ return cute.domain_offset(idx, mQ)
153
+ else:
154
+ if const_expr(not self.has_cu_seqlens_q):
155
+ offset_q = 0
156
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
157
+ mQ = mQ[idx]
158
+ else:
159
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
160
+ if const_expr(cute.rank(mQ.shape[0]) == 1):
161
+ return copy_utils.offset_ragged_tensor(
162
+ mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True
163
+ )
164
+ else: # PackGQA
165
+ assert cute.rank(mQ.shape[0]) == 2
166
+ # Unpack before calling offset_ragged_tensor, then pack
167
+ idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1)
168
+ mQ = mQ[idx]
169
+ mQ = copy_utils.offset_ragged_tensor(
170
+ mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True
171
+ )
172
+ return cute.group_modes(mQ, 0, 2)
173
+
174
+ def offset_batch_K(
175
+ self,
176
+ mK: cute.Tensor,
177
+ batch_idx: Int32,
178
+ dim: int,
179
+ padded: cutlass.Constexpr[bool] = False,
180
+ ragged: cutlass.Constexpr[bool] = False,
181
+ multiple: int = 1,
182
+ ) -> cute.Tensor:
183
+ """Seqlen must be the first dimension of mK"""
184
+ if const_expr(not ragged):
185
+ if const_expr(not self.has_cu_seqlens_k):
186
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
187
+ return mK[idx]
188
+ else:
189
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
190
+ offset_k *= multiple
191
+ idx = (offset_k,) + (None,) * (cute.rank(mK) - 1)
192
+ return cute.domain_offset(idx, mK)
193
+ else:
194
+ if const_expr(not self.has_cu_seqlens_k):
195
+ offset_k = 0
196
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
197
+ mK = mK[idx]
198
+ else:
199
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
200
+ offset_k *= multiple
201
+ return copy_utils.offset_ragged_tensor(
202
+ mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True
203
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Online softmax primitives.
5
+
6
+ Contains:
7
+ - ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O.
8
+ The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed
9
+ ``fmul.f32x2`` (2× CUDA-core throughput) when available.
10
+ - ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``,
11
+ ``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel.
12
+ """
13
+
14
+ import math
15
+ import operator
16
+ from dataclasses import dataclass
17
+ from typing import Tuple
18
+
19
+ import cutlass
20
+ import cutlass.cute as cute
21
+ from cutlass import Float32
22
+
23
+ from ...quack import layout_utils
24
+ from ...quack.cute_dsl_utils import ParamsBase
25
+
26
+ from . import utils
27
+
28
+
29
+ @dataclass
30
+ class Softmax(ParamsBase):
31
+ scale_log2: Float32
32
+ num_rows: cutlass.Constexpr[int]
33
+ row_max: cute.Tensor
34
+ row_sum: cute.Tensor
35
+ arch: cutlass.Constexpr[int] = 80
36
+ softmax_scale: Float32 | None = None
37
+
38
+ @staticmethod
39
+ def create(
40
+ scale_log2: Float32,
41
+ num_rows: cutlass.Constexpr[int],
42
+ arch: cutlass.Constexpr[int] = 80,
43
+ softmax_scale: Float32 | None = None,
44
+ ):
45
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
46
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
47
+ return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
48
+
49
+ def reset(self) -> None:
50
+ self.row_max.fill(-Float32.inf)
51
+ self.row_sum.fill(0.0)
52
+
53
+ def _compute_row_max(
54
+ self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
55
+ ) -> Float32:
56
+ return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
57
+
58
+ def _compute_row_sum(
59
+ self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
60
+ ) -> Float32:
61
+ return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
62
+
63
+ @cute.jit
64
+ def online_softmax(
65
+ self,
66
+ acc_S: cute.Tensor,
67
+ is_first: cutlass.Constexpr[bool] = False,
68
+ check_inf: cutlass.Constexpr[bool] = True,
69
+ ) -> cute.Tensor:
70
+ """Apply online softmax and return the row_scale to rescale O.
71
+
72
+ On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is
73
+ rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL
74
+ compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU
75
+ confirms: FFMA2 count is 0 for the TensorSSA path). The packed
76
+ rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA
77
+ instruction count for the softmax scale/subtract stage.
78
+ """
79
+ acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
80
+ row_scale = cute.make_rmem_tensor_like(self.row_max, Float32)
81
+
82
+ row_max = self.row_max
83
+ row_sum = self.row_sum
84
+ scale_log2 = self.scale_log2
85
+ arch = self.arch
86
+
87
+ for r in cutlass.range(cute.size(row_max), unroll_full=True):
88
+ acc_S_row_slice = acc_S_mn[r, None]
89
+ acc_S_row = acc_S_row_slice.load()
90
+
91
+ row_max_cur = utils.fmax_reduce(
92
+ acc_S_row,
93
+ init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
94
+ arch=arch,
95
+ )
96
+
97
+ row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4)
98
+ row_max_prev = row_max[r]
99
+ row_max[r] = row_max_cur
100
+
101
+ if cutlass.const_expr(check_inf):
102
+ row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
103
+
104
+ row_max_cur_scaled = row_max_cur * scale_log2
105
+ minus_row_max_scaled = -row_max_cur_scaled
106
+ n = cute.size(acc_S_row_slice)
107
+
108
+ if cutlass.const_expr(arch >= 100 and n % 2 == 0):
109
+ # SM100 packed f32x2 FMA path: scale + subtract in one pass.
110
+ for i in cutlass.range(0, n, 2, unroll_full=True):
111
+ acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2(
112
+ (acc_S_row_slice[i], acc_S_row_slice[i + 1]),
113
+ (scale_log2, scale_log2),
114
+ (minus_row_max_scaled, minus_row_max_scaled),
115
+ )
116
+ for i in cutlass.range(n, unroll_full=True):
117
+ acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True)
118
+ acc_S_row_exp = acc_S_row_slice.load()
119
+ else:
120
+ acc_S_row_exp = cute.math.exp2(
121
+ acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
122
+ )
123
+ acc_S_row_slice.store(acc_S_row_exp)
124
+
125
+ if cutlass.const_expr(is_first):
126
+ acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
127
+ row_scale[r] = 1.0
128
+ else:
129
+ row_scale[r] = cute.math.exp2(
130
+ (row_max_prev - row_max_cur) * scale_log2, fastmath=True
131
+ )
132
+ acc_S_row_sum = utils.fadd_reduce(
133
+ acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
134
+ )
135
+
136
+ row_sum[r] = acc_S_row_sum
137
+
138
+ return row_scale
139
+
140
+ @cute.jit
141
+ def finalize(
142
+ self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
143
+ ) -> cute.Tensor:
144
+ """Finalize the online softmax by computing the scale and logsumexp.
145
+
146
+ On SM100+ with an even ``num_rows`` and no sink_val, the loop is
147
+ unrolled in pairs so the key per-row arithmetic ― rcp*final_scale,
148
+ max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one
149
+ ``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more
150
+ ``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare).
151
+ """
152
+ if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
153
+ assert cute.size(sink_val) == cute.size(self.row_sum)
154
+ row_sum = self.row_sum
155
+ row_max = self.row_max
156
+ scale_log2 = self.scale_log2
157
+
158
+ row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
159
+ row_scale = cute.make_rmem_tensor_like(row_max, Float32)
160
+
161
+ LN2 = math.log(2.0)
162
+ num_rows = cute.size(row_sum)
163
+ use_packed = cutlass.const_expr(
164
+ self.arch >= 100 and num_rows % 2 == 0 and sink_val is None
165
+ )
166
+
167
+ if use_packed:
168
+ for r in cutlass.range(0, num_rows, 2, unroll_full=True):
169
+ s0 = row_sum[r]
170
+ s1 = row_sum[r + 1]
171
+ m0 = row_max[r]
172
+ m1 = row_max[r + 1]
173
+ bad0 = s0 == 0.0 or s0 != s0
174
+ bad1 = s1 == 0.0 or s1 != s1
175
+
176
+ # row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar
177
+ # (no packed rcp intrinsic); the trailing multiply packs.
178
+ rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0)
179
+ rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1)
180
+ row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2(
181
+ (rcp0, rcp1), (final_scale, final_scale)
182
+ )
183
+
184
+ # LSE = (row_max * scale_log2 + log2(row_sum)) * LN2
185
+ # packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2.
186
+ log0 = cute.math.log2(s0, fastmath=True)
187
+ log1 = cute.math.log2(s1, fastmath=True)
188
+ lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2(
189
+ (m0, m1), (scale_log2, scale_log2), (log0, log1)
190
+ )
191
+ lse_0, lse_1 = cute.arch.mul_packed_f32x2(
192
+ (lse_pre_0, lse_pre_1), (LN2, LN2)
193
+ )
194
+ row_sum[r] = -Float32.inf if bad0 else lse_0
195
+ row_sum[r + 1] = -Float32.inf if bad1 else lse_1
196
+ else:
197
+ for r in cutlass.range(num_rows, unroll_full=True):
198
+ if cutlass.const_expr(sink_val is not None):
199
+ sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
200
+ LOG2_E = math.log2(math.e)
201
+ row_sum[r] += cute.math.exp2(
202
+ sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
203
+ )
204
+
205
+ acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
206
+ row_scale[r] = (
207
+ cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
208
+ ) * final_scale
209
+ row_sum_cur = row_sum[r]
210
+ row_sum[r] = (
211
+ (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
212
+ if not acc_O_mn_row_is_zero_or_nan
213
+ else -Float32.inf
214
+ )
215
+ return row_scale
216
+
217
+ @cute.jit
218
+ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
219
+ """Scale each row of acc_O by the given scale tensor."""
220
+ acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
221
+ assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
222
+ n = cute.size(acc_O_mn, mode=[1])
223
+ if cutlass.const_expr(self.arch >= 100 and n % 2 == 0):
224
+ # SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput).
225
+ for r in cutlass.range(cute.size(row_scale), unroll_full=True):
226
+ scale = row_scale[r]
227
+ for j in cutlass.range(0, n, 2, unroll_full=True):
228
+ acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2(
229
+ (acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale)
230
+ )
231
+ else:
232
+ for r in cutlass.range(cute.size(row_scale), unroll_full=True):
233
+ acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
234
+
235
+
236
+ @dataclass
237
+ class SoftmaxSm100(Softmax):
238
+ """SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths."""
239
+
240
+ rescale_threshold: cutlass.Constexpr[float] = 0.0
241
+
242
+ @staticmethod
243
+ def create(
244
+ scale_log2: Float32,
245
+ rescale_threshold: cutlass.Constexpr[float] = 0.0,
246
+ softmax_scale: Float32 | None = None,
247
+ ):
248
+ num_rows = 1
249
+ arch = 100
250
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
251
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
252
+ return SoftmaxSm100(
253
+ scale_log2,
254
+ num_rows,
255
+ row_max,
256
+ row_sum,
257
+ arch,
258
+ softmax_scale,
259
+ rescale_threshold=rescale_threshold,
260
+ )
261
+
262
+ @cute.jit
263
+ def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
264
+ if cutlass.const_expr(is_first):
265
+ row_max_new = self._compute_row_max(acc_S_row)
266
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
267
+ acc_scale = 0.0
268
+ else:
269
+ row_max_old = self.row_max[0]
270
+ row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
271
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
272
+ acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
273
+ acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
274
+ if cutlass.const_expr(self.rescale_threshold > 0.0):
275
+ if acc_scale_ >= -self.rescale_threshold:
276
+ row_max_new = row_max_old
277
+ row_max_safe = row_max_old
278
+ acc_scale = 1.0
279
+ self.row_max[0] = row_max_new
280
+ return row_max_safe, acc_scale
281
+
282
+ @cute.jit
283
+ def update_row_max_deferred_exp2(
284
+ self,
285
+ acc_S_row: cute.TensorSSA,
286
+ is_first: int,
287
+ ) -> Tuple[Float32, Float32]:
288
+ """update_row_max variant that publishes the log2-delta (un-exp2'd) so
289
+ the consumer can do the exp2 only when an actual rescale fires.
290
+
291
+ Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where:
292
+ - ``row_max_safe`` is the same row-max as ``update_row_max`` (with
293
+ ``rescale_threshold`` rollback applied).
294
+ - ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when
295
+ the threshold rollback fired (consumer treats as no rescale), else
296
+ the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2``
297
+ (consumer computes ``cute.math.exp2`` and rescales).
298
+
299
+ This keeps MUFU.EX2 off the sm_stats publication critical path that
300
+ gates the correction WG's consumer wait.
301
+ """
302
+ publish = Float32(0.0)
303
+ if cutlass.const_expr(is_first):
304
+ row_max_new = self._compute_row_max(acc_S_row)
305
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
306
+ else:
307
+ row_max_old = self.row_max[0]
308
+ row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
309
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
310
+ acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
311
+ if cutlass.const_expr(self.rescale_threshold > 0.0):
312
+ if acc_scale_ >= -self.rescale_threshold:
313
+ row_max_new = row_max_old
314
+ row_max_safe = row_max_old
315
+ # publish stays 0.0 (signal: no rescale needed)
316
+ else:
317
+ publish = acc_scale_
318
+ else:
319
+ publish = acc_scale_
320
+ self.row_max[0] = row_max_new
321
+ return row_max_safe, publish
322
+
323
+ @cute.jit
324
+ def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None:
325
+ if cutlass.const_expr(is_first):
326
+ row_max_new = self._compute_row_max(acc_S_row)
327
+ else:
328
+ row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0])
329
+ self.row_max[0] = row_max_new
330
+
331
+ def update_row_sum(
332
+ self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
333
+ ) -> None:
334
+ init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
335
+ self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
336
+
337
+ @cute.jit
338
+ def compute_scaled_exp2_row_sum(
339
+ self,
340
+ acc_S_row: cute.Tensor,
341
+ scale: Float32,
342
+ ) -> Float32:
343
+ return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch)
344
+
345
+ @cute.jit
346
+ def scale_subtract_rowmax(
347
+ self,
348
+ acc_S_row: cute.Tensor,
349
+ row_max: Float32,
350
+ ):
351
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
352
+ row_max_scaled = row_max * self.scale_log2
353
+ for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
354
+ acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
355
+ (acc_S_row[i], acc_S_row[i + 1]),
356
+ (self.scale_log2, self.scale_log2),
357
+ (-row_max_scaled, -row_max_scaled),
358
+ )
359
+
360
+ @cute.jit
361
+ def apply_exp2_convert(
362
+ self,
363
+ acc_S_row: cute.Tensor,
364
+ acc_S_row_converted: cute.Tensor,
365
+ ex2_emu_freq: cutlass.Constexpr[int] = 0,
366
+ ex2_emu_res: cutlass.Constexpr[int] = 4,
367
+ ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
368
+ ):
369
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
370
+ frg_tile = 32
371
+ assert frg_tile % 2 == 0
372
+ frg_cnt = cute.size(acc_S_row) // frg_tile
373
+ assert cute.size(acc_S_row) % frg_tile == 0
374
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
375
+ acc_S_row_converted_frg = cute.logical_divide(
376
+ acc_S_row_converted, cute.make_layout(frg_tile)
377
+ )
378
+ for j in cutlass.range_constexpr(frg_cnt):
379
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
380
+ if cutlass.const_expr(ex2_emu_freq == 0):
381
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
382
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
383
+ else:
384
+ if cutlass.const_expr(
385
+ k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
386
+ or j >= frg_cnt - 1
387
+ or j < ex2_emu_start_frg
388
+ ):
389
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
390
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(
391
+ acc_S_row_frg[k + 1, j], fastmath=True
392
+ )
393
+ else:
394
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
395
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
396
+ )
397
+ acc_S_row_converted_frg[None, j].store(
398
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
399
+ )
400
+
401
+ @cute.jit
402
+ def scale_apply_exp2_convert(
403
+ self,
404
+ acc_S_row: cute.Tensor,
405
+ row_max: Float32,
406
+ acc_S_row_converted: cute.Tensor,
407
+ ):
408
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
409
+ minus_row_max_scaled = -row_max * self.scale_log2
410
+ for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
411
+ acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
412
+ (acc_S_row[i], acc_S_row[i + 1]),
413
+ (self.scale_log2, self.scale_log2),
414
+ (minus_row_max_scaled, minus_row_max_scaled),
415
+ )
416
+
417
+ frg_tile = 32
418
+ assert frg_tile % 2 == 0
419
+ frg_cnt = cute.size(acc_S_row) // frg_tile
420
+ assert cute.size(acc_S_row) % frg_tile == 0
421
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
422
+ acc_S_row_converted_frg = cute.logical_divide(
423
+ acc_S_row_converted, cute.make_layout(frg_tile)
424
+ )
425
+ for j in cutlass.range_constexpr(frg_cnt):
426
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
427
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
428
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
429
+ acc_S_row_converted_frg[None, j].store(
430
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
431
+ )
432
+
433
+ @cute.jit
434
+ def scale_apply_exp2_convert_sum(
435
+ self,
436
+ acc_S_row: cute.Tensor,
437
+ row_max: Float32,
438
+ acc_S_row_converted: cute.Tensor,
439
+ init_sum: Float32,
440
+ ex2_emu_freq: cutlass.Constexpr[int] = 0,
441
+ ex2_emu_res: cutlass.Constexpr[int] = 4,
442
+ ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
443
+ ) -> Float32:
444
+ # When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res
445
+ # pairs in the inner loop use the FFMA2-based polynomial ex2 emulation
446
+ # (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's
447
+ # apply_exp2_convert. This removes the MUFU "wait" stall that dominates
448
+ # the second-largest stall bucket in decode (~22% of total).
449
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
450
+ minus_row_max_scaled = -row_max * self.scale_log2
451
+ acc_sum = (init_sum, Float32(0.0))
452
+
453
+ frg_tile = 32
454
+ assert frg_tile % 2 == 0
455
+ frg_cnt = cute.size(acc_S_row) // frg_tile
456
+ assert cute.size(acc_S_row) % frg_tile == 0
457
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
458
+ acc_S_row_converted_frg = cute.logical_divide(
459
+ acc_S_row_converted, cute.make_layout(frg_tile)
460
+ )
461
+ for j in cutlass.range_constexpr(frg_cnt):
462
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
463
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2(
464
+ (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
465
+ (self.scale_log2, self.scale_log2),
466
+ (minus_row_max_scaled, minus_row_max_scaled),
467
+ )
468
+ if cutlass.const_expr(ex2_emu_freq == 0):
469
+ acc_S_row_frg[k, j] = cute.math.exp2(
470
+ acc_S_row_frg[k, j], fastmath=True)
471
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(
472
+ acc_S_row_frg[k + 1, j], fastmath=True)
473
+ else:
474
+ use_real = cutlass.const_expr(
475
+ k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
476
+ or j >= frg_cnt - 1
477
+ or j < ex2_emu_start_frg
478
+ )
479
+ if cutlass.const_expr(use_real):
480
+ acc_S_row_frg[k, j] = cute.math.exp2(
481
+ acc_S_row_frg[k, j], fastmath=True)
482
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(
483
+ acc_S_row_frg[k + 1, j], fastmath=True)
484
+ else:
485
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
486
+ utils.ex2_emulation_2(
487
+ acc_S_row_frg[k, j],
488
+ acc_S_row_frg[k + 1, j],
489
+ )
490
+ )
491
+ acc_sum = cute.arch.add_packed_f32x2(
492
+ acc_sum,
493
+ (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
494
+ )
495
+ acc_S_row_converted_frg[None, j].store(
496
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
497
+ )
498
+ return acc_sum[0] + acc_sum[1]
build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from enum import IntEnum, auto
5
+ from typing import Optional, Tuple, Protocol, runtime_checkable
6
+ from dataclasses import dataclass
7
+
8
+ try:
9
+ from typing import override
10
+ except ImportError: # Python < 3.12
11
+ from typing_extensions import override
12
+
13
+ import cutlass
14
+ from cutlass.pipeline import PipelineClcFetchAsync, PipelineState
15
+ from cutlass._mlir import ir
16
+ import cutlass.cute as cute
17
+ from cutlass import Int32, const_expr
18
+ from cutlass.cute import FastDivmodDivisor
19
+ from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams
20
+
21
+ from ...quack.cute_dsl_utils import ParamsBase
22
+
23
+ from ...src.common import utils as utils
24
+ from ...src.common.fast_math import clz
25
+
26
+
27
+ class SchedulingMode(IntEnum):
28
+ NONE = auto()
29
+ STATIC = auto()
30
+ DYNAMIC = auto()
31
+ CLC = auto()
32
+
33
+
34
+ @dataclass
35
+ class ClcState(ParamsBase):
36
+ """Owns the runtime state shared by CLC-capable tile schedulers.
37
+
38
+ `SparseAttentionForwardSm100` constructs this state because it owns the CLC
39
+ response buffer, mbarrier storage, and launch geometry needed to initialize
40
+ the hardware scheduler and async pipeline. Individual tile schedulers then
41
+ consume this state and map the returned hardware work tiles into their own
42
+ logical `WorkTileInfo` coordinates.
43
+
44
+ To add CLC support to a scheduler:
45
+ - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler
46
+ - accept `clc: ClcState | None` in `create(...)` / `__init__`
47
+ - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates
48
+ """
49
+
50
+ _hw_scheduler: ClcDynamicPersistentTileScheduler
51
+ _pipeline: PipelineClcFetchAsync
52
+ _consumer_state: PipelineState
53
+ _producer_state: PipelineState
54
+
55
+ @staticmethod
56
+ def create(
57
+ *,
58
+ hw_scheduler: ClcDynamicPersistentTileScheduler,
59
+ pipeline: PipelineClcFetchAsync,
60
+ consumer_state: PipelineState,
61
+ producer_state: PipelineState,
62
+ ) -> "ClcState":
63
+ return ClcState(hw_scheduler, pipeline, consumer_state, producer_state)
64
+
65
+ def initial_work_tile_info(self):
66
+ return self._hw_scheduler.initial_work_tile_info()
67
+
68
+ def get_current_work(self):
69
+ return self._hw_scheduler.get_current_work()
70
+
71
+ def prefetch_next_work(self, *, loc=None, ip=None):
72
+ self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip)
73
+ mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip)
74
+ self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip)
75
+ self._producer_state.advance(loc=loc, ip=ip)
76
+
77
+ def consumer_wait(self, *, loc=None, ip=None):
78
+ self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip)
79
+
80
+ def consumer_release(self, *, loc=None, ip=None):
81
+ self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip)
82
+ self._consumer_state.advance(loc=loc, ip=ip)
83
+
84
+ def producer_tail(self, *, loc=None, ip=None):
85
+ self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip)
86
+
87
+
88
+ class WorkTileInfo(cutlass.utils.WorkTileInfo):
89
+ """Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
90
+
91
+ @override
92
+ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
93
+ assert len(values) == 5
94
+ new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
95
+ new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
96
+ return WorkTileInfo(new_tile_idx, new_is_valid_tile)
97
+
98
+
99
+ @runtime_checkable
100
+ class TileSchedulerProtocol(Protocol):
101
+ """Protocol defining the interface all tile schedulers must implement.
102
+
103
+ Schedulers are responsible for:
104
+ 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split)
105
+ 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic)
106
+ """
107
+
108
+ def get_current_work(self) -> WorkTileInfo:
109
+ """Get the current work tile coordinates."""
110
+ ...
111
+
112
+ def initial_work_tile_info(self) -> WorkTileInfo:
113
+ """Get the initial work tile for this CTA."""
114
+ ...
115
+
116
+ def advance_to_next_work(self, *, loc=None, ip=None):
117
+ """Consumer-side advance: move to next tile and return it.
118
+
119
+ For static schedulers: grid-stride increment + get_current_work.
120
+ For CLC schedulers: consumer wait + get_current_work + consumer release + state advance.
121
+ """
122
+ ...
123
+
124
+ def prefetch_next_work(self, *, loc=None, ip=None) -> None:
125
+ """Producer-side prefetch of next work tile (no-op for static schedulers).
126
+
127
+ For CLC schedulers: producer acquire + issue CLC query + producer state advance.
128
+ Only called by the scheduler warp.
129
+ """
130
+ ...
131
+
132
+ def producer_tail(self, *, loc=None, ip=None) -> None:
133
+ """Producer-side cleanup after the last tile.
134
+
135
+ No-op for static schedulers. For CLC schedulers: pipeline producer_tail.
136
+ """
137
+ ...
138
+
139
+
140
+ @dataclass
141
+ class TileSchedulerArguments(ParamsBase):
142
+ num_block: Int32
143
+ num_head: Int32
144
+ num_batch: Int32
145
+ num_splits: Int32
146
+ seqlen_k: Int32
147
+ headdim: Int32
148
+ headdim_v: Int32
149
+ total_q: Int32
150
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
151
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
152
+ mCuSeqlensQ: Optional[cute.Tensor] = None
153
+ mSeqUsedQ: Optional[cute.Tensor] = None
154
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
155
+ element_size: cutlass.Constexpr[int] = 2
156
+ is_persistent: cutlass.Constexpr[bool] = False
157
+ lpt: cutlass.Constexpr[bool] = False
158
+ is_split_kv: cutlass.Constexpr[bool] = False
159
+ head_swizzle: cutlass.Constexpr[bool] = False
160
+ use_cluster_idx: cutlass.Constexpr[bool] = False
161
+
162
+
163
+ class SingleTileScheduler:
164
+ @dataclass
165
+ class Params(ParamsBase):
166
+ num_block: Int32
167
+ num_head: Int32
168
+ num_batch: Int32
169
+ num_splits: Int32
170
+ num_splits_divmod: FastDivmodDivisor
171
+ is_split_kv: cutlass.Constexpr[bool] = False
172
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
173
+ use_cluster_idx: cutlass.Constexpr[bool] = False
174
+
175
+ @staticmethod
176
+ def create(
177
+ args: TileSchedulerArguments, *, loc=None, ip=None
178
+ ) -> "SingleTileScheduler.Params":
179
+ return SingleTileScheduler.Params(
180
+ args.num_block,
181
+ args.num_head,
182
+ args.num_batch,
183
+ args.num_splits,
184
+ FastDivmodDivisor(args.num_splits),
185
+ args.is_split_kv,
186
+ args.cluster_shape_mn,
187
+ args.use_cluster_idx,
188
+ )
189
+
190
+ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
191
+ self.params = params
192
+ self._blk_coord = blk_coord
193
+ self._is_first_block = True
194
+ self._loc = loc
195
+ self._ip = ip
196
+
197
+ @staticmethod
198
+ def to_underlying_arguments(
199
+ args: TileSchedulerArguments,
200
+ *,
201
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
202
+ loc=None,
203
+ ip=None,
204
+ ) -> Params:
205
+ assert scheduling_mode == SchedulingMode.STATIC, (
206
+ f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}"
207
+ )
208
+ return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
209
+
210
+ @staticmethod
211
+ def create(
212
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
213
+ ) -> "SingleTileScheduler":
214
+ if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx):
215
+ blk_coord = cute.arch.block_idx()
216
+ else:
217
+ blk_coord = cute.arch.cluster_idx()
218
+ return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
219
+
220
+ # called by host
221
+ @staticmethod
222
+ def get_grid_shape(
223
+ params: Params,
224
+ *,
225
+ loc=None,
226
+ ip=None,
227
+ ) -> Tuple[Int32, Int32, Int32]:
228
+ # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
229
+ assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
230
+ if const_expr(params.use_cluster_idx):
231
+ # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters
232
+ grid_x = params.num_block * params.cluster_shape_mn[0]
233
+ else:
234
+ grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0])
235
+ return (
236
+ grid_x,
237
+ params.num_head * params.num_splits,
238
+ params.num_batch,
239
+ )
240
+
241
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
242
+ block_idx, head_idx, batch_idx = self._blk_coord
243
+ if const_expr(self.params.is_split_kv):
244
+ head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
245
+ else:
246
+ split_idx = Int32(0)
247
+ return WorkTileInfo(
248
+ (block_idx, head_idx, batch_idx, split_idx),
249
+ self._is_first_block,
250
+ )
251
+
252
+ def initial_work_tile_info(self, *, loc=None, ip=None):
253
+ return self.get_current_work(loc=loc, ip=ip)
254
+
255
+ def prefetch_next_work(self, *, loc=None, ip=None):
256
+ pass
257
+
258
+ def advance_to_next_work(self, *, loc=None, ip=None):
259
+ self._is_first_block = False
260
+ return self.get_current_work()
261
+
262
+ def producer_tail(self, *, loc=None, ip=None):
263
+ pass
264
+
265
+ def __extract_mlir_values__(self):
266
+ values, self._values_pos = [], []
267
+ for obj in [self.params, self._blk_coord]:
268
+ obj_values = cutlass.extract_mlir_values(obj)
269
+ values += obj_values
270
+ self._values_pos.append(len(obj_values))
271
+ return values
272
+
273
+ def __new_from_mlir_values__(self, values):
274
+ obj_list = []
275
+ for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
276
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
277
+ values = values[n_items:]
278
+ return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
279
+
280
+
281
+ class StaticPersistentTileScheduler:
282
+ @dataclass
283
+ class Params(ParamsBase):
284
+ num_block_cluster_divmod: FastDivmodDivisor
285
+ num_head_divmod: FastDivmodDivisor
286
+ total_blocks_cluster: Int32
287
+ cluster_shape_m: cutlass.Constexpr[int] = 1
288
+
289
+ @staticmethod
290
+ def create(
291
+ args: TileSchedulerArguments, *, loc=None, ip=None
292
+ ) -> "StaticPersistentTileScheduler.Params":
293
+ num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn))
294
+ total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch
295
+ return StaticPersistentTileScheduler.Params(
296
+ FastDivmodDivisor(num_block_cluster),
297
+ FastDivmodDivisor(args.num_head),
298
+ total_blocks_cluster,
299
+ cluster_shape_m=args.cluster_shape_mn[0],
300
+ )
301
+
302
+ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
303
+ self.params = params
304
+ self._tile_idx = tile_idx
305
+ self._loc = loc
306
+ self._ip = ip
307
+
308
+ @staticmethod
309
+ def to_underlying_arguments(
310
+ args: TileSchedulerArguments,
311
+ *,
312
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
313
+ loc=None,
314
+ ip=None,
315
+ ) -> Params:
316
+ assert scheduling_mode == SchedulingMode.STATIC, (
317
+ f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}"
318
+ )
319
+ return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
320
+
321
+ @staticmethod
322
+ def create(
323
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
324
+ ) -> "StaticPersistentTileScheduler":
325
+ if const_expr(cute.size(params.cluster_shape_m) == 1):
326
+ tile_idx = cute.arch.block_idx()[0]
327
+ else:
328
+ tile_idx = cute.arch.cluster_idx()[0]
329
+ return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
330
+
331
+ @staticmethod
332
+ def get_grid_shape(
333
+ params: Params,
334
+ *,
335
+ usable_SM_count=0,
336
+ loc=None,
337
+ ip=None,
338
+ ) -> Tuple[Int32, Int32, Int32]:
339
+ hardware_info = cutlass.utils.HardwareInfo()
340
+ cluster_shape_m = int(params.cluster_shape_m)
341
+ if usable_SM_count > 0:
342
+ sm_count = usable_SM_count
343
+ else:
344
+ sm_count = hardware_info.get_device_multiprocessor_count()
345
+ max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m
346
+ max_ctas = max(max_ctas, cluster_shape_m)
347
+ grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m)
348
+ return (grid_x, Int32(1), Int32(1))
349
+
350
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
351
+ hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
352
+ batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
353
+ is_valid = self._tile_idx < self.params.total_blocks_cluster
354
+ return WorkTileInfo(
355
+ (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
356
+ )
357
+
358
+ def initial_work_tile_info(self, *, loc=None, ip=None):
359
+ return self.get_current_work(loc=loc, ip=ip)
360
+
361
+ def prefetch_next_work(self, *, loc=None, ip=None):
362
+ pass
363
+
364
+ def advance_to_next_work(self, *, loc=None, ip=None):
365
+ if const_expr(self.params.cluster_shape_m == 1):
366
+ self._tile_idx += cute.arch.grid_dim()[0]
367
+ else:
368
+ self._tile_idx += cute.arch.cluster_dim()[0]
369
+ return self.get_current_work()
370
+
371
+ def producer_tail(self, *, loc=None, ip=None):
372
+ pass
373
+
374
+ def __extract_mlir_values__(self):
375
+ values, self._values_pos = [], []
376
+ for obj in [self.params, self._tile_idx]:
377
+ obj_values = cutlass.extract_mlir_values(obj)
378
+ values += obj_values
379
+ self._values_pos.append(len(obj_values))
380
+ return values
381
+
382
+ def __new_from_mlir_values__(self, values):
383
+ obj_list = []
384
+ for obj, n_items in zip(
385
+ [self.params, self._tile_idx],
386
+ self._values_pos,
387
+ ):
388
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
389
+ values = values[n_items:]
390
+ return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
391
+
392
+
393
+ class SingleTileLPTScheduler:
394
+ @dataclass
395
+ class Params(ParamsBase):
396
+ total_blocks: Int32
397
+ num_splits: Int32
398
+ num_block: Int32
399
+ num_head: Int32
400
+ num_batch: Int32
401
+ l2_minor: Int32
402
+ num_head_divmod: FastDivmodDivisor
403
+ l2_minor_divmod: FastDivmodDivisor
404
+ l2_major_divmod: FastDivmodDivisor
405
+ l2_minor_residual_divmod: FastDivmodDivisor
406
+ num_hb_quotient: Int32
407
+ num_splits_divmod: FastDivmodDivisor
408
+ is_split_kv: cutlass.Constexpr[bool] = False
409
+ cluster_shape_m: cutlass.Constexpr[int] = 1
410
+ scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
411
+ lpt: cutlass.Constexpr[bool] = True
412
+ use_cluster_idx: cutlass.Constexpr[bool] = True
413
+
414
+ @staticmethod
415
+ @cute.jit
416
+ def create(
417
+ args: TileSchedulerArguments,
418
+ *,
419
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
420
+ loc=None,
421
+ ip=None,
422
+ ) -> "SingleTileLPTScheduler.Params":
423
+ assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
424
+ f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
425
+ )
426
+ size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
427
+ size_one_head = size_one_kv_head
428
+ size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
429
+ # Swizzle is the size of each "section". Round swizzle to a power of 2
430
+ # Need to be careful about the case where only one head will fit
431
+ # swizzle is how many heads can fit in L2
432
+ # Seems faster if swizzle is a power of 2
433
+ log2_floor = lambda n: 31 - clz(n)
434
+ swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
435
+ # If we're in the last section (called residual), we don't want to divide by
436
+ # swizzle. Instead we want to divide by the remainder.
437
+ num_hb_quotient = (args.num_head * args.num_batch) // swizzle
438
+ num_hb_remainder = (args.num_head * args.num_batch) % swizzle
439
+ return SingleTileLPTScheduler.Params(
440
+ total_blocks=args.num_block * args.num_head * args.num_batch,
441
+ num_block=args.num_block,
442
+ num_head=args.num_head,
443
+ num_batch=args.num_batch,
444
+ l2_minor=Int32(swizzle),
445
+ num_head_divmod=FastDivmodDivisor(args.num_head),
446
+ l2_minor_divmod=FastDivmodDivisor(swizzle),
447
+ l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
448
+ l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)),
449
+ num_hb_quotient=Int32(num_hb_quotient),
450
+ num_splits=args.num_splits,
451
+ num_splits_divmod=FastDivmodDivisor(args.num_splits),
452
+ is_split_kv=args.is_split_kv,
453
+ cluster_shape_m=args.cluster_shape_mn[0],
454
+ scheduling_mode=scheduling_mode,
455
+ lpt=args.lpt,
456
+ use_cluster_idx=args.use_cluster_idx,
457
+ )
458
+
459
+ def __init__(
460
+ self,
461
+ params: Params,
462
+ tile_idx: Int32,
463
+ split_idx: Int32,
464
+ clc: ClcState | None = None,
465
+ *,
466
+ loc=None,
467
+ ip=None,
468
+ ):
469
+ self.params = params
470
+ self._tile_idx = tile_idx
471
+ self._split_idx = split_idx
472
+ self.clc = clc
473
+ self._loc = loc
474
+ self._ip = ip
475
+
476
+ @staticmethod
477
+ def to_underlying_arguments(
478
+ args: TileSchedulerArguments,
479
+ *,
480
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
481
+ loc=None,
482
+ ip=None,
483
+ ) -> Params:
484
+ return SingleTileLPTScheduler.Params.create(
485
+ args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
486
+ )
487
+
488
+ @staticmethod
489
+ def _clc_grid_shape(params: Params):
490
+ num_batch_splits = (
491
+ params.num_batch * params.num_splits
492
+ if const_expr(params.is_split_kv)
493
+ else params.num_batch
494
+ )
495
+ return (
496
+ cute.round_up(params.num_block, params.cluster_shape_m),
497
+ params.num_head,
498
+ num_batch_splits,
499
+ )
500
+
501
+ @staticmethod
502
+ @cute.jit
503
+ def clc_problem_shape(params: Params):
504
+ return ClcDynamicPersistentTileSchedulerParams(
505
+ problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params),
506
+ cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
507
+ )
508
+
509
+ @staticmethod
510
+ @cute.jit
511
+ def create(
512
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
513
+ ) -> "SingleTileLPTScheduler":
514
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
515
+ return SingleTileLPTScheduler(
516
+ params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip
517
+ )
518
+ tile_idx, split_idx, _ = cute.arch.block_idx()
519
+ return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
520
+
521
+ @staticmethod
522
+ def get_grid_shape(
523
+ params: Params,
524
+ *,
525
+ loc=None,
526
+ ip=None,
527
+ ) -> Tuple[Int32, Int32, Int32]:
528
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
529
+ return SingleTileLPTScheduler._clc_grid_shape(params)
530
+ return (params.total_blocks, params.num_splits, Int32(1))
531
+
532
+ @cute.jit
533
+ def clc_work_to_coords(self, work) -> WorkTileInfo:
534
+ """Convert CLC response (block, head, batch_split) to WorkTileInfo.
535
+
536
+ CLC returns raw grid coordinates — no L2 swizzle (hardware decides order).
537
+ We only apply cluster division, optional LPT block reversal, and split_kv unpacking.
538
+ """
539
+ block_idx = work.tile_idx[0]
540
+ if const_expr(self.params.cluster_shape_m > 1):
541
+ block_idx = block_idx // self.params.cluster_shape_m
542
+ if const_expr(self.params.lpt):
543
+ # Longest-processing-time-first: reverse block order
544
+ if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx):
545
+ num_block = self.params.num_block // self.params.cluster_shape_m
546
+ else:
547
+ num_block = self.params.num_block
548
+ block_idx = num_block - 1 - block_idx
549
+ split_idx = Int32(0)
550
+ if const_expr(self.params.is_split_kv):
551
+ batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod)
552
+ else:
553
+ batch_idx = work.tile_idx[2]
554
+ if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx):
555
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
556
+ block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0]
557
+ return WorkTileInfo(
558
+ (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)),
559
+ work.is_valid_tile,
560
+ )
561
+
562
+ @cute.jit
563
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
564
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
565
+ work = self.clc.get_current_work()
566
+ self._tile_idx = work.tile_idx[0]
567
+ return self.clc_work_to_coords(work)
568
+ # Static path: L2-swizzled coordinate mapping
569
+ params = self.params
570
+ # Implement LPT scheduling coordinate calculation
571
+ bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
572
+ # If we're in the last section (called residual), we don't want to divide by
573
+ # swizzle. Instead we want to divide by the remainder.
574
+ block, bidhb_residual = 0, 0
575
+ if bidhb < params.num_hb_quotient:
576
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
577
+ else:
578
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
579
+ bidhb_actual = bidhb * params.l2_minor + bidhb_residual
580
+ batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
581
+ # Longest-processing-time-first
582
+ if const_expr(params.lpt):
583
+ block = params.num_block - 1 - block
584
+ is_valid = self._tile_idx < params.total_blocks
585
+ return WorkTileInfo(
586
+ (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
587
+ )
588
+
589
+ @cute.jit
590
+ def initial_work_tile_info(self, *, loc=None, ip=None):
591
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
592
+ work = self.clc.initial_work_tile_info()
593
+ self._tile_idx = work.tile_idx[0]
594
+ return self.clc_work_to_coords(work)
595
+ return self.get_current_work(loc=loc, ip=ip)
596
+
597
+ def prefetch_next_work(self, *, loc=None, ip=None):
598
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
599
+ self.clc.prefetch_next_work(loc=loc, ip=ip)
600
+
601
+ def advance_to_next_work(self, *, loc=None, ip=None):
602
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
603
+ self.clc.consumer_wait(loc=loc, ip=ip)
604
+ work = self.get_current_work()
605
+ self.clc.consumer_release(loc=loc, ip=ip)
606
+ return work
607
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
608
+ self._tile_idx = self.params.total_blocks
609
+ return self.get_current_work()
610
+
611
+ def producer_tail(self, *, loc=None, ip=None):
612
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
613
+ self.clc.producer_tail(loc=loc, ip=ip)
614
+
615
+ def __extract_mlir_values__(self):
616
+ values, self._values_pos = [], []
617
+ objs = [self.params, self._tile_idx, self._split_idx]
618
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
619
+ objs += [self.clc]
620
+ for obj in objs:
621
+ obj_values = cutlass.extract_mlir_values(obj)
622
+ values += obj_values
623
+ self._values_pos.append(len(obj_values))
624
+ return values
625
+
626
+ def __new_from_mlir_values__(self, values):
627
+ obj_list = []
628
+ objs = [self.params, self._tile_idx, self._split_idx]
629
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
630
+ objs += [self.clc]
631
+ for obj, n_items in zip(objs, self._values_pos):
632
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
633
+ values = values[n_items:]
634
+ return self.__class__(*obj_list, loc=self._loc)
635
+
636
+
637
+ class SingleTileVarlenScheduler:
638
+ @dataclass
639
+ class Params(ParamsBase):
640
+ num_head: Int32
641
+ num_batch: Int32
642
+ total_q: Int32
643
+ num_splits: Int32
644
+ max_kvblock_in_l2: Int32
645
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
646
+ mCuSeqlensQ: Optional[cute.Tensor] = None
647
+ mSeqUsedQ: Optional[cute.Tensor] = None
648
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
649
+ lpt: cutlass.Constexpr[bool] = False
650
+ is_split_kv: cutlass.Constexpr[bool] = False
651
+ head_swizzle: cutlass.Constexpr[bool] = False
652
+ cluster_shape_m: cutlass.Constexpr[int] = 1
653
+ scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
654
+
655
+ @staticmethod
656
+ @cute.jit
657
+ def create(
658
+ args: TileSchedulerArguments,
659
+ *,
660
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
661
+ loc=None,
662
+ ip=None,
663
+ ) -> "SingleTileVarlenScheduler.Params":
664
+ assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
665
+ f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
666
+ )
667
+ size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
668
+ kv_block_size = (
669
+ (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
670
+ )
671
+ if args.head_swizzle:
672
+ kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
673
+ max_kvblock_in_l2 = size_l2 // kv_block_size
674
+ assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
675
+ "At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
676
+ )
677
+ assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
678
+ # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the
679
+ # flattened-tile decode so cluster unpacking semantics are explicit.
680
+ assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, (
681
+ "Varlen CLC currently requires cluster_shape_mn[0] == 1"
682
+ )
683
+ return SingleTileVarlenScheduler.Params(
684
+ num_head=args.num_head,
685
+ num_batch=args.num_batch,
686
+ total_q=args.total_q,
687
+ num_splits=args.num_splits,
688
+ max_kvblock_in_l2=max_kvblock_in_l2,
689
+ tile_shape_mn=args.tile_shape_mn,
690
+ mCuSeqlensQ=args.mCuSeqlensQ,
691
+ mSeqUsedQ=args.mSeqUsedQ,
692
+ qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
693
+ lpt=args.lpt,
694
+ is_split_kv=args.is_split_kv,
695
+ head_swizzle=args.head_swizzle,
696
+ cluster_shape_m=args.cluster_shape_mn[0],
697
+ scheduling_mode=scheduling_mode,
698
+ )
699
+
700
+ def __init__(
701
+ self,
702
+ params: Params,
703
+ tile_idx: Int32,
704
+ split_idx: Int32,
705
+ clc: ClcState | None = None,
706
+ *,
707
+ loc=None,
708
+ ip=None,
709
+ ):
710
+ self.params = params
711
+ self._tile_idx = tile_idx
712
+ self._split_idx = split_idx
713
+ self._is_first_block = True
714
+ self.clc = clc
715
+ self._loc = loc
716
+ self._ip = ip
717
+
718
+ @staticmethod
719
+ def to_underlying_arguments(
720
+ args: TileSchedulerArguments,
721
+ *,
722
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
723
+ loc=None,
724
+ ip=None,
725
+ ) -> Params:
726
+ return SingleTileVarlenScheduler.Params.create(
727
+ args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
728
+ )
729
+
730
+ @staticmethod
731
+ @cute.jit
732
+ def clc_problem_shape(params: Params):
733
+ return ClcDynamicPersistentTileSchedulerParams(
734
+ problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params),
735
+ cluster_shape_mnk=(1, 1, 1),
736
+ )
737
+
738
+ @staticmethod
739
+ @cute.jit
740
+ def create(
741
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
742
+ ) -> "SingleTileVarlenScheduler":
743
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
744
+ block_idx = cute.arch.block_idx()
745
+ split_idx = Int32(0)
746
+ if const_expr(params.is_split_kv):
747
+ split_idx = block_idx[1]
748
+ return SingleTileVarlenScheduler(
749
+ params,
750
+ block_idx[0],
751
+ split_idx,
752
+ clc,
753
+ loc=loc,
754
+ ip=ip,
755
+ )
756
+ tile_idx, split_idx, _ = cute.arch.block_idx()
757
+ return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
758
+
759
+ # called by host
760
+ @staticmethod
761
+ def get_grid_shape(
762
+ params: Params,
763
+ *,
764
+ loc=None,
765
+ ip=None,
766
+ ) -> Tuple[Int32, Int32, Int32]:
767
+ total_blocks_max = (
768
+ params.total_q
769
+ + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
770
+ ) // params.tile_shape_mn[0]
771
+ # Round down to nearest multiple of cluster since odd excess is always padding.
772
+ total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
773
+ return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
774
+
775
+ @cute.jit
776
+ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
777
+ params = self.params
778
+ batch_idx = lane + bidb_start
779
+ if cutlass.const_expr(params.mSeqUsedQ is not None):
780
+ seqlen = Int32(0)
781
+ if batch_idx < params.num_batch:
782
+ seqlen = params.mSeqUsedQ[batch_idx]
783
+ else:
784
+ assert params.mCuSeqlensQ is not None
785
+ cur_cu_seqlen = Int32(0)
786
+ if batch_idx <= params.num_batch:
787
+ cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
788
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
789
+ seqlen = next_cu_seqlen - cur_cu_seqlen
790
+ if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
791
+ seqlen *= params.qhead_per_kvhead_packgqa
792
+ return (
793
+ cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m)
794
+ if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
795
+ else Int32(0)
796
+ )
797
+
798
+ @cute.jit
799
+ def _varlen_coord_map(self) -> WorkTileInfo:
800
+ """Map self._tile_idx to (block, head, batch) via warp-level prefix sums."""
801
+ params = self.params
802
+ lane_idx = cute.arch.lane_idx()
803
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
804
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
805
+ # Total number of blocks for the next 31 batches
806
+ m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
807
+ # Same for all lanes
808
+ group_end_tile = m_blocks_in_group * params.num_head
809
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
810
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
811
+ next_tile_idx = self._tile_idx // params.cluster_shape_m
812
+ while group_end_tile <= next_tile_idx:
813
+ batch_idx += cute.arch.WARP_SIZE - 1
814
+ if batch_idx >= params.num_batch:
815
+ batch_idx = Int32(params.num_batch)
816
+ group_end_tile = next_tile_idx + 1
817
+ else:
818
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
819
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
820
+ m_blocks_in_group = cute.arch.shuffle_sync(
821
+ num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
822
+ )
823
+ group_end_tile += m_blocks_in_group * params.num_head
824
+ is_valid = False
825
+ if batch_idx >= params.num_batch:
826
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
827
+ else:
828
+ group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
829
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
830
+ # The next problem to process is the first one that does not have ending tile position
831
+ # that is greater than or equal to tile index.
832
+ batch_idx_in_group = cute.arch.popc(
833
+ cute.arch.vote_ballot_sync(
834
+ group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
835
+ )
836
+ )
837
+ batch_idx += batch_idx_in_group
838
+ num_m_blocks_prev_lane = (
839
+ 0
840
+ if batch_idx_in_group == 0
841
+ else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
842
+ )
843
+ num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
844
+ mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
845
+ if cutlass.const_expr(params.lpt or params.head_swizzle):
846
+ # This is a version of the SingleTileLPTScheduler, complicated by the fact that
847
+ # the seqlen can vary per batch.
848
+ # TODO: is there any case where num_m_blocks is 0?
849
+ # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
850
+ num_n_blocks = (
851
+ num_m_blocks
852
+ * params.tile_shape_mn[0]
853
+ * params.cluster_shape_m
854
+ // params.qhead_per_kvhead_packgqa
855
+ // params.tile_shape_mn[1]
856
+ )
857
+ # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
858
+ # Seems faster to have this be a power of 2
859
+ nheads_in_l2 = (
860
+ 16
861
+ if num_n_blocks * 16 <= params.max_kvblock_in_l2
862
+ else (
863
+ 8
864
+ if num_n_blocks * 8 <= params.max_kvblock_in_l2
865
+ else (
866
+ 4
867
+ if num_n_blocks * 4 <= params.max_kvblock_in_l2
868
+ else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
869
+ )
870
+ )
871
+ )
872
+ nheads_in_l2 = min(nheads_in_l2, params.num_head)
873
+ mh_in_l2 = nheads_in_l2 * num_m_blocks
874
+ section_idx = mh_block // mh_in_l2
875
+ l2_mod = mh_block - section_idx * mh_in_l2
876
+ # Deal with tail section
877
+ nheads_in_this_section = (
878
+ nheads_in_l2
879
+ if nheads_in_l2 * (section_idx + 1) <= params.num_head
880
+ else params.num_head - section_idx * nheads_in_l2
881
+ )
882
+ block = l2_mod // nheads_in_this_section
883
+ head_idx_residual = l2_mod - block * nheads_in_this_section
884
+ head_idx = section_idx * nheads_in_l2 + head_idx_residual
885
+ if cutlass.const_expr(params.lpt):
886
+ block = num_m_blocks - 1 - block
887
+ else:
888
+ head_idx = mh_block // num_m_blocks
889
+ block = mh_block - head_idx * num_m_blocks
890
+ is_valid = self._is_first_block and batch_idx < params.num_batch
891
+ if cutlass.const_expr(params.cluster_shape_m > 1):
892
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
893
+ block = block * params.cluster_shape_m + bidx_in_cluster[0]
894
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
895
+ split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
896
+ return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
897
+
898
+ @cute.jit
899
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
900
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
901
+ clc_work = self.clc.get_current_work()
902
+ # Default to grid_dim (one past last valid flat index) so _varlen_coord_map
903
+ # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when
904
+ # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural
905
+ # mismatch on self inside the runtime if.
906
+ new_tile_idx = cute.arch.grid_dim()[0]
907
+ new_split_idx = Int32(0)
908
+ if clc_work.is_valid_tile:
909
+ new_tile_idx = clc_work.tile_idx[0]
910
+ if const_expr(self.params.is_split_kv):
911
+ new_split_idx = clc_work.tile_idx[1]
912
+ self._tile_idx = new_tile_idx
913
+ self._split_idx = new_split_idx
914
+ return self._varlen_coord_map()
915
+
916
+ @cute.jit
917
+ def initial_work_tile_info(self, *, loc=None, ip=None):
918
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
919
+ clc_work = self.clc.initial_work_tile_info()
920
+ # See get_current_work for why grid_dim and local-then-assign.
921
+ new_tile_idx = cute.arch.grid_dim()[0]
922
+ new_split_idx = Int32(0)
923
+ if clc_work.is_valid_tile:
924
+ new_tile_idx = clc_work.tile_idx[0]
925
+ if const_expr(self.params.is_split_kv):
926
+ new_split_idx = clc_work.tile_idx[1]
927
+ self._tile_idx = new_tile_idx
928
+ self._split_idx = new_split_idx
929
+ return self._varlen_coord_map()
930
+
931
+ def prefetch_next_work(self, *, loc=None, ip=None):
932
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
933
+ self.clc.prefetch_next_work(loc=loc, ip=ip)
934
+
935
+ def advance_to_next_work(self, *, loc=None, ip=None):
936
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
937
+ self.clc.consumer_wait(loc=loc, ip=ip)
938
+ work = self.get_current_work()
939
+ self.clc.consumer_release(loc=loc, ip=ip)
940
+ return work
941
+ self._is_first_block = False
942
+ return self.get_current_work()
943
+
944
+ def producer_tail(self, *, loc=None, ip=None):
945
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
946
+ self.clc.producer_tail(loc=loc, ip=ip)
947
+
948
+ def __extract_mlir_values__(self):
949
+ values, self._values_pos = [], []
950
+ objs = [self.params, self._tile_idx, self._split_idx]
951
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
952
+ objs += [self.clc]
953
+ for obj in objs:
954
+ obj_values = cutlass.extract_mlir_values(obj)
955
+ values += obj_values
956
+ self._values_pos.append(len(obj_values))
957
+ return values
958
+
959
+ def __new_from_mlir_values__(self, values):
960
+ obj_list = []
961
+ objs = [self.params, self._tile_idx, self._split_idx]
962
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
963
+ objs += [self.clc]
964
+ for obj, n_items in zip(objs, self._values_pos):
965
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
966
+ values = values[n_items:]
967
+ return self.__class__(*obj_list, loc=self._loc)
build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Raw TMA ops and descriptor builders.
5
+
6
+ `tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA
7
+ descriptor construction. Non-TMA store/layout helpers are re-exported from
8
+ `copy_utils.py` for backward compatibility.
9
+ """
10
+
11
+ import ctypes
12
+
13
+ from cutlass import Int32, Int64
14
+ from cutlass.cutlass_dsl import T, dsl_user_op
15
+ from cutlass._mlir.dialects import llvm
16
+ import cutlass._mlir.dialects.cute as cute_ir
17
+ import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir
18
+ from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen
19
+
20
+
21
+ # Raw TMA Ops
22
+
23
+ TMA_CACHE_EVICT_FIRST = 0x12F0000000000000
24
+ TMA_CACHE_EVICT_LAST = 0x14F0000000000000
25
+
26
+
27
+ @dsl_user_op
28
+ def tma_tile_load(
29
+ smem_ptr,
30
+ smem_byte_offset,
31
+ tma_desc_ptr,
32
+ col_idx,
33
+ row_idx,
34
+ mbar_ptr,
35
+ *,
36
+ loc=None,
37
+ ip=None,
38
+ ):
39
+ """cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion."""
40
+ llvm.inline_asm(
41
+ T.i32(),
42
+ [
43
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
44
+ Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
45
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
46
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
47
+ Int32(row_idx).ir_value(loc=loc, ip=ip),
48
+ mbar_ptr.toint().ir_value(loc=loc, ip=ip),
49
+ ],
50
+ "{\n"
51
+ ".reg .u32 sa, ma;\n"
52
+ "cvt.u32.u64 sa, $1;\n"
53
+ "add.u32 sa, sa, $2;\n"
54
+ "cvt.u32.u64 ma, $6;\n"
55
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile"
56
+ ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n"
57
+ "mov.u32 $0, 0;\n"
58
+ "}\n",
59
+ "=r,l,r,l,r,r,l",
60
+ has_side_effects=True,
61
+ is_align_stack=False,
62
+ asm_dialect=llvm.AsmDialect.AD_ATT,
63
+ loc=loc,
64
+ ip=ip,
65
+ )
66
+
67
+
68
+ @dsl_user_op
69
+ def tma_gather4(
70
+ smem_ptr,
71
+ smem_byte_offset,
72
+ tma_desc_ptr,
73
+ col_idx,
74
+ row0,
75
+ row1,
76
+ row2,
77
+ row3,
78
+ mbar_ptr,
79
+ *,
80
+ loc=None,
81
+ ip=None,
82
+ ):
83
+ """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar."""
84
+ llvm.inline_asm(
85
+ T.i32(),
86
+ [
87
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
88
+ Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
89
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
90
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
91
+ Int32(row0).ir_value(loc=loc, ip=ip),
92
+ Int32(row1).ir_value(loc=loc, ip=ip),
93
+ Int32(row2).ir_value(loc=loc, ip=ip),
94
+ Int32(row3).ir_value(loc=loc, ip=ip),
95
+ mbar_ptr.toint().ir_value(loc=loc, ip=ip),
96
+ ],
97
+ "{\n"
98
+ ".reg .u32 sa, ma;\n"
99
+ "cvt.u32.u64 sa, $1;\n"
100
+ "add.u32 sa, sa, $2;\n"
101
+ "cvt.u32.u64 ma, $9;\n"
102
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4"
103
+ ".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n"
104
+ "mov.u32 $0, 0;\n"
105
+ "}\n",
106
+ "=r,l,r,l,r,r,r,r,r,l",
107
+ has_side_effects=True,
108
+ is_align_stack=False,
109
+ asm_dialect=llvm.AsmDialect.AD_ATT,
110
+ loc=loc,
111
+ ip=ip,
112
+ )
113
+
114
+
115
+ @dsl_user_op
116
+ def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None):
117
+ """Prefetch a raw TMA descriptor pointer into the descriptor cache."""
118
+ ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip)
119
+ ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width)
120
+ ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip)
121
+ ptr_ty = cute_ir.PtrType.get(
122
+ cute_nvgpu_ir.TmaDescriptorTiledType.get(),
123
+ cute_ir.AddressSpace.gmem,
124
+ 128,
125
+ )
126
+ desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip)
127
+ cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip)
128
+
129
+
130
+ @dsl_user_op
131
+ def tma_tile_prefetch(
132
+ tma_desc_ptr,
133
+ col_idx,
134
+ row_idx,
135
+ cache_hint=TMA_CACHE_EVICT_FIRST,
136
+ *,
137
+ loc=None,
138
+ ip=None,
139
+ ):
140
+ """cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint."""
141
+ llvm.inline_asm(
142
+ None,
143
+ [
144
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
145
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
146
+ Int32(row_idx).ir_value(loc=loc, ip=ip),
147
+ Int64(cache_hint).ir_value(loc=loc, ip=ip),
148
+ ],
149
+ "cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint "
150
+ "[$0, {$1, $2}], $3;\n",
151
+ "l,r,r,l",
152
+ has_side_effects=True,
153
+ is_align_stack=False,
154
+ asm_dialect=llvm.AsmDialect.AD_ATT,
155
+ loc=loc,
156
+ ip=ip,
157
+ )
158
+
159
+
160
+ @dsl_user_op
161
+ def tma_gather4_prefetch(
162
+ tma_desc_ptr,
163
+ col_idx,
164
+ row0,
165
+ row1,
166
+ row2,
167
+ row3,
168
+ cache_hint=TMA_CACHE_EVICT_LAST,
169
+ *,
170
+ loc=None,
171
+ ip=None,
172
+ ):
173
+ """cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint."""
174
+ llvm.inline_asm(
175
+ None,
176
+ [
177
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
178
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
179
+ Int32(row0).ir_value(loc=loc, ip=ip),
180
+ Int32(row1).ir_value(loc=loc, ip=ip),
181
+ Int32(row2).ir_value(loc=loc, ip=ip),
182
+ Int32(row3).ir_value(loc=loc, ip=ip),
183
+ Int64(cache_hint).ir_value(loc=loc, ip=ip),
184
+ ],
185
+ "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint "
186
+ "[$0, {$1, $2, $3, $4, $5}], $6;\n",
187
+ "l,r,r,r,r,r,l",
188
+ has_side_effects=True,
189
+ is_align_stack=False,
190
+ asm_dialect=llvm.AsmDialect.AD_ATT,
191
+ loc=loc,
192
+ ip=ip,
193
+ )
194
+
195
+
196
+ @dsl_user_op
197
+ def tma_tile_load_cached(
198
+ smem_ptr,
199
+ smem_byte_offset,
200
+ tma_desc_ptr,
201
+ col_idx,
202
+ row_idx,
203
+ mbar_ptr,
204
+ cache_hint=TMA_CACHE_EVICT_FIRST,
205
+ *,
206
+ loc=None,
207
+ ip=None,
208
+ ):
209
+ """cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar."""
210
+ llvm.inline_asm(
211
+ T.i32(),
212
+ [
213
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
214
+ Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
215
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
216
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
217
+ Int32(row_idx).ir_value(loc=loc, ip=ip),
218
+ mbar_ptr.toint().ir_value(loc=loc, ip=ip),
219
+ Int64(cache_hint).ir_value(loc=loc, ip=ip),
220
+ ],
221
+ "{\n"
222
+ ".reg .u32 sa, ma;\n"
223
+ "cvt.u32.u64 sa, $1;\n"
224
+ "add.u32 sa, sa, $2;\n"
225
+ "cvt.u32.u64 ma, $6;\n"
226
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile"
227
+ ".mbarrier::complete_tx::bytes.L2::cache_hint "
228
+ "[sa], [$3, {$4, $5}], [ma], $7;\n"
229
+ "mov.u32 $0, 0;\n"
230
+ "}\n",
231
+ "=r,l,r,l,r,r,l,l",
232
+ has_side_effects=True,
233
+ is_align_stack=False,
234
+ asm_dialect=llvm.AsmDialect.AD_ATT,
235
+ loc=loc,
236
+ ip=ip,
237
+ )
238
+
239
+
240
+ @dsl_user_op
241
+ def tma_gather4_cached(
242
+ smem_ptr,
243
+ smem_byte_offset,
244
+ tma_desc_ptr,
245
+ col_idx,
246
+ row0,
247
+ row1,
248
+ row2,
249
+ row3,
250
+ mbar_ptr,
251
+ cache_hint=TMA_CACHE_EVICT_LAST,
252
+ *,
253
+ loc=None,
254
+ ip=None,
255
+ ):
256
+ """cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint."""
257
+ llvm.inline_asm(
258
+ None,
259
+ [
260
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
261
+ Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
262
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
263
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
264
+ Int32(row0).ir_value(loc=loc, ip=ip),
265
+ Int32(row1).ir_value(loc=loc, ip=ip),
266
+ Int32(row2).ir_value(loc=loc, ip=ip),
267
+ Int32(row3).ir_value(loc=loc, ip=ip),
268
+ mbar_ptr.toint().ir_value(loc=loc, ip=ip),
269
+ Int64(cache_hint).ir_value(loc=loc, ip=ip),
270
+ ],
271
+ "{\n"
272
+ ".reg .u32 sa, ma;\n"
273
+ "cvt.u32.u64 sa, $0;\n"
274
+ "add.u32 sa, sa, $1;\n"
275
+ "cvt.u32.u64 ma, $8;\n"
276
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4"
277
+ ".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint "
278
+ "[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n"
279
+ "}\n",
280
+ "l,r,l,r,r,r,r,r,l,l",
281
+ has_side_effects=True,
282
+ is_align_stack=False,
283
+ asm_dialect=llvm.AsmDialect.AD_ATT,
284
+ loc=loc,
285
+ ip=ip,
286
+ )
287
+
288
+
289
+ @dsl_user_op
290
+ def tma_tile_store(
291
+ tma_desc_ptr,
292
+ col_idx,
293
+ row_idx,
294
+ smem_ptr,
295
+ smem_byte_offset,
296
+ *,
297
+ loc=None,
298
+ ip=None,
299
+ ):
300
+ """cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store."""
301
+ llvm.inline_asm(
302
+ T.i32(),
303
+ [
304
+ tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
305
+ Int32(col_idx).ir_value(loc=loc, ip=ip),
306
+ Int32(row_idx).ir_value(loc=loc, ip=ip),
307
+ smem_ptr.toint().ir_value(loc=loc, ip=ip),
308
+ Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
309
+ ],
310
+ "{\n"
311
+ ".reg .u32 sa;\n"
312
+ "cvt.u32.u64 sa, $4;\n"
313
+ "add.u32 sa, sa, $5;\n"
314
+ "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group"
315
+ " [$1, {$2, $3}], [sa];\n"
316
+ "mov.u32 $0, 0;\n"
317
+ "}\n",
318
+ "=r,l,r,r,l,r",
319
+ has_side_effects=True,
320
+ is_align_stack=False,
321
+ asm_dialect=llvm.AsmDialect.AD_ATT,
322
+ loc=loc,
323
+ ip=ip,
324
+ )
325
+
326
+
327
+ # Descriptor Builders
328
+
329
+ _TMA_DESC_BYTES = 128
330
+
331
+
332
+ def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes:
333
+ import torch
334
+ import cuda.bindings.driver as cuda
335
+
336
+ if tensor_2d.ndim != 2:
337
+ raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}")
338
+ rows, cols = tensor_2d.shape
339
+ if tensor_2d.stride(-1) != 1:
340
+ raise ValueError(f"{context} tensor must be contiguous in the last dimension")
341
+ dtype_map = {
342
+ torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
343
+ torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
344
+ torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
345
+ }
346
+ if tensor_2d.dtype not in dtype_map:
347
+ raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}")
348
+
349
+ sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)]
350
+ strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())]
351
+ box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)]
352
+ elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)]
353
+ err, tm = cuda.cuTensorMapEncodeTiled(
354
+ dtype_map[tensor_2d.dtype],
355
+ 2,
356
+ tensor_2d.data_ptr(),
357
+ sizes,
358
+ strides,
359
+ box,
360
+ elem_stride,
361
+ cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
362
+ cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
363
+ cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
364
+ cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
365
+ )
366
+ assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}"
367
+ buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr())
368
+ return bytes(buf)
369
+
370
+
371
+ def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device):
372
+ import torch
373
+
374
+ desc_bytes = bytes(desc_bytes)
375
+ device = torch.device(device)
376
+ if device.type != "cuda":
377
+ raise ValueError(f"TMA descriptors require a CUDA device, got {device}")
378
+
379
+ host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True)
380
+ host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8))
381
+ device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device)
382
+ stream = torch.cuda.current_stream(device)
383
+ with torch.cuda.stream(stream):
384
+ device_desc.copy_(host_desc, non_blocking=True)
385
+ device_desc.record_stream(stream)
386
+ # Keep the staging buffer alive for the async copy without caching descriptors.
387
+ device_desc._tma_host_desc = host_desc
388
+ return device_desc
389
+
390
+
391
+ def create_flat_gather4_tma_desc(tensor_2d, box_x=64):
392
+ """Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor."""
393
+ if tensor_2d.ndim != 2:
394
+ raise ValueError(
395
+ f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}"
396
+ )
397
+ desc = _encode_tma_desc_2d_bytes(
398
+ tensor_2d,
399
+ box_x=box_x,
400
+ box_y=1,
401
+ context="gather4",
402
+ )
403
+ return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device)
404
+
405
+
406
+ def create_q_gather4_tma_desc(q_flat, box_x=64):
407
+ return create_flat_gather4_tma_desc(q_flat, box_x=box_x)
408
+
409
+
410
+ def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y):
411
+ """Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride."""
412
+ desc = _encode_tma_desc_2d_bytes(
413
+ tensor_2d,
414
+ box_x=box_x,
415
+ box_y=box_y,
416
+ context="strided 2D",
417
+ )
418
+ return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device)
419
+
420
+
421
+ def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128):
422
+ """Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage."""
423
+ import torch
424
+
425
+ if kv_flat.ndim != 3:
426
+ raise ValueError(
427
+ f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}"
428
+ )
429
+ total_k, head_kv, dim = kv_flat.shape
430
+ row_stride = head_kv * dim
431
+ desc_table = bytearray()
432
+ for h in range(head_kv):
433
+ head_view = torch.as_strided(
434
+ kv_flat,
435
+ size=(total_k, dim),
436
+ stride=(row_stride, 1),
437
+ storage_offset=h * dim,
438
+ )
439
+ desc_table.extend(
440
+ _encode_tma_desc_2d_bytes(
441
+ head_view,
442
+ box_x=box_x,
443
+ box_y=box_y,
444
+ context="flat KV",
445
+ )
446
+ )
447
+ return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape(
448
+ head_kv, _TMA_DESC_BYTES
449
+ )
450
+
451
+
452
+ # Compatibility Re-exports
453
+
454
+ from .copy_utils import (
455
+ atomic_add_broadcast_i32,
456
+ atomic_add_i32,
457
+ convert_layout_acc_mn,
458
+ convert_layout_from_tmem16x256b_to_acc_sm90,
459
+ make_16x256b_tensor_mn_view,
460
+ real_col_to_stg128_fake_col,
461
+ real_col_to_stg128_fp8_fake_col,
462
+ real_col_to_stg128_half_fake_col,
463
+ stg128_fake_col_to_real_col,
464
+ stg128_fp8_fake_col_to_real_col,
465
+ stg128_half_fake_col_to_real_col,
466
+ stg_128,
467
+ stg_128_cs,
468
+ stg_128_bf16,
469
+ stg_128_bf16_cs,
470
+ stg_128_f16,
471
+ stg_128_f16_cs,
472
+ stg_128_fp8_e4m3_cs,
473
+ stg_32_fp8_e4m3,
474
+ stg_64_bf16,
475
+ stg_64_f16,
476
+ )
477
+
478
+
479
+ __all__ = [
480
+ "TMA_CACHE_EVICT_FIRST",
481
+ "TMA_CACHE_EVICT_LAST",
482
+ "atomic_add_broadcast_i32",
483
+ "atomic_add_i32",
484
+ "convert_layout_acc_mn",
485
+ "convert_layout_from_tmem16x256b_to_acc_sm90",
486
+ "create_flat_gather4_tma_desc",
487
+ "create_flat_kv_tma_descs",
488
+ "create_q_gather4_tma_desc",
489
+ "create_strided_2d_tma_desc",
490
+ "make_16x256b_tensor_mn_view",
491
+ "prefetch_tma_desc_raw",
492
+ "real_col_to_stg128_fake_col",
493
+ "real_col_to_stg128_fp8_fake_col",
494
+ "real_col_to_stg128_half_fake_col",
495
+ "stg128_fake_col_to_real_col",
496
+ "stg128_fp8_fake_col_to_real_col",
497
+ "stg128_half_fake_col_to_real_col",
498
+ "stg_128",
499
+ "stg_128_cs",
500
+ "stg_128_bf16",
501
+ "stg_128_bf16_cs",
502
+ "stg_128_f16",
503
+ "stg_128_f16_cs",
504
+ "stg_128_fp8_e4m3_cs",
505
+ "stg_32_fp8_e4m3",
506
+ "stg_64_bf16",
507
+ "stg_64_f16",
508
+ "tma_gather4",
509
+ "tma_gather4_cached",
510
+ "tma_gather4_prefetch",
511
+ "tma_tile_load",
512
+ "tma_tile_load_cached",
513
+ "tma_tile_prefetch",
514
+ "tma_tile_store",
515
+ ]
build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import math
5
+ import hashlib
6
+ import inspect
7
+ from typing import Type, Callable, Optional, Tuple, overload
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+
12
+ from cutlass import Float32, const_expr
13
+ from cutlass.cutlass_dsl import T, dsl_user_op
14
+ from cutlass._mlir.dialects import nvvm, llvm
15
+ from cutlass.cute.runtime import from_dlpack
16
+
17
+
18
+ from ...quack import activation
19
+ _MIXER_ATTRS = ("__vec_size__",)
20
+
21
+ # Obtained from sollya:
22
+ # fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
23
+ POLY_EX2 = {
24
+ 0: (1.0),
25
+ 1: (
26
+ 1.0,
27
+ 0.922497093677520751953125,
28
+ ),
29
+ 2: (
30
+ 1.0,
31
+ 0.6657850742340087890625,
32
+ 0.330107033252716064453125,
33
+ ),
34
+ 3: (
35
+ 1.0,
36
+ 0.695146143436431884765625,
37
+ 0.227564394474029541015625,
38
+ 0.077119089663028717041015625,
39
+ ),
40
+ 4: (
41
+ 1.0,
42
+ 0.693042695522308349609375,
43
+ 0.2412912547588348388671875,
44
+ 5.2225358784198760986328125e-2,
45
+ 1.3434938155114650726318359375e-2,
46
+ ),
47
+ 5: (
48
+ 1.0,
49
+ 0.693151414394378662109375,
50
+ 0.24016360938549041748046875,
51
+ 5.5802188813686370849609375e-2,
52
+ 9.01452265679836273193359375e-3,
53
+ 1.86810153536498546600341796875e-3,
54
+ ),
55
+ }
56
+
57
+
58
+ def _compute_base_hash(func: Callable) -> str:
59
+ """Compute hash from source code or bytecode and closure values."""
60
+ try:
61
+ data = inspect.getsource(func).encode()
62
+ except (OSError, TypeError):
63
+ if hasattr(func, "__code__") and func.__code__ is not None:
64
+ data = func.__code__.co_code
65
+ else:
66
+ data = repr(func).encode()
67
+
68
+ hasher = hashlib.sha256(data)
69
+
70
+ if hasattr(func, "__closure__") and func.__closure__ is not None:
71
+ for cell in func.__closure__:
72
+ hasher.update(repr(cell.cell_contents).encode())
73
+
74
+ return hasher.hexdigest()
75
+
76
+
77
+ def hash_callable(
78
+ func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
79
+ ) -> str:
80
+ """Hash a callable based on the source code or bytecode and closure values.
81
+ Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
82
+ attribute, that value is returned immediately as the base hash, then
83
+ metadata dunders are mixed in to produce the final dict-key hash.
84
+ set_cute_hash: whether or not to set func.__cute_hash__
85
+ """
86
+ # Resolve base hash
87
+ if hasattr(func, "__cute_hash__"):
88
+ base_hash = func.__cute_hash__
89
+ else:
90
+ # Unwrap decorated functions (e.g., cute.jit wrappers).
91
+ base_func = getattr(func, "__wrapped__", func)
92
+
93
+ if hasattr(base_func, "__cute_hash__"):
94
+ base_hash = base_func.__cute_hash__
95
+ else:
96
+ base_hash = _compute_base_hash(base_func)
97
+
98
+ if set_cute_hash:
99
+ base_func.__cute_hash__ = base_hash
100
+
101
+ # Mix in mutable metadata dunders
102
+ mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
103
+
104
+ if all(v is None for v in mixer_values):
105
+ return base_hash
106
+
107
+ hasher = hashlib.sha256(base_hash.encode())
108
+
109
+ for attr, val in zip(_MIXER_ATTRS, mixer_values):
110
+ hasher.update(f"{attr}={val!r}".encode())
111
+
112
+ return hasher.hexdigest()
113
+
114
+
115
+ LOG2_E = math.log2(math.e)
116
+
117
+
118
+ def compute_softmax_scale_log2(softmax_scale):
119
+ """Compute softmax_scale_log2 from softmax_scale.
120
+
121
+ Returns (softmax_scale_log2, None).
122
+ """
123
+ return softmax_scale * LOG2_E, None
124
+
125
+
126
+ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
127
+ return (
128
+ from_dlpack(x, assumed_align=alignment)
129
+ .mark_layout_dynamic(leading_dim=leading_dim)
130
+ .mark_compact_shape_dynamic(
131
+ mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
132
+ )
133
+ )
134
+
135
+
136
+ def make_tiled_copy_A(
137
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
138
+ ) -> cute.TiledCopy:
139
+ if const_expr(swapAB):
140
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
141
+ else:
142
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
143
+
144
+
145
+ def make_tiled_copy_B(
146
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
147
+ ) -> cute.TiledCopy:
148
+ if const_expr(swapAB):
149
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
150
+ else:
151
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
152
+
153
+
154
+ def mma_make_fragment_A(
155
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
156
+ ) -> cute.Tensor:
157
+ if const_expr(swapAB):
158
+ return mma_make_fragment_B(smem, thr_mma)
159
+ else:
160
+ return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
161
+
162
+
163
+ def mma_make_fragment_B(
164
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
165
+ ) -> cute.Tensor:
166
+ if const_expr(swapAB):
167
+ return mma_make_fragment_A(smem, thr_mma)
168
+ else:
169
+ return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
170
+
171
+
172
+ def get_smem_store_atom(
173
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
174
+ ) -> cute.CopyAtom:
175
+ if const_expr(arch < 90 or element_type.width != 16):
176
+ return cute.make_copy_atom(
177
+ cute.nvgpu.CopyUniversalOp(),
178
+ element_type,
179
+ num_bits_per_copy=2 * element_type.width,
180
+ )
181
+ else:
182
+ return cute.make_copy_atom(
183
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
184
+ element_type,
185
+ )
186
+
187
+
188
+ @cute.jit
189
+ def warp_reduce(
190
+ val: cute.TensorSSA | cute.Numeric,
191
+ op: Callable,
192
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
193
+ ) -> cute.TensorSSA | cute.Numeric:
194
+ if const_expr(isinstance(val, cute.TensorSSA)):
195
+ res = cute.make_rmem_tensor(val.shape, val.dtype)
196
+ res.store(val)
197
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
198
+ res[i] = warp_reduce(res[i], op, width)
199
+ return res.load()
200
+ else:
201
+ for i in cutlass.range_constexpr(int(math.log2(width))):
202
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
203
+ return val
204
+
205
+
206
+ @dsl_user_op
207
+ def fmax(
208
+ a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
209
+ ) -> Float32:
210
+ from cutlass import CUDA_VERSION
211
+
212
+ # * NVVM call based on nvvm version
213
+ if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
214
+ # Old API: requires explicit result type as first positional argument
215
+ return Float32(
216
+ nvvm.fmax(
217
+ T.f32(),
218
+ Float32(a).ir_value(loc=loc, ip=ip),
219
+ Float32(b).ir_value(loc=loc, ip=ip),
220
+ c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
221
+ loc=loc,
222
+ ip=ip,
223
+ )
224
+ )
225
+ else:
226
+ # New API: infers result type automatically
227
+ return Float32(
228
+ nvvm.fmax(
229
+ Float32(a).ir_value(loc=loc, ip=ip),
230
+ Float32(b).ir_value(loc=loc, ip=ip),
231
+ c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
232
+ loc=loc,
233
+ ip=ip,
234
+ )
235
+ )
236
+
237
+
238
+ @cute.jit
239
+ def fmax_reduce(
240
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
241
+ ) -> Float32:
242
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
243
+ res = cute.make_rmem_tensor(x.shape, Float32)
244
+ res.store(x)
245
+ local_max = [res[0], res[1], res[2], res[3]]
246
+ for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
247
+ local_max[0] = fmax(local_max[0], res[i + 0])
248
+ local_max[1] = fmax(local_max[1], res[i + 1])
249
+ local_max[2] = fmax(local_max[2], res[i + 2])
250
+ local_max[3] = fmax(local_max[3], res[i + 3])
251
+ local_max[0] = fmax(local_max[0], local_max[1])
252
+ local_max[2] = fmax(local_max[2], local_max[3])
253
+ local_max[0] = fmax(local_max[0], local_max[2])
254
+ return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
255
+ else:
256
+ res = cute.make_rmem_tensor(x.shape, Float32)
257
+ res.store(x)
258
+ local_max_0 = (
259
+ fmax(init_val, res[0], res[1])
260
+ if const_expr(init_val is not None)
261
+ else fmax(res[0], res[1])
262
+ )
263
+ local_max = [
264
+ local_max_0,
265
+ fmax(res[2], res[3]),
266
+ fmax(res[4], res[5]),
267
+ fmax(res[6], res[7]),
268
+ ]
269
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
270
+ local_max[0] = fmax(local_max[0], res[i], res[i + 1])
271
+ local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
272
+ local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
273
+ local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
274
+ local_max[0] = fmax(local_max[0], local_max[1])
275
+ return fmax(local_max[0], local_max[2], local_max[3])
276
+
277
+
278
+ @cute.jit
279
+ def fadd_reduce(
280
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
281
+ ) -> Float32:
282
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
283
+ if const_expr(init_val is None):
284
+ init_val = Float32.zero
285
+ return x.reduce(cute.ReductionOp.ADD, init_val, 0)
286
+ else:
287
+ res = cute.make_rmem_tensor(x.shape, Float32)
288
+ res.store(x)
289
+ local_sum_0 = (
290
+ cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
291
+ if const_expr(init_val is not None)
292
+ else (res[0], res[1])
293
+ )
294
+ local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
295
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
296
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
297
+ local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
298
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
299
+ local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
300
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
301
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
302
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
303
+ return local_sum[0][0] + local_sum[0][1]
304
+
305
+
306
+ @cute.jit
307
+ def fadd_exp2_scaled_reduce(
308
+ x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80
309
+ ) -> Float32:
310
+ assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements"
311
+ if const_expr(arch < 100):
312
+ return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch)
313
+ elif const_expr(cute.size(x.shape) % 8 == 0):
314
+ local_sum = [
315
+ (Float32(0.0), Float32(0.0)),
316
+ (Float32(0.0), Float32(0.0)),
317
+ (Float32(0.0), Float32(0.0)),
318
+ (Float32(0.0), Float32(0.0)),
319
+ ]
320
+ for i in cutlass.range_constexpr(0, cute.size(x.shape), 8):
321
+ acc0, acc1 = cute.arch.mul_packed_f32x2(
322
+ (x[i + 0], x[i + 1]), (scale, scale)
323
+ )
324
+ acc2, acc3 = cute.arch.mul_packed_f32x2(
325
+ (x[i + 2], x[i + 3]), (scale, scale)
326
+ )
327
+ acc4, acc5 = cute.arch.mul_packed_f32x2(
328
+ (x[i + 4], x[i + 5]), (scale, scale)
329
+ )
330
+ acc6, acc7 = cute.arch.mul_packed_f32x2(
331
+ (x[i + 6], x[i + 7]), (scale, scale)
332
+ )
333
+ acc0 = cute.math.exp2(acc0, fastmath=True)
334
+ acc1 = cute.math.exp2(acc1, fastmath=True)
335
+ acc2 = cute.math.exp2(acc2, fastmath=True)
336
+ acc3 = cute.math.exp2(acc3, fastmath=True)
337
+ acc4 = cute.math.exp2(acc4, fastmath=True)
338
+ acc5 = cute.math.exp2(acc5, fastmath=True)
339
+ acc6 = cute.math.exp2(acc6, fastmath=True)
340
+ acc7 = cute.math.exp2(acc7, fastmath=True)
341
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1))
342
+ local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3))
343
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5))
344
+ local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7))
345
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
346
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
347
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
348
+ return local_sum[0][0] + local_sum[0][1]
349
+ else:
350
+ row_sum = Float32(0.0)
351
+ for i in cutlass.range_constexpr(0, cute.size(x.shape), 2):
352
+ acc0, acc1 = cute.arch.mul_packed_f32x2(
353
+ (x[i], x[i + 1]), (scale, scale)
354
+ )
355
+ acc0 = cute.math.exp2(acc0, fastmath=True)
356
+ acc1 = cute.math.exp2(acc1, fastmath=True)
357
+ row_sum += acc0 + acc1
358
+ return row_sum
359
+
360
+
361
+ @dsl_user_op
362
+ def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
363
+ nvvm.atomicrmw(
364
+ res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
365
+ )
366
+
367
+
368
+ @dsl_user_op
369
+ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
370
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
371
+
372
+
373
+ @cute.jit
374
+ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
375
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
376
+ tApA = cute.make_rmem_tensor(
377
+ cute.make_layout(
378
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
379
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
380
+ ),
381
+ cutlass.Boolean,
382
+ )
383
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
384
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
385
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
386
+ return tApA
387
+
388
+
389
+ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
390
+ warp_group_idx = cute.arch.thread_idx()[0] // 128
391
+ if const_expr(sync):
392
+ warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
393
+ return warp_group_idx
394
+
395
+
396
+ @cute.jit
397
+ def shuffle_sync(
398
+ value: cute.Numeric,
399
+ offset: cute.typing.Int,
400
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
401
+ ) -> cute.Numeric:
402
+ assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
403
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
404
+ mask = cute.arch.WARP_SIZE - width
405
+ clamp = cute.arch.WARP_SIZE - 1
406
+ mask_and_clamp = mask << 8 | clamp
407
+ # important: need stride 1 and not 0 for recast_tensor to work
408
+ val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
409
+ val[0] = value
410
+ val_i32 = cute.recast_tensor(val, cutlass.Int32)
411
+ for i in cutlass.range_constexpr(cute.size(val_i32)):
412
+ val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
413
+ return val[0]
414
+
415
+
416
+ @dsl_user_op
417
+ def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
418
+ """
419
+ Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).
420
+
421
+ Named ``shl_u32`` (not ``shl_b32``) because python type annotations
422
+ distinguish signed/unsigned.
423
+
424
+ PTX semantics (9.7.8.8): "Shift amounts greater than the register width N
425
+ are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0.
426
+
427
+ This differs from C/C++ and LLVM IR, where shifting by >= the type width is
428
+ undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain
429
+ Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer
430
+ may treat the result as poison and eliminate dependent code. Inline PTX
431
+ bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim
432
+ into PTX where clamping makes it safe for all shift amounts.
433
+ """
434
+ return cutlass.Uint32(
435
+ llvm.inline_asm(
436
+ T.i32(),
437
+ [
438
+ cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
439
+ cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
440
+ ],
441
+ "shl.b32 $0, $1, $2;",
442
+ "=r,r,r",
443
+ has_side_effects=False,
444
+ is_align_stack=False,
445
+ asm_dialect=llvm.AsmDialect.AD_ATT,
446
+ )
447
+ )
448
+
449
+
450
+ @dsl_user_op
451
+ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
452
+ """
453
+ Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).
454
+
455
+ See ``shl_u32`` docstring for why inline PTX is used instead of plain
456
+ CuTeDSL shift operators (LLVM shift-by-type-width UB).
457
+ """
458
+ return cutlass.Uint32(
459
+ llvm.inline_asm(
460
+ T.i32(),
461
+ [
462
+ cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
463
+ cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
464
+ ],
465
+ "shr.u32 $0, $1, $2;",
466
+ "=r,r,r",
467
+ has_side_effects=False,
468
+ is_align_stack=False,
469
+ asm_dialect=llvm.AsmDialect.AD_ATT,
470
+ )
471
+ )
472
+
473
+
474
+ @cute.jit
475
+ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
476
+ if const_expr(lane is None):
477
+ lane = cute.arch.lane_idx()
478
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
479
+ offset = 1 << i
480
+ # Very important that we set mask_and_clamp to 0
481
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
482
+ if lane >= offset:
483
+ val += partial_sum
484
+ return val
485
+
486
+
487
+ @dsl_user_op
488
+ def cvt_f16x2_f32(
489
+ a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
490
+ ) -> cutlass.Int32:
491
+ assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
492
+ return cutlass.Int32(
493
+ llvm.inline_asm(
494
+ T.i32(),
495
+ [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
496
+ f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
497
+ "=r,f,f",
498
+ has_side_effects=False,
499
+ is_align_stack=False,
500
+ asm_dialect=llvm.AsmDialect.AD_ATT,
501
+ )
502
+ )
503
+
504
+
505
+ @dsl_user_op
506
+ def cvt_fp8x4_e4m3_f32(
507
+ a: float | Float32,
508
+ b: float | Float32,
509
+ c: float | Float32,
510
+ d: float | Float32,
511
+ *,
512
+ loc=None,
513
+ ip=None,
514
+ ) -> cutlass.Int32:
515
+ return cutlass.Int32(
516
+ llvm.inline_asm(
517
+ T.i32(),
518
+ [
519
+ Float32(a).ir_value(loc=loc, ip=ip),
520
+ Float32(b).ir_value(loc=loc, ip=ip),
521
+ Float32(c).ir_value(loc=loc, ip=ip),
522
+ Float32(d).ir_value(loc=loc, ip=ip),
523
+ ],
524
+ "{\n"
525
+ ".reg .b16 h0, h1;\n"
526
+ "cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n"
527
+ "cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n"
528
+ "mov.b32 $0, {h0, h1};\n"
529
+ "}\n",
530
+ "=r,f,f,f,f",
531
+ has_side_effects=False,
532
+ is_align_stack=False,
533
+ asm_dialect=llvm.AsmDialect.AD_ATT,
534
+ )
535
+ )
536
+
537
+
538
+ @dsl_user_op
539
+ def cvt_fp8x4_e4m3_bf16x4(
540
+ src: cutlass.Int32,
541
+ *,
542
+ loc=None,
543
+ ip=None,
544
+ ) -> Tuple[cutlass.Int32, cutlass.Int32]:
545
+ """Convert packed e4m3x4 bits into two packed bf16x2 registers."""
546
+ out0 = cutlass.Int32(
547
+ llvm.inline_asm(
548
+ T.i32(),
549
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
550
+ "{\n\t"
551
+ ".reg .b32 q, mant, out, bias, zero;\n\t"
552
+ "prmt.b32 q, $1, $1, 0x1302;\n\t"
553
+ "and.b32 out, q, 0x80008000;\n\t"
554
+ "and.b32 mant, q, 0x7f007f00;\n\t"
555
+ "shr.u32 mant, mant, 4;\n\t"
556
+ "or.b32 out, out, mant;\n\t"
557
+ "mov.b32 bias, 0x7b807b80;\n\t"
558
+ "mov.b32 zero, 0;\n\t"
559
+ "fma.rn.bf16x2 $0, out, bias, zero;\n\t"
560
+ "}\n",
561
+ "=r,r",
562
+ has_side_effects=False,
563
+ is_align_stack=False,
564
+ asm_dialect=llvm.AsmDialect.AD_ATT,
565
+ )
566
+ )
567
+ out1 = cutlass.Int32(
568
+ llvm.inline_asm(
569
+ T.i32(),
570
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
571
+ "{\n\t"
572
+ ".reg .b32 q, qs, mant, out, bias, zero;\n\t"
573
+ "prmt.b32 q, $1, $1, 0x1302;\n\t"
574
+ "shl.b32 qs, q, 8;\n\t"
575
+ "and.b32 out, qs, 0x80008000;\n\t"
576
+ "and.b32 mant, qs, 0x7f007f00;\n\t"
577
+ "shr.u32 mant, mant, 4;\n\t"
578
+ "or.b32 out, out, mant;\n\t"
579
+ "mov.b32 bias, 0x7b807b80;\n\t"
580
+ "mov.b32 zero, 0;\n\t"
581
+ "fma.rn.bf16x2 $0, out, bias, zero;\n\t"
582
+ "}\n",
583
+ "=r,r",
584
+ has_side_effects=False,
585
+ is_align_stack=False,
586
+ asm_dialect=llvm.AsmDialect.AD_ATT,
587
+ )
588
+ )
589
+ return out0, out1
590
+
591
+
592
+ @dsl_user_op
593
+ def cvt_fp4x2_e2m1_f16x2(
594
+ src: cutlass.Int32,
595
+ *,
596
+ loc=None,
597
+ ip=None,
598
+ ) -> cutlass.Int32:
599
+ """Convert one packed E2M1 byte into one packed f16x2 register."""
600
+
601
+ return cutlass.Int32(
602
+ llvm.inline_asm(
603
+ T.i32(),
604
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
605
+ "{\n\t"
606
+ ".reg .b8 byte0;\n\t"
607
+ "mov.b32 {byte0, _, _, _}, $1;\n\t"
608
+ "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t"
609
+ "}\n",
610
+ "=r,r",
611
+ has_side_effects=False,
612
+ is_align_stack=False,
613
+ asm_dialect=llvm.AsmDialect.AD_ATT,
614
+ )
615
+ )
616
+
617
+
618
+ @dsl_user_op
619
+ def cvt_fp4x8_e2m1_f16x8(
620
+ src: cutlass.Int32,
621
+ *,
622
+ loc=None,
623
+ ip=None,
624
+ ) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]:
625
+ """Convert four packed E2M1 bytes into four packed f16x2 registers."""
626
+
627
+ out = llvm.inline_asm(
628
+ llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
629
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
630
+ "{\n\t"
631
+ ".reg .b8 byte0, byte1, byte2, byte3;\n\t"
632
+ "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t"
633
+ "cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t"
634
+ "cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t"
635
+ "cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t"
636
+ "cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t"
637
+ "}\n",
638
+ "=r,=r,=r,=r,r",
639
+ has_side_effects=False,
640
+ is_align_stack=False,
641
+ asm_dialect=llvm.AsmDialect.AD_ATT,
642
+ )
643
+ out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
644
+ out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
645
+ out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip))
646
+ out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip))
647
+ return out0, out1, out2, out3
648
+
649
+
650
+ @dsl_user_op
651
+ def cvt_fp4x8_e2m1_bf16x8(
652
+ src: cutlass.Int32,
653
+ *,
654
+ loc=None,
655
+ ip=None,
656
+ ) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]:
657
+ """Convert four packed E2M1 bytes into four packed bf16x2 registers."""
658
+
659
+ from cutlass import CUDA_VERSION
660
+
661
+ if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2):
662
+ out = llvm.inline_asm(
663
+ llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
664
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
665
+ "{\n\t"
666
+ ".reg .b8 byte0, byte1, byte2, byte3;\n\t"
667
+ "mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t"
668
+ "cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t"
669
+ "cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t"
670
+ "cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t"
671
+ "cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t"
672
+ "}\n",
673
+ "=r,=r,=r,=r,r",
674
+ has_side_effects=False,
675
+ is_align_stack=False,
676
+ asm_dialect=llvm.AsmDialect.AD_ATT,
677
+ )
678
+ out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
679
+ out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
680
+ out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip))
681
+ out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip))
682
+ return out0, out1, out2, out3
683
+
684
+ f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8(
685
+ src, loc=loc, ip=ip
686
+ )
687
+ return (
688
+ cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip),
689
+ cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip),
690
+ cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip),
691
+ cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip),
692
+ )
693
+
694
+
695
+ @dsl_user_op
696
+ def cvt_fp4x8_e2m1_scaled_e4m3x8(
697
+ src: cutlass.Int32,
698
+ scale_e4m3: cutlass.Int32,
699
+ *,
700
+ loc=None,
701
+ ip=None,
702
+ ) -> Tuple[cutlass.Int32, cutlass.Int32]:
703
+ """Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3."""
704
+
705
+ from cutlass import CUDA_VERSION
706
+
707
+ if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2):
708
+ out = llvm.inline_asm(
709
+ llvm.StructType.get_literal([T.i32(), T.i32()]),
710
+ [
711
+ cutlass.Int32(src).ir_value(loc=loc, ip=ip),
712
+ cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip),
713
+ ],
714
+ "{\n\t"
715
+ ".reg .b32 tmp, ra;\n\t"
716
+ ".reg .b8 byte0, byte1, byte2, byte3;\n\t"
717
+ "prmt.b32 tmp, $3, 0, 0;\n\t"
718
+ "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t"
719
+ "mov.b32 ra, {byte0, byte1, _, _};\n\t"
720
+ "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t"
721
+ "mov.b32 ra, {_, _, byte2, byte3};\n\t"
722
+ "mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t"
723
+ "}\n",
724
+ "=r,=r,r,r",
725
+ has_side_effects=False,
726
+ is_align_stack=False,
727
+ asm_dialect=llvm.AsmDialect.AD_ATT,
728
+ )
729
+ out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
730
+ out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
731
+ return out0, out1
732
+
733
+ out = llvm.inline_asm(
734
+ llvm.StructType.get_literal([T.i32(), T.i32()]),
735
+ [
736
+ cutlass.Int32(src).ir_value(loc=loc, ip=ip),
737
+ cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip),
738
+ ],
739
+ "{\n\t"
740
+ ".reg .b32 sf_bytes, sf_f16x2;\n\t"
741
+ ".reg .b16 sf_pair, e0, e1, e2, e3;\n\t"
742
+ ".reg .b8 byte0, byte1, byte2, byte3;\n\t"
743
+ ".reg .b32 h0, h1, h2, h3;\n\t"
744
+ "prmt.b32 sf_bytes, $3, 0, 0;\n\t"
745
+ "mov.b32 {sf_pair, _}, sf_bytes;\n\t"
746
+ "cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t"
747
+ "mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t"
748
+ "cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t"
749
+ "cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t"
750
+ "cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t"
751
+ "cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t"
752
+ "mul.rn.f16x2 h0, h0, sf_f16x2;\n\t"
753
+ "mul.rn.f16x2 h1, h1, sf_f16x2;\n\t"
754
+ "mul.rn.f16x2 h2, h2, sf_f16x2;\n\t"
755
+ "mul.rn.f16x2 h3, h3, sf_f16x2;\n\t"
756
+ "cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t"
757
+ "cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t"
758
+ "cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t"
759
+ "cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t"
760
+ "mov.b32 $0, {e0, e1};\n\t"
761
+ "mov.b32 $1, {e2, e3};\n\t"
762
+ "}\n",
763
+ "=r,=r,r,r",
764
+ has_side_effects=False,
765
+ is_align_stack=False,
766
+ asm_dialect=llvm.AsmDialect.AD_ATT,
767
+ )
768
+ out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
769
+ out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
770
+ return out0, out1
771
+
772
+
773
+ @dsl_user_op
774
+ def cvt_f16x2_to_bf16x2(
775
+ src: cutlass.Int32,
776
+ *,
777
+ loc=None,
778
+ ip=None,
779
+ ) -> cutlass.Int32:
780
+ """Convert a packed f16x2 register into a packed bf16x2 register."""
781
+
782
+ return cutlass.Int32(
783
+ llvm.inline_asm(
784
+ T.i32(),
785
+ [cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
786
+ "{\n\t"
787
+ ".reg .b16 h0, h1;\n\t"
788
+ ".reg .f32 f0, f1;\n\t"
789
+ "mov.b32 {h0, h1}, $1;\n\t"
790
+ "cvt.f32.f16 f0, h0;\n\t"
791
+ "cvt.f32.f16 f1, h1;\n\t"
792
+ "cvt.rn.bf16x2.f32 $0, f1, f0;\n\t"
793
+ "}\n",
794
+ "=r,r",
795
+ has_side_effects=False,
796
+ is_align_stack=False,
797
+ asm_dialect=llvm.AsmDialect.AD_ATT,
798
+ )
799
+ )
800
+
801
+
802
+ @dsl_user_op
803
+ def mul_bf16x2(
804
+ a: cutlass.Int32,
805
+ b: cutlass.Int32,
806
+ *,
807
+ loc=None,
808
+ ip=None,
809
+ ) -> cutlass.Int32:
810
+ """Multiply two packed bf16x2 registers."""
811
+
812
+ return cutlass.Int32(
813
+ llvm.inline_asm(
814
+ T.i32(),
815
+ [
816
+ cutlass.Int32(a).ir_value(loc=loc, ip=ip),
817
+ cutlass.Int32(b).ir_value(loc=loc, ip=ip),
818
+ ],
819
+ "mul.rn.bf16x2 $0, $1, $2;",
820
+ "=r,r,r",
821
+ has_side_effects=False,
822
+ is_align_stack=False,
823
+ asm_dialect=llvm.AsmDialect.AD_ATT,
824
+ )
825
+ )
826
+
827
+
828
+ @cute.jit
829
+ def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32:
830
+ """Decode one E4M3 byte and replicate it into a packed bf16x2 register."""
831
+
832
+ src_u8 = src & cutlass.Int32(0xFF)
833
+ packed = src_u8 * cutlass.Int32(0x01010101)
834
+ out0, _ = cvt_fp8x4_e4m3_bf16x4(packed)
835
+ return out0
836
+
837
+
838
+ @overload
839
+ def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
840
+
841
+
842
+ @overload
843
+ def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
844
+
845
+
846
+ @cute.jit
847
+ def cvt_f16(src: cute.Tensor, dst_or_dtype):
848
+ """Convert Float32 tensor to Float16/BFloat16.
849
+
850
+ Args:
851
+ src: Source tensor with Float32 element type
852
+ dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
853
+
854
+ Returns:
855
+ None if dst is a tensor, or a new tensor if dtype is provided
856
+ """
857
+ if const_expr(isinstance(dst_or_dtype, type)):
858
+ # dtype variant: create new tensor and call the tensor variant
859
+ dtype = dst_or_dtype
860
+ dst = cute.make_rmem_tensor(src.shape, dtype)
861
+ cvt_f16(src, dst)
862
+ return dst
863
+ else:
864
+ # tensor variant: write to dst
865
+ dst = dst_or_dtype
866
+ assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
867
+ assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
868
+ assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
869
+ "dst must be BFloat16 or Float16"
870
+ )
871
+ assert src.element_type is Float32, "src must be Float32"
872
+ dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
873
+ assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
874
+ for i in cutlass.range_constexpr(cute.size(dst_i32)):
875
+ dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
876
+
877
+
878
+ @cute.jit
879
+ def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None:
880
+ """Convert a Float32 rmem tensor to dst's element type.
881
+
882
+ fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment
883
+ ``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the
884
+ DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8
885
+ bytes within a 32-bit register cell in the order DSL chooses, which is
886
+ expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads.
887
+ """
888
+ if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]):
889
+ cvt_f16(src, dst)
890
+ elif const_expr(dst.element_type is cutlass.Float8E4M3FN):
891
+ assert src.element_type is Float32, "src must be Float32"
892
+ assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size"
893
+ assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements"
894
+ frg_tile = 4
895
+ src_frg = cute.logical_divide(src, cute.make_layout(frg_tile))
896
+ dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile))
897
+ for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])):
898
+ dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type))
899
+ else:
900
+ assert src.element_type is Float32, "src must be Float32"
901
+ dst_view = cute.make_tensor(dst.iterator, src.layout)
902
+ dst_view.store(src.load().to(dst.element_type))
903
+
904
+
905
+ @dsl_user_op
906
+ @cute.jit
907
+ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
908
+ deg = len(poly) - 1
909
+ out = poly[deg]
910
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
911
+ out = out * x + poly[i]
912
+ return out
913
+
914
+
915
+ @dsl_user_op
916
+ @cute.jit
917
+ def evaluate_polynomial_2(
918
+ x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
919
+ ) -> Tuple[Float32, Float32]:
920
+ deg = len(poly) - 1
921
+ out = (poly[deg], poly[deg])
922
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
923
+ out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
924
+ return out
925
+
926
+
927
+ @dsl_user_op
928
+ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
929
+ # There's probably a way to call llvm or nvvm to do this instead of ptx
930
+ return cutlass.Float32(
931
+ llvm.inline_asm(
932
+ T.f32(),
933
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
934
+ "add.rm.ftz.f32 $0, $1, $2;",
935
+ "=f,f,f",
936
+ has_side_effects=False,
937
+ is_align_stack=False,
938
+ asm_dialect=llvm.AsmDialect.AD_ATT,
939
+ )
940
+ )
941
+
942
+
943
+ @dsl_user_op
944
+ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
945
+ return cutlass.Float32(
946
+ llvm.inline_asm(
947
+ T.f32(),
948
+ [
949
+ Float32(x_rounded).ir_value(loc=loc, ip=ip),
950
+ Float32(frac_ex2).ir_value(loc=loc, ip=ip),
951
+ ],
952
+ "{\n\t"
953
+ ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
954
+ "mov.b32 x_rounded_i, $1;\n\t"
955
+ "mov.b32 frac_ex_i, $2;\n\t"
956
+ "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
957
+ # add.u32 generates IMAD instruction and add.s32 generates LEA instruction
958
+ # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
959
+ "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
960
+ "mov.b32 $0, out_i;\n\t"
961
+ "}\n",
962
+ "=f,f,f",
963
+ has_side_effects=False,
964
+ is_align_stack=False,
965
+ asm_dialect=llvm.AsmDialect.AD_ATT,
966
+ )
967
+ )
968
+
969
+
970
+ @dsl_user_op
971
+ def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
972
+ assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
973
+ # We assume x <= 127.0
974
+ fp32_round_int = float(2**23 + 2**22)
975
+ x_clamped = cute.arch.fmax(x, -127.0)
976
+ # We want to round down here, so that the fractional part is in [0, 1)
977
+ x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
978
+ # The integer floor of x is now in the last 8 bits of x_rounded
979
+ # We assume the next 2 ops round to nearest even. The rounding mode is important.
980
+ x_rounded_back = x_rounded - fp32_round_int
981
+ x_frac = x_clamped - x_rounded_back
982
+ x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
983
+ return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
984
+
985
+
986
+ @dsl_user_op
987
+ def ex2_emulation_2(
988
+ x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
989
+ ) -> Tuple[Float32, Float32]:
990
+ # We assume x <= 127.0 and y <= 127.0
991
+ fp32_round_int = float(2**23 + 2**22)
992
+ xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
993
+ # We want to round down here, so that the fractional part is in [0, 1)
994
+ xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
995
+ # The integer floor of x & y are now in the last 8 bits of xy_rounded
996
+ # We want the next 2 ops to round to nearest even. The rounding mode is important.
997
+ xy_rounded_back = activation.sub_packed_f32x2(
998
+ xy_rounded, (fp32_round_int, fp32_round_int)
999
+ )
1000
+ xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
1001
+ xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
1002
+ x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
1003
+ y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
1004
+ return x_out, y_out
1005
+
1006
+
1007
+ @dsl_user_op
1008
+ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
1009
+ out_f32x2 = llvm.inline_asm(
1010
+ llvm.StructType.get_literal([T.f32(), T.f32()]),
1011
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
1012
+ "{\n\t"
1013
+ ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
1014
+ ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
1015
+ ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
1016
+ "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
1017
+ "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
1018
+ "mov.b64 l1, {f1, f2};\n\t"
1019
+ "mov.f32 f3, 0f4B400000;\n\t"
1020
+ "mov.b64 l2, {f3, f3};\n\t"
1021
+ "add.rm.ftz.f32x2 l7, l1, l2;\n\t"
1022
+ "sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
1023
+ "sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
1024
+ "mov.f32 f7, 0f3D9DF09D;\n\t"
1025
+ "mov.b64 l6, {f7, f7};\n\t"
1026
+ "mov.f32 f6, 0f3E6906A4;\n\t"
1027
+ "mov.b64 l5, {f6, f6};\n\t"
1028
+ "mov.f32 f5, 0f3F31F519;\n\t"
1029
+ "mov.b64 l4, {f5, f5};\n\t"
1030
+ "mov.f32 f4, 0f3F800000;\n\t"
1031
+ "mov.b64 l3, {f4, f4};\n\t"
1032
+ "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
1033
+ "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
1034
+ "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
1035
+ "mov.b64 {r1, r2}, l7;\n\t"
1036
+ "mov.b64 {r3, r4}, l10;\n\t"
1037
+ "shl.b32 r5, r1, 23;\n\t"
1038
+ "add.s32 r7, r5, r3;\n\t"
1039
+ "shl.b32 r6, r2, 23;\n\t"
1040
+ "add.s32 r8, r6, r4;\n\t"
1041
+ "mov.b32 $0, r7;\n\t"
1042
+ "mov.b32 $1, r8;\n\t"
1043
+ "}\n",
1044
+ "=r,=r,f,f",
1045
+ has_side_effects=False,
1046
+ is_align_stack=False,
1047
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1048
+ )
1049
+ out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
1050
+ out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
1051
+ return out0, out1
1052
+
1053
+
1054
+ @dsl_user_op
1055
+ def domain_offset_aligned(
1056
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
1057
+ ) -> cute.Tensor:
1058
+ assert isinstance(tensor.iterator, cute.Pointer)
1059
+ # We assume that applying the offset does not change the pointer alignment
1060
+ new_ptr = cute.make_ptr(
1061
+ tensor.element_type,
1062
+ elem_pointer(tensor, coord).toint(),
1063
+ tensor.memspace,
1064
+ assumed_align=tensor.iterator.alignment,
1065
+ )
1066
+ return cute.make_tensor(new_ptr, tensor.layout)
1067
+
1068
+
1069
+ @cute.jit
1070
+ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
1071
+ """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
1072
+ vec = cute.make_rmem_tensor(1, dtype)
1073
+ vec[0] = a
1074
+ return vec.load()
1075
+
1076
+
1077
+ def ssa_to_scalar(val):
1078
+ """Could inline but nice for reflecting the above api"""
1079
+ return val[0]
1080
+
1081
+
1082
+ # ------------------------------------------------------------------
1083
+ # Host-side Python helpers (not @cute.jit — called from PyTorch host code)
1084
+ # ------------------------------------------------------------------
1085
+
1086
+ def default_softmax_scale(dim: int) -> float:
1087
+ """Default softmax scale: 1 / sqrt(dim)."""
1088
+ return 1.0 / math.sqrt(dim)
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """SM100 sparse attention kernels."""
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """q2k -> k2q CSR builder backed by the precompiled Torch ops.
5
+
6
+ The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built
7
+ ahead of time by kernel-builder; it is reached through the ``_ops``
8
+ namespace instead of being JIT-compiled at import time.
9
+
10
+ The kernel pipeline is tuned and verified for SM100; other
11
+ architectures are not supported.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+
18
+ from ...._ops import ops
19
+
20
+
21
+ def run_build_k2q_csr(
22
+ q2k: torch.Tensor,
23
+ cu_seqlens_q: torch.Tensor,
24
+ cu_seqlens_k: torch.Tensor,
25
+ row_ptr: torch.Tensor,
26
+ q_idx: torch.Tensor,
27
+ topk: int,
28
+ blk_kv: int,
29
+ total_rows: int,
30
+ max_kv_blocks: int,
31
+ ) -> None:
32
+ """In-place fill of ``row_ptr`` and ``q_idx``.
33
+
34
+ Args:
35
+ q2k: int32 [H, total_q, topK] contiguous (CUDA).
36
+ cu_seqlens_q: int32 [B+1] contiguous (CUDA).
37
+ cu_seqlens_k: int32 [B+1] contiguous (CUDA).
38
+ row_ptr: int32 [H, total_rows + 1] CUDA, written in place.
39
+ q_idx: int32 [H, total_q * topK] CUDA, written in place
40
+ (trailing slots set to -1).
41
+ topk: must be in {4, 8, 16, 32}.
42
+ blk_kv: must equal 128.
43
+ total_rows: sum over batches of ceil(seqlen_k / blk_kv).
44
+ max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound
45
+ used to size the row_map workspace and clamp valid kv ids.
46
+ """
47
+ ops.run_build_k2q_csr(
48
+ q2k,
49
+ cu_seqlens_q,
50
+ cu_seqlens_k,
51
+ row_ptr,
52
+ q_idx,
53
+ int(topk),
54
+ int(blk_kv),
55
+ int(total_rows),
56
+ int(max_kv_blocks),
57
+ )
58
+
59
+
60
+ def run_build_k2q_csr_with_schedule(
61
+ q2k: torch.Tensor,
62
+ cu_seqlens_q: torch.Tensor,
63
+ cu_seqlens_k: torch.Tensor,
64
+ row_ptr: torch.Tensor,
65
+ q_idx: torch.Tensor,
66
+ scheduler_metadata: torch.Tensor,
67
+ work_count: torch.Tensor,
68
+ qsplit_idx: torch.Tensor,
69
+ split_counts: torch.Tensor,
70
+ topk: int,
71
+ blk_kv: int,
72
+ total_rows: int,
73
+ max_kv_blocks: int,
74
+ target_q_per_cta: int,
75
+ work_capacity: int,
76
+ max_seqlen_q: int,
77
+ ) -> None:
78
+ """In-place fill of CSR plus fused sparse attention schedule metadata."""
79
+ ops.run_build_k2q_csr_with_schedule(
80
+ q2k,
81
+ cu_seqlens_q,
82
+ cu_seqlens_k,
83
+ row_ptr,
84
+ q_idx,
85
+ scheduler_metadata,
86
+ work_count,
87
+ qsplit_idx,
88
+ split_counts,
89
+ int(topk),
90
+ int(blk_kv),
91
+ int(total_rows),
92
+ int(max_kv_blocks),
93
+ int(target_q_per_cta),
94
+ int(work_capacity),
95
+ int(max_seqlen_q),
96
+ )
97
+
98
+
99
+ def is_supported(topk: int, blk_kv: int) -> bool:
100
+ return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128
101
+
102
+
103
+ __all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Split-KV schedule for paged fp8 decode attention.
5
+
6
+ The public PageKV representation remains this repo's rectangular page table:
7
+ ``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only
8
+ describes how query tiles and KV chunks are split into work items.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+
16
+ import torch
17
+
18
+
19
+ @dataclass
20
+ class DecodeAttentionSchedule:
21
+ split_kv: bool
22
+ cta_tile_q: int
23
+ num_q_tiles: int
24
+ kv_chunk_size_pages: int
25
+ kv_chunk_size_tokens: int
26
+ work_count: int
27
+ padded_work_count: int
28
+ partial_rows: int
29
+ max_split_count: int
30
+ max_grid_size: int
31
+ active_blocks_per_sm: int
32
+ num_sms: int
33
+ base_cta: int
34
+ request_indices: torch.Tensor
35
+ qo_tile_indices: torch.Tensor
36
+ kv_tile_indices: torch.Tensor
37
+ merge_indptr: torch.Tensor
38
+ o_indptr: torch.Tensor
39
+ block_valid_mask: torch.Tensor
40
+ kv_pages: torch.Tensor
41
+ split_counts: torch.Tensor
42
+
43
+
44
+ def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None:
45
+ if tensor.dtype != torch.int32:
46
+ raise TypeError(f"{name} must be torch.int32")
47
+ if tensor.ndim != 1:
48
+ raise ValueError(f"{name} must be rank-1")
49
+ if not tensor.is_cuda:
50
+ raise ValueError(f"{name} must be a CUDA tensor")
51
+ if not tensor.is_contiguous():
52
+ raise ValueError(f"{name} must be contiguous")
53
+
54
+
55
+ def prepare_decode_schedule(
56
+ *,
57
+ seqused_k: torch.Tensor,
58
+ page_size: int,
59
+ seqlen_q: int,
60
+ num_qo_heads: int,
61
+ num_kv_heads: int,
62
+ head_dim: int,
63
+ max_seqlen_k: int,
64
+ enable_cuda_graph: bool = False,
65
+ max_grid_size: Optional[int] = None,
66
+ fixed_split_size: Optional[int] = None,
67
+ disable_split_kv: bool = False,
68
+ ) -> DecodeAttentionSchedule:
69
+ """Build paged decode split-KV schedule on the GPU.
70
+
71
+ A single CUDA kernel reads ``seqused_k`` on device and writes all
72
+ schedule index arrays. Only a small summary tensor is D2H-synced so
73
+ the wrapper can size O_partial / pick the kernel grid / choose the
74
+ split-vs-non-split compile path.
75
+
76
+ ``max_seqlen_k`` is the host-side worst-case bound used to pad the
77
+ work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``.
78
+ """
79
+ _require_i32_cuda_1d(seqused_k, name="seqused_k")
80
+ # Hard cap: current single-CTA schedule kernel stores per-batch state
81
+ # in shared memory. Larger batches require a multi-CTA cooperative
82
+ # scheduler (unimplemented). Fail fast at the Python boundary so the
83
+ # error doesn't surface from inside the CUDA extension.
84
+ if int(seqused_k.shape[0]) > 1024:
85
+ raise NotImplementedError(
86
+ "decode schedule currently supports batch <= 1024 "
87
+ f"(got batch={int(seqused_k.shape[0])}). Larger batches need "
88
+ "the multi-CTA scheduler — not yet implemented."
89
+ )
90
+ # Two API-boundary checks tied to the kernel's packed-GQA layout
91
+ # (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8):
92
+ #
93
+ # (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as
94
+ # col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first
95
+ # q-token in the packed group) this is col_limit = seqlen_k - seqlen_q
96
+ # + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked
97
+ # row then enters a mask-codegen path with PTX-undefined shift counts
98
+ # and the kernel hangs. The condition is also semantically invalid
99
+ # in batched-decode: you can't emit seqlen_q new tokens with fewer
100
+ # than seqlen_q total context tokens (seqlen_k includes them).
101
+ #
102
+ # (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires
103
+ # when the LAST partial page has < q_tokens_per_group=8 valid
104
+ # columns, because then the *last MMA tile* hits the same all-masked
105
+ # row case for the trailing q-tokens.
106
+ #
107
+ # Both are tracked as a separate kernel-level TODO (un-pack the
108
+ # all-masked row → skip mask call, or saturate causal_col_limit at >= 1
109
+ # in mask.py). Until then, fail fast at the Python boundary with a
110
+ # clear message rather than letting the kernel timeout.
111
+ seqlen_q_i = int(seqlen_q)
112
+ bad_q = seqused_k < seqlen_q_i
113
+ if bool(bad_q.any().item()):
114
+ bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item())
115
+ bad_val = int(seqused_k[bad_idx].item())
116
+ raise ValueError(
117
+ f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) "
118
+ f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. "
119
+ f"This is also a batched-decode invariant: seqlen_k must include "
120
+ f"the seqlen_q new tokens being emitted."
121
+ )
122
+ rem = seqused_k % int(page_size)
123
+ bad_rem = (rem > 0) & (rem < seqlen_q_i)
124
+ if bool(bad_rem.any().item()):
125
+ bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item())
126
+ bad_val = int(seqused_k[bad_idx].item())
127
+ raise ValueError(
128
+ f"decode kernel requires seqused_k[b] % page_size ∈ "
129
+ f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. "
130
+ f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has "
131
+ f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). "
132
+ f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to "
133
+ f"a multiple of {page_size}."
134
+ )
135
+ if int(page_size) <= 0:
136
+ raise ValueError("page_size must be positive")
137
+ if int(seqlen_q) <= 0:
138
+ raise ValueError("seqlen_q must be positive")
139
+ if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0:
140
+ raise ValueError("head counts must be positive")
141
+ if int(num_qo_heads) % int(num_kv_heads) != 0:
142
+ raise ValueError("num_qo_heads must be divisible by num_kv_heads")
143
+ if int(num_qo_heads) // int(num_kv_heads) != 16:
144
+ raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16")
145
+ if int(head_dim) != 128:
146
+ raise NotImplementedError("decode schedule currently supports only head_dim=128")
147
+ if int(max_seqlen_k) <= 0:
148
+ raise ValueError("max_seqlen_k must be positive")
149
+
150
+ from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule
151
+
152
+ raw = build_decode_schedule(
153
+ seqused_k,
154
+ page_size=int(page_size),
155
+ seqlen_q=int(seqlen_q),
156
+ num_qo_heads=int(num_qo_heads),
157
+ num_kv_heads=int(num_kv_heads),
158
+ head_dim=int(head_dim),
159
+ max_seqlen_k=int(max_seqlen_k),
160
+ enable_cuda_graph=bool(enable_cuda_graph),
161
+ max_grid_size=0 if max_grid_size is None else int(max_grid_size),
162
+ fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size),
163
+ disable_split_kv=bool(disable_split_kv),
164
+ )
165
+ return DecodeAttentionSchedule(
166
+ split_kv=bool(raw["split_kv"]),
167
+ cta_tile_q=int(raw["cta_tile_q"]),
168
+ num_q_tiles=int(raw["num_q_tiles"]),
169
+ kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]),
170
+ kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]),
171
+ work_count=int(raw["work_count"]),
172
+ padded_work_count=int(raw["padded_work_count"]),
173
+ partial_rows=int(raw["partial_rows"]),
174
+ max_split_count=int(raw["max_split_count"]),
175
+ max_grid_size=int(raw["max_grid_size"]),
176
+ active_blocks_per_sm=int(raw["active_blocks_per_sm"]),
177
+ num_sms=int(raw["num_sms"]),
178
+ base_cta=int(raw["base_cta"]),
179
+ request_indices=raw["request_indices"],
180
+ qo_tile_indices=raw["qo_tile_indices"],
181
+ kv_tile_indices=raw["kv_tile_indices"],
182
+ merge_indptr=raw["merge_indptr"],
183
+ o_indptr=raw["o_indptr"],
184
+ block_valid_mask=raw["block_valid_mask"],
185
+ kv_pages=raw["kv_pages"],
186
+ split_counts=raw["split_counts"],
187
+ )
188
+
189
+
190
+ __all__ = [
191
+ "DecodeAttentionSchedule",
192
+ "prepare_decode_schedule",
193
+ ]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py ADDED
@@ -0,0 +1,1956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """SM100 FP4 sparse-attention indexer kernels."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Literal
10
+
11
+ import cuda.bindings.driver as cuda
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ import cutlass.pipeline as pipeline
15
+ import cutlass.utils as utils
16
+ import cutlass.utils.blackwell_helpers as sm100_utils
17
+ import cutlass.utils.blockscaled_layout as blockscaled_utils
18
+ import torch
19
+ from cutlass import Float32, Int32, const_expr
20
+ from cutlass.cute.nvgpu import cpasync, tcgen05
21
+
22
+ from ...src.common import pipeline as common_pipeline
23
+
24
+
25
+ FP4_FORMAT = Literal["mxfp4", "nvfp4"]
26
+ _FP4_PACKED_D_BYTES = 64
27
+ _HEAD_DIM = 128
28
+ _BLOCK_K = 128
29
+ _PAGE_SIZE = 128
30
+ _MMA_TILER_MN = (128, 128)
31
+ _MMA_INST_SHAPE_K = 64
32
+ _NON_CAUSAL_K_TILES_PER_CTA = 16
33
+ _CAUSAL_K_TILES_PER_CTA = 16
34
+ _DECODE_PACK_Q_LEN = 8
35
+ _DECODE_QHEAD_PER_KV = 16
36
+ _DECODE_K_TILES_PER_CTA = 16
37
+ _AB_DTYPE = cutlass.Float4E2M1FN
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class Fp4FormatSpec:
42
+ name: FP4_FORMAT
43
+ sf_vec_size: int
44
+ scale_groups: int
45
+ torch_scale_dtype: torch.dtype
46
+ cutlass_scale_dtype: type
47
+
48
+
49
+ _FORMAT_SPECS: dict[str, Fp4FormatSpec] = {
50
+ "mxfp4": Fp4FormatSpec(
51
+ name="mxfp4",
52
+ sf_vec_size=32,
53
+ scale_groups=4,
54
+ torch_scale_dtype=torch.float8_e8m0fnu,
55
+ cutlass_scale_dtype=cutlass.Float8E8M0FNU,
56
+ ),
57
+ "nvfp4": Fp4FormatSpec(
58
+ name="nvfp4",
59
+ sf_vec_size=16,
60
+ scale_groups=8,
61
+ torch_scale_dtype=torch.float8_e4m3fn,
62
+ cutlass_scale_dtype=cutlass.Float8E4M3FN,
63
+ ),
64
+ }
65
+
66
+
67
+ def normalize_fp4_format(fmt: str) -> Fp4FormatSpec:
68
+ key = str(fmt).lower()
69
+ try:
70
+ return _FORMAT_SPECS[key]
71
+ except KeyError as exc:
72
+ raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc
73
+
74
+
75
+ def ceil_div(x: int, y: int) -> int:
76
+ return (int(x) + int(y) - 1) // int(y)
77
+
78
+
79
+ def k_tiles_per_cta_for(causal: bool) -> int:
80
+ return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA
81
+
82
+
83
+ class Fp4IndexerScaleReorderSm100:
84
+ """Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout."""
85
+
86
+ def __init__(self, *, fmt: str):
87
+ spec = normalize_fp4_format(fmt)
88
+ self.fmt = spec.name
89
+ self.sf_dtype = spec.cutlass_scale_dtype
90
+ self.scale_groups = spec.scale_groups
91
+ self.threads_per_cta = 256
92
+
93
+ @cute.jit
94
+ def __call__(
95
+ self,
96
+ q_scale_ptr: cute.Pointer,
97
+ k_scale_ptr: cute.Pointer,
98
+ q_scale_mma_ptr: cute.Pointer,
99
+ k_scale_mma_ptr: cute.Pointer,
100
+ problem_size: tuple,
101
+ stream: cuda.CUstream,
102
+ ):
103
+ total_q, heads_q, page_count, heads_k = problem_size
104
+ rest_q_m = cute.ceil_div(total_q, 128)
105
+ rest_g = cute.ceil_div(self.scale_groups, 4)
106
+ k_l = page_count * heads_k
107
+
108
+ q_scale = cute.make_tensor(
109
+ q_scale_ptr,
110
+ cute.make_layout(
111
+ (total_q, heads_q, self.scale_groups),
112
+ stride=(heads_q * self.scale_groups, self.scale_groups, 1),
113
+ ),
114
+ )
115
+ k_scale = cute.make_tensor(
116
+ k_scale_ptr,
117
+ cute.make_layout(
118
+ (page_count, heads_k, _PAGE_SIZE, self.scale_groups),
119
+ stride=(
120
+ heads_k * _PAGE_SIZE * self.scale_groups,
121
+ _PAGE_SIZE * self.scale_groups,
122
+ self.scale_groups,
123
+ 1,
124
+ ),
125
+ ),
126
+ )
127
+
128
+ q_mma_layout = cute.make_ordered_layout(
129
+ (32, 4, rest_q_m, 4, rest_g, heads_q),
130
+ order=(2, 1, 4, 0, 3, 5),
131
+ )
132
+ k_mma_layout = cute.make_ordered_layout(
133
+ (32, 4, 1, 4, rest_g, k_l),
134
+ order=(2, 1, 4, 0, 3, 5),
135
+ )
136
+ q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout)
137
+ k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout)
138
+ q_scale_mma = cute.group_modes(q_scale_mma, 0, 3)
139
+ q_scale_mma = cute.group_modes(q_scale_mma, 1, 3)
140
+ k_scale_mma = cute.group_modes(k_scale_mma, 0, 3)
141
+ k_scale_mma = cute.group_modes(k_scale_mma, 1, 3)
142
+
143
+ q_scale_count = total_q * heads_q * Int32(self.scale_groups)
144
+ k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups)
145
+ total_scale_count = q_scale_count + k_scale_count
146
+ grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta)
147
+ self.kernel(
148
+ q_scale,
149
+ k_scale,
150
+ q_scale_mma,
151
+ k_scale_mma,
152
+ heads_q,
153
+ heads_k,
154
+ q_scale_count,
155
+ total_scale_count,
156
+ ).launch(
157
+ grid=(grid_ctas, 1, 1),
158
+ block=[self.threads_per_cta, 1, 1],
159
+ cluster=(1, 1, 1),
160
+ stream=stream,
161
+ )
162
+
163
+ @cute.kernel
164
+ def kernel(
165
+ self,
166
+ q_scale: cute.Tensor,
167
+ k_scale: cute.Tensor,
168
+ q_scale_mma: cute.Tensor,
169
+ k_scale_mma: cute.Tensor,
170
+ heads_q: Int32,
171
+ heads_k: Int32,
172
+ q_scale_count: Int32,
173
+ total_scale_count: Int32,
174
+ ):
175
+ tidx, _, _ = cute.arch.thread_idx()
176
+ block_idx, _, _ = cute.arch.block_idx()
177
+ grid_dim, _, _ = cute.arch.grid_dim()
178
+ linear = block_idx * Int32(self.threads_per_cta) + tidx
179
+ stride = grid_dim * Int32(self.threads_per_cta)
180
+
181
+ while linear < total_scale_count:
182
+ if linear < q_scale_count:
183
+ group = linear % Int32(self.scale_groups)
184
+ tmp = linear // Int32(self.scale_groups)
185
+ head = tmp % heads_q
186
+ row = tmp // heads_q
187
+ q_scale_mma[row, group, head] = q_scale[row, head, group]
188
+ else:
189
+ k_linear = linear - q_scale_count
190
+ group = k_linear % Int32(self.scale_groups)
191
+ tmp = k_linear // Int32(self.scale_groups)
192
+ row = tmp % Int32(_PAGE_SIZE)
193
+ tmp = tmp // Int32(_PAGE_SIZE)
194
+ head = tmp % heads_k
195
+ page = tmp // heads_k
196
+ scale_l = page * heads_k + head
197
+ k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group]
198
+ linear += stride
199
+
200
+
201
+ class Fp4IndexerStagedMmaSm100:
202
+ """Single-kernel FP4 indexer for preordered MMA scale storage."""
203
+
204
+ def __init__(
205
+ self,
206
+ *,
207
+ fmt: str,
208
+ causal: bool,
209
+ preordered_q_scale_tma: bool = False,
210
+ compact_schedule: bool = False,
211
+ use_tmem_load_red: bool = False,
212
+ ):
213
+ spec = normalize_fp4_format(fmt)
214
+ self.fmt = spec.name
215
+ self.is_causal = bool(causal)
216
+ self.preordered_q_scale_tma = bool(preordered_q_scale_tma)
217
+ self.compact_schedule = bool(compact_schedule)
218
+ self.use_tmem_load_red = bool(use_tmem_load_red)
219
+ self.sf_vec_size = spec.sf_vec_size
220
+ self.sf_dtype = spec.cutlass_scale_dtype
221
+ self.scale_groups = spec.scale_groups
222
+ self.use_nvfp4 = spec.name == "nvfp4"
223
+ self.epi_threads_per_cta = 128
224
+ self.epi_warps_per_group = 4
225
+ self.num_epi_warpgroups = 2
226
+ self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups
227
+ self.load_warp_id = self.mma_warp_id + 1
228
+ self.threads_per_cta = 384
229
+ self.num_tmem_alloc_cols = 512
230
+ self.num_q_stage = 1
231
+ self.num_acc_stage = 3
232
+ self.num_ab_stage = 3
233
+ self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal)
234
+
235
+ @cute.jit
236
+ def __call__(
237
+ self,
238
+ q_ptr: cute.Pointer,
239
+ k_ptr: cute.Pointer,
240
+ q_scale_ptr: cute.Pointer,
241
+ k_scale_ptr: cute.Pointer,
242
+ scores_ptr: cute.Pointer,
243
+ kv_indices_ptr: cute.Pointer,
244
+ cu_seqlens_q_ptr: cute.Pointer,
245
+ cu_seqlens_k_ptr: cute.Pointer,
246
+ cu_page_offsets_ptr: cute.Pointer,
247
+ qo_offset_ptr: cute.Pointer,
248
+ problem_size: tuple,
249
+ stream: cuda.CUstream,
250
+ ):
251
+ (
252
+ m,
253
+ _,
254
+ k,
255
+ _,
256
+ lk,
257
+ heads_q,
258
+ heads_k,
259
+ batch,
260
+ max_k_tiles,
261
+ total_q,
262
+ has_qo_offset,
263
+ compact_task_count,
264
+ ) = problem_size
265
+ page_count = lk // heads_k
266
+ self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2)
267
+ self.cta_tile_shape_mnk = self.mma_tiler
268
+
269
+ q_tma_tensor = cute.make_tensor(
270
+ cute.recast_ptr(q_ptr, dtype=_AB_DTYPE),
271
+ cute.make_layout(
272
+ (total_q, _HEAD_DIM, heads_q),
273
+ stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM),
274
+ ),
275
+ )
276
+ k_tma_tensor = cute.make_tensor(
277
+ cute.recast_ptr(k_ptr, dtype=_AB_DTYPE),
278
+ cute.make_layout(
279
+ (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count),
280
+ stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM),
281
+ ),
282
+ )
283
+ q_scale_tensor = cute.make_tensor(
284
+ q_scale_ptr,
285
+ blockscaled_utils.tile_atom_to_shape_SF(
286
+ (total_q, _HEAD_DIM, heads_q),
287
+ self.sf_vec_size,
288
+ ),
289
+ )
290
+ k_scale_tensor = cute.make_tensor(
291
+ k_scale_ptr,
292
+ blockscaled_utils.tile_atom_to_shape_SF(
293
+ (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k),
294
+ self.sf_vec_size,
295
+ ),
296
+ )
297
+ scores_tensor = cute.make_tensor(
298
+ scores_ptr,
299
+ cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)),
300
+ )
301
+ kv_indices_tensor = cute.make_tensor(
302
+ kv_indices_ptr,
303
+ cute.make_layout((page_count,), stride=(1,)),
304
+ )
305
+ cu_layout = cute.make_layout((batch + 1,), stride=(1,))
306
+ cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout)
307
+ cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout)
308
+ cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout)
309
+ qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,)))
310
+
311
+ if const_expr(self.use_nvfp4):
312
+ mma_op = tcgen05.MmaMXF4NVF4Op(
313
+ self.sf_dtype,
314
+ (*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
315
+ tcgen05.CtaGroup.ONE,
316
+ tcgen05.OperandSource.SMEM,
317
+ )
318
+ else:
319
+ mma_op = tcgen05.MmaMXF4Op(
320
+ (*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
321
+ tcgen05.CtaGroup.ONE,
322
+ tcgen05.OperandSource.SMEM,
323
+ )
324
+ tiled_mma = cute.make_tiled_mma(mma_op)
325
+ q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage)
326
+ k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage)
327
+ q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa(
328
+ tiled_mma,
329
+ self.mma_tiler,
330
+ self.sf_vec_size,
331
+ self.num_q_stage,
332
+ )
333
+ k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb(
334
+ tiled_mma,
335
+ self.mma_tiler,
336
+ self.sf_vec_size,
337
+ self.num_ab_stage,
338
+ )
339
+ cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1))
340
+ tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
341
+ q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0))
342
+ k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0))
343
+ tma_q = cute.nvgpu.make_tiled_tma_atom_A(
344
+ tma_load_op,
345
+ q_tma_tensor,
346
+ q_smem_layout_stage,
347
+ self.mma_tiler,
348
+ tiled_mma,
349
+ cluster_layout_vmnk.shape,
350
+ )
351
+ tma_k = cute.nvgpu.make_tiled_tma_atom_B(
352
+ tma_load_op,
353
+ k_tma_tensor,
354
+ k_smem_layout_stage,
355
+ self.mma_tiler,
356
+ tiled_mma,
357
+ cluster_layout_vmnk.shape,
358
+ )
359
+ if const_expr(self.preordered_q_scale_tma):
360
+ tma_qs = cute.nvgpu.make_tiled_tma_atom_A(
361
+ tma_load_op,
362
+ q_scale_tensor,
363
+ q_scale_smem_layout,
364
+ self.mma_tiler,
365
+ tiled_mma,
366
+ cluster_layout_vmnk.shape,
367
+ internal_type=cutlass.Int16,
368
+ )
369
+ else:
370
+ tma_qs = tma_q
371
+ tma_ks = cute.nvgpu.make_tiled_tma_atom_B(
372
+ tma_load_op,
373
+ k_scale_tensor,
374
+ k_scale_smem_layout,
375
+ self.mma_tiler,
376
+ tiled_mma,
377
+ cluster_layout_vmnk.shape,
378
+ internal_type=cutlass.Int16,
379
+ )
380
+ grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0])
381
+ grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta)
382
+ if const_expr(self.compact_schedule):
383
+ grid_x = compact_task_count
384
+ else:
385
+ grid_x = grid_q_tiles * grid_k_groups
386
+ self.kernel(
387
+ tiled_mma,
388
+ tma_q,
389
+ tma_qs,
390
+ tma_k,
391
+ tma_ks,
392
+ q_scale_tensor,
393
+ k_scale_tensor,
394
+ scores_tensor,
395
+ kv_indices_tensor,
396
+ cu_q_tensor,
397
+ cu_k_tensor,
398
+ cu_page_offsets_tensor,
399
+ qo_offset_tensor,
400
+ q_smem_layout,
401
+ k_smem_layout,
402
+ q_scale_smem_layout,
403
+ k_scale_smem_layout,
404
+ heads_q,
405
+ heads_k,
406
+ has_qo_offset,
407
+ max_k_tiles,
408
+ grid_k_groups,
409
+ ).launch(
410
+ grid=(grid_x, batch * heads_q, 1),
411
+ block=[self.threads_per_cta, 1, 1],
412
+ cluster=(1, 1, 1),
413
+ stream=stream,
414
+ )
415
+
416
+ @cute.jit
417
+ def _group_has_visible(
418
+ self,
419
+ q_tile_start: Int32,
420
+ q_tile_last: Int32,
421
+ q_len: Int32,
422
+ group_first_ktile: Int32,
423
+ batch_k_tiles: Int32,
424
+ causal_offset: Int32,
425
+ ):
426
+ visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles
427
+ if const_expr(self.is_causal):
428
+ visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
429
+ return visible
430
+
431
+ @cute.jit
432
+ def _tile_has_visible(
433
+ self,
434
+ q_tile_start: Int32,
435
+ q_tile_last: Int32,
436
+ q_len: Int32,
437
+ ktile: Int32,
438
+ batch_k_tiles: Int32,
439
+ causal_offset: Int32,
440
+ ):
441
+ visible = q_tile_start < q_len and ktile < batch_k_tiles
442
+ if const_expr(self.is_causal):
443
+ visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
444
+ return visible
445
+
446
+ @cute.jit
447
+ def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32):
448
+ if const_expr(self.is_causal):
449
+ return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset
450
+ return True
451
+
452
+ @cute.jit
453
+ def _full_tile_coord_visible(
454
+ self,
455
+ coord_m: Int32,
456
+ target_m: Int32,
457
+ q_local: Int32,
458
+ k_local: Int32,
459
+ causal_offset: Int32,
460
+ ):
461
+ visible = coord_m == target_m
462
+ if const_expr(self.is_causal):
463
+ visible = visible and k_local <= q_local + causal_offset
464
+ return visible
465
+
466
+ @cute.jit
467
+ def _partial_tile_coord_visible(
468
+ self,
469
+ coord_m: Int32,
470
+ target_m: Int32,
471
+ q_local: Int32,
472
+ k_local: Int32,
473
+ q_len: Int32,
474
+ k_len: Int32,
475
+ causal_offset: Int32,
476
+ ):
477
+ visible = coord_m == target_m and q_local < q_len and k_local < k_len
478
+ if const_expr(self.is_causal):
479
+ visible = visible and k_local <= q_local + causal_offset
480
+ return visible
481
+
482
+ @cute.kernel
483
+ def kernel(
484
+ self,
485
+ tiled_mma: cute.TiledMma,
486
+ tma_q: cpasync.TmaInfo,
487
+ tma_qs: cpasync.TmaInfo,
488
+ tma_k: cpasync.TmaInfo,
489
+ tma_ks: cpasync.TmaInfo,
490
+ mQS: cute.Tensor,
491
+ mKS: cute.Tensor,
492
+ mScores: cute.Tensor,
493
+ mKvIndices: cute.Tensor,
494
+ mCuQ: cute.Tensor,
495
+ mCuK: cute.Tensor,
496
+ mCuPages: cute.Tensor,
497
+ mQoOffset: cute.Tensor,
498
+ q_smem_layout: cute.ComposedLayout,
499
+ k_smem_layout: cute.ComposedLayout,
500
+ q_scale_smem_layout: cute.Layout,
501
+ k_scale_smem_layout: cute.Layout,
502
+ heads_q: Int32,
503
+ heads_k: Int32,
504
+ has_qo_offset: Int32,
505
+ max_k_tiles: Int32,
506
+ k_group_count: Int32,
507
+ ):
508
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
509
+ tidx, _, _ = cute.arch.thread_idx()
510
+ lane_idx = cute.arch.lane_idx()
511
+ epi_tidx = tidx % Int32(self.epi_threads_per_cta)
512
+ epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group)
513
+ task_idx, q_l, _ = cute.arch.block_idx()
514
+ batch_idx = q_l // heads_q
515
+ hq = q_l - batch_idx * heads_q
516
+ hk = hq // (heads_q // heads_k)
517
+ q_begin = mCuQ[batch_idx]
518
+ q_end = mCuQ[batch_idx + 1]
519
+ k_begin = mCuK[batch_idx]
520
+ k_end = mCuK[batch_idx + 1]
521
+ q_len = q_end - q_begin
522
+ k_len = k_end - k_begin
523
+ page_begin = mCuPages[batch_idx]
524
+ batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE)
525
+ causal_offset = Int32(0)
526
+ if const_expr(self.is_causal):
527
+ causal_offset = k_len - q_len
528
+ if has_qo_offset != 0:
529
+ causal_offset = mQoOffset[batch_idx]
530
+ task_valid = True
531
+ q_tile_idx = Int32(0)
532
+ ktile_group = Int32(0)
533
+ if const_expr(self.compact_schedule):
534
+ remaining = task_idx
535
+ q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0])
536
+ batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta)
537
+ q_scan = Int32(0)
538
+ task_valid = False
539
+ while q_scan < q_tile_count and not task_valid:
540
+ q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0])
541
+ q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1)
542
+ if q_scan_last >= q_len:
543
+ q_scan_last = q_len - Int32(1)
544
+ visible_limit = q_scan_last + causal_offset
545
+ visible_group_count = Int32(0)
546
+ if visible_limit >= Int32(0):
547
+ visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1)
548
+ if visible_group_count > batch_k_group_count:
549
+ visible_group_count = batch_k_group_count
550
+ task_valid = remaining < visible_group_count
551
+ if not task_valid:
552
+ remaining -= visible_group_count
553
+ q_scan += Int32(1)
554
+ if task_valid:
555
+ q_tile_idx = q_scan
556
+ ktile_group = remaining
557
+ else:
558
+ q_len = Int32(0)
559
+ k_len = Int32(0)
560
+ else:
561
+ q_tile_idx = task_idx // k_group_count
562
+ ktile_group = task_idx - q_tile_idx * k_group_count
563
+ q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0])
564
+ q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1)
565
+ if q_tile_last >= q_len:
566
+ q_tile_last = q_len - Int32(1)
567
+ q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len
568
+ q_tile_global_start = q_begin + q_tile_start
569
+ q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128)
570
+ group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta)
571
+ group_has_visible = self._group_has_visible(
572
+ q_tile_start,
573
+ q_tile_last,
574
+ q_len,
575
+ group_first_ktile,
576
+ batch_k_tiles,
577
+ causal_offset,
578
+ )
579
+
580
+ @cute.struct
581
+ class SharedStorage:
582
+ acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
583
+ q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
584
+ qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
585
+ k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
586
+ tmem_holding_buf: cutlass.Int32
587
+
588
+ smem = utils.SmemAllocator()
589
+ storage = smem.allocate(SharedStorage)
590
+ sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner)
591
+ sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner)
592
+ sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128)
593
+ sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128)
594
+ mQ_tma = tma_q.tma_tensor
595
+ mQS_tma = tma_qs.tma_tensor
596
+ mK_tma = tma_k.tma_tensor
597
+ mKS_tma = tma_ks.tma_tensor
598
+ thr_mma = tiled_mma.get_slice(0)
599
+ tCsQ = thr_mma.partition_A(sQ_public)
600
+ tCsK = thr_mma.partition_B(sK_public)
601
+ mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma)
602
+ gQ_tma = cute.local_tile(
603
+ mQ_tma_cur,
604
+ cute.slice_(self.mma_tiler, (None, 0, None)),
605
+ (None, None, None),
606
+ )
607
+ tCgQ_tma = thr_mma.partition_A(gQ_tma)
608
+ tQsQ_tma, tQgQ_tma = cpasync.tma_partition(
609
+ tma_q.atom,
610
+ 0,
611
+ cute.make_layout(1),
612
+ cute.group_modes(sQ_public, 0, 3),
613
+ cute.group_modes(tCgQ_tma, 0, 3),
614
+ )
615
+ if const_expr(self.preordered_q_scale_tma):
616
+ mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma)
617
+ gQS_tma = cute.local_tile(
618
+ mQS_tma_cur,
619
+ cute.slice_(self.mma_tiler, (None, 0, None)),
620
+ (None, None, None),
621
+ )
622
+ tCgQS_tma = thr_mma.partition_A(gQS_tma)
623
+ tQsQS_tma, tQgQS_tma = cpasync.tma_partition(
624
+ tma_qs.atom,
625
+ 0,
626
+ cute.make_layout(1),
627
+ cute.group_modes(sQS_public, 0, 3),
628
+ cute.group_modes(tCgQS_tma, 0, 3),
629
+ )
630
+ tQsQS_tma = cute.filter_zeros(tQsQS_tma)
631
+ tQgQS_tma = cute.filter_zeros(tQgQS_tma)
632
+ gK_tma = cute.local_tile(
633
+ mK_tma,
634
+ cute.slice_(self.mma_tiler, (0, None, None)),
635
+ (None, None, None, None),
636
+ )
637
+ tCgK_tma = thr_mma.partition_B(gK_tma)
638
+ tKsK_tma, tKgK_tma = cpasync.tma_partition(
639
+ tma_k.atom,
640
+ 0,
641
+ cute.make_layout(1),
642
+ cute.group_modes(sK_public, 0, 3),
643
+ cute.group_modes(tCgK_tma, 0, 3),
644
+ )
645
+ gKS_tma = cute.local_tile(
646
+ mKS_tma,
647
+ cute.slice_(self.mma_tiler, (0, None, None)),
648
+ (None, None, None),
649
+ )
650
+ tCgKS_tma = thr_mma.partition_B(gKS_tma)
651
+ tKsKS_tma, tKgKS_tma = cpasync.tma_partition(
652
+ tma_ks.atom,
653
+ 0,
654
+ cute.make_layout(1),
655
+ cute.group_modes(sKS_public, 0, 3),
656
+ cute.group_modes(tCgKS_tma, 0, 3),
657
+ )
658
+ tKsKS_tma = cute.filter_zeros(tKsKS_tma)
659
+ tKgKS_tma = cute.filter_zeros(tKgKS_tma)
660
+ sQS = sQS_public
661
+ sKS = sKS_public
662
+
663
+ tCrQ = tiled_mma.make_fragment_A(sQ_public)
664
+ tCrK = tiled_mma.make_fragment_B(sK_public)
665
+ tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2]))
666
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
667
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
668
+
669
+ tmem = utils.TmemAllocator(
670
+ storage.tmem_holding_buf.ptr,
671
+ barrier_for_retrieve=pipeline.NamedBarrier(
672
+ barrier_id=1,
673
+ num_threads=32 * (self.mma_warp_id + 1),
674
+ ),
675
+ )
676
+
677
+ acc_pipeline = common_pipeline.PipelineUmmaAsync.create(
678
+ barrier_storage=storage.acc_mbar_ptr.data_ptr(),
679
+ num_stages=self.num_acc_stage,
680
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
681
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta),
682
+ defer_sync=True,
683
+ )
684
+ acc_producer, _ = acc_pipeline.make_participants()
685
+ q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout)
686
+ k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout)
687
+ if const_expr(self.preordered_q_scale_tma):
688
+ qs_tma_copy_bytes = cute.size_in_bytes(
689
+ self.sf_dtype,
690
+ cute.select(tma_qs.smem_layout, mode=[0, 1, 2]),
691
+ )
692
+ ks_tma_copy_bytes = cute.size_in_bytes(
693
+ self.sf_dtype,
694
+ cute.select(tma_ks.smem_layout, mode=[0, 1, 2]),
695
+ )
696
+ k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes
697
+ q_producer, q_consumer = pipeline.PipelineTmaAsync.create(
698
+ barrier_storage=storage.q_mbar_ptr.data_ptr(),
699
+ num_stages=1,
700
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
701
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
702
+ tx_count=q_tma_copy_bytes,
703
+ defer_sync=True,
704
+ ).make_participants()
705
+ if const_expr(self.preordered_q_scale_tma):
706
+ qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create(
707
+ barrier_storage=storage.qs_mbar_ptr.data_ptr(),
708
+ num_stages=1,
709
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
710
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
711
+ tx_count=qs_tma_copy_bytes,
712
+ defer_sync=True,
713
+ ).make_participants()
714
+ k_producer, k_consumer = pipeline.PipelineTmaAsync.create(
715
+ barrier_storage=storage.k_mbar_ptr.data_ptr(),
716
+ num_stages=self.num_ab_stage,
717
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
718
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
719
+ tx_count=k_pair_tma_copy_bytes,
720
+ defer_sync=True,
721
+ ).make_participants()
722
+ cute.arch.mbarrier_init_fence()
723
+ cute.arch.barrier()
724
+ if warp_idx == self.load_warp_id:
725
+ if group_has_visible:
726
+ q_empty = q_producer.acquire_and_advance()
727
+ if const_expr(self.preordered_q_scale_tma):
728
+ if q_scale_tma_safe:
729
+ qs_empty = qs_producer.acquire_and_advance()
730
+ cute.copy(
731
+ tma_qs.atom,
732
+ tQgQS_tma[(None, q_tile_idx, 0, hq)],
733
+ tQsQS_tma[(None, qs_empty.index)],
734
+ tma_bar_ptr=qs_empty.barrier,
735
+ )
736
+ qs_empty.commit()
737
+ else:
738
+ for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32):
739
+ row = row_base + lane_idx
740
+ q_local = q_tile_start + row
741
+ row_major = row // Int32(32)
742
+ row_atom = row - row_major * Int32(32)
743
+ for group in cutlass.range_constexpr(self.scale_groups):
744
+ group_i = Int32(group)
745
+ mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
746
+ group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
747
+ sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0))
748
+ q_scale_row = q_begin + q_local
749
+ if q_local >= q_len:
750
+ q_scale_row = q_begin
751
+ sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq]
752
+ else:
753
+ for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32):
754
+ row = row_base + lane_idx
755
+ q_local = q_tile_start + row
756
+ row_major = row // Int32(32)
757
+ row_atom = row - row_major * Int32(32)
758
+ for group in cutlass.range_constexpr(self.scale_groups):
759
+ group_i = Int32(group)
760
+ mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
761
+ group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
762
+ sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0))
763
+ q_scale_row = q_begin + q_local
764
+ if q_local >= q_len:
765
+ q_scale_row = q_begin
766
+ sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq]
767
+ cute.copy(
768
+ tma_q.atom,
769
+ tQgQ_tma[(None, q_tile_idx, 0, hq)],
770
+ tQsQ_tma[(None, q_empty.index)],
771
+ tma_bar_ptr=q_empty.barrier,
772
+ )
773
+ q_empty.commit()
774
+
775
+ if warp_idx == self.mma_warp_id:
776
+ tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
777
+ tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
778
+ # Move block scales into TMEM and issue one FP4 GEMM per visible K tile.
779
+ tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa(
780
+ tiled_mma,
781
+ self.mma_tiler,
782
+ self.sf_vec_size,
783
+ cute.slice_(q_scale_smem_layout, (None, None, None, 0)),
784
+ )
785
+ tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb(
786
+ tiled_mma,
787
+ self.mma_tiler,
788
+ self.sf_vec_size,
789
+ cute.slice_(k_scale_smem_layout, (None, None, None, 0)),
790
+ )
791
+ tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype)
792
+ tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype)
793
+ copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype)
794
+ tCsQS_compact = cute.filter_zeros(sQS)
795
+ tCtQS_compact = cute.filter_zeros(tCtQS)
796
+ tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact)
797
+ thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0)
798
+ tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
799
+ tiled_copy_s2t_qs,
800
+ thr_copy_s2t_qs.partition_S(tCsQS_compact),
801
+ )
802
+ tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact)
803
+ tCsKS_compact = cute.filter_zeros(sKS)
804
+ tCtKS_compact = cute.filter_zeros(tCtKS)
805
+ tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact)
806
+ thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0)
807
+ tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
808
+ tiled_copy_s2t_ks,
809
+ thr_copy_s2t_ks.partition_S(tCsKS_compact),
810
+ )
811
+ tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact)
812
+ if group_has_visible:
813
+ q_full = q_consumer.wait_and_advance()
814
+ if const_expr(self.preordered_q_scale_tma):
815
+ if q_scale_tma_safe:
816
+ qs_full = qs_consumer.wait_and_advance()
817
+ qs_full.release()
818
+ q_full.release()
819
+ cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t)
820
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
821
+ q_tile_crd = (None, None, None, 0)
822
+ if const_expr(self.is_causal):
823
+ causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
824
+ causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
825
+ causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
826
+ ktile = Int32(0)
827
+ if causal_group_full:
828
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
829
+ k_pair_full = k_consumer.wait_and_advance()
830
+ acc_empty = acc_producer.acquire_and_advance()
831
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
832
+ k_tile_crd = (None, None, None, k_pair_full.index)
833
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
834
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
835
+ acc_empty.commit()
836
+ k_pair_full.release()
837
+ else:
838
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
839
+ ktile = group_first_ktile + Int32(ktile_inner)
840
+ if ktile < max_k_tiles:
841
+ tile_has_visible = self._tile_has_visible(
842
+ q_tile_start,
843
+ q_tile_last,
844
+ q_len,
845
+ ktile,
846
+ batch_k_tiles,
847
+ causal_offset,
848
+ )
849
+ if tile_has_visible:
850
+ k_pair_full = k_consumer.wait_and_advance()
851
+ acc_empty = acc_producer.acquire_and_advance()
852
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
853
+ k_tile_crd = (None, None, None, k_pair_full.index)
854
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
855
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
856
+ acc_empty.commit()
857
+ k_pair_full.release()
858
+ else:
859
+ k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
860
+ ktile = Int32(0)
861
+ if k_group_full:
862
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
863
+ k_pair_full = k_consumer.wait_and_advance()
864
+ acc_empty = acc_producer.acquire_and_advance()
865
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
866
+ k_tile_crd = (None, None, None, k_pair_full.index)
867
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
868
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
869
+ acc_empty.commit()
870
+ k_pair_full.release()
871
+ else:
872
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
873
+ ktile = group_first_ktile + Int32(ktile_inner)
874
+ if ktile < batch_k_tiles:
875
+ k_pair_full = k_consumer.wait_and_advance()
876
+ acc_empty = acc_producer.acquire_and_advance()
877
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
878
+ k_tile_crd = (None, None, None, k_pair_full.index)
879
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
880
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
881
+ acc_empty.commit()
882
+ k_pair_full.release()
883
+ acc_producer.tail()
884
+
885
+ if warp_idx == self.load_warp_id:
886
+ if group_has_visible:
887
+ load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
888
+ if const_expr(self.is_causal):
889
+ load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
890
+ load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
891
+ ktile = Int32(0)
892
+ if load_group_full:
893
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
894
+ ktile = group_first_ktile + Int32(ktile_inner)
895
+ k_pair_empty = k_producer.acquire_and_advance()
896
+ physical_page = mKvIndices[page_begin + ktile]
897
+ cute.copy(
898
+ tma_k.atom,
899
+ tKgK_tma[(None, 0, 0, hk, physical_page)],
900
+ tKsK_tma[(None, k_pair_empty.index)],
901
+ tma_bar_ptr=k_pair_empty.barrier,
902
+ )
903
+ scale_l = physical_page * heads_k + hk
904
+ cute.copy(
905
+ tma_ks.atom,
906
+ tKgKS_tma[(None, 0, 0, scale_l)],
907
+ tKsKS_tma[(None, k_pair_empty.index)],
908
+ tma_bar_ptr=k_pair_empty.barrier,
909
+ )
910
+ k_pair_empty.commit()
911
+ else:
912
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
913
+ ktile = group_first_ktile + Int32(ktile_inner)
914
+ if ktile < max_k_tiles:
915
+ tile_has_visible = self._tile_has_visible(
916
+ q_tile_start,
917
+ q_tile_last,
918
+ q_len,
919
+ ktile,
920
+ batch_k_tiles,
921
+ causal_offset,
922
+ )
923
+ if tile_has_visible:
924
+ k_pair_empty = k_producer.acquire_and_advance()
925
+ physical_page = mKvIndices[page_begin + ktile]
926
+ cute.copy(
927
+ tma_k.atom,
928
+ tKgK_tma[(None, 0, 0, hk, physical_page)],
929
+ tKsK_tma[(None, k_pair_empty.index)],
930
+ tma_bar_ptr=k_pair_empty.barrier,
931
+ )
932
+ scale_l = physical_page * heads_k + hk
933
+ cute.copy(
934
+ tma_ks.atom,
935
+ tKgKS_tma[(None, 0, 0, scale_l)],
936
+ tKsKS_tma[(None, k_pair_empty.index)],
937
+ tma_bar_ptr=k_pair_empty.barrier,
938
+ )
939
+ k_pair_empty.commit()
940
+ k_producer.tail()
941
+ q_producer.tail()
942
+ if const_expr(self.preordered_q_scale_tma):
943
+ if q_scale_tma_safe:
944
+ qs_producer.tail()
945
+
946
+ if warp_idx < self.mma_warp_id:
947
+ tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
948
+ tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
949
+ # Load accumulators from TMEM, reduce per-row max, and store scores.
950
+ if const_expr(self.use_tmem_load_red):
951
+ copy_atom_t2r = cute.make_copy_atom(
952
+ tcgen05.LdRed32x32bOp(
953
+ tcgen05.Repetition.x128,
954
+ tcgen05.Pack.NONE,
955
+ tcgen05.TmemLoadRedOp.MAX,
956
+ ),
957
+ Float32,
958
+ )
959
+ else:
960
+ copy_atom_t2r = cute.make_copy_atom(
961
+ tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
962
+ Float32,
963
+ )
964
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)])
965
+ thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
966
+ tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
967
+ tTR_cC = thr_copy_t2r.partition_D(tCcC)
968
+ tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32)
969
+ if const_expr(self.use_tmem_load_red):
970
+ tTR_rRed = cute.make_rmem_tensor((1,), Float32)
971
+ q_local_store0 = q_tile_start + epi_tidx
972
+ q_global_store0 = q_begin + q_local_store0
973
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
974
+ q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta)
975
+ q_global_store1 = q_begin + q_local_store1
976
+ if group_has_visible:
977
+ visible_tile_count = Int32(0)
978
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
979
+ ktile = group_first_ktile + Int32(ktile_inner)
980
+ if ktile < max_k_tiles:
981
+ tile_has_visible = self._tile_has_visible(
982
+ q_tile_start,
983
+ q_tile_last,
984
+ q_len,
985
+ ktile,
986
+ batch_k_tiles,
987
+ causal_offset,
988
+ )
989
+ if tile_has_visible:
990
+ epilogue_owns_tile = epi_warpgroup_idx == Int32(
991
+ ktile_inner % self.num_epi_warpgroups
992
+ )
993
+ if epilogue_owns_tile:
994
+ acc_stage_index = visible_tile_count % Int32(self.num_acc_stage)
995
+ acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2)
996
+ tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset)
997
+ k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len
998
+ tile_full = q_tile_full and k_tile_full
999
+ acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase)
1000
+ tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)]
1001
+ if const_expr(self.use_tmem_load_red):
1002
+ cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed])
1003
+ else:
1004
+ cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc)
1005
+ row_max0 = -Float32.inf
1006
+ row_max1 = -Float32.inf
1007
+ if tile_mask_free:
1008
+ if tile_full:
1009
+ if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1010
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1011
+ coord_m, _ = tTR_cC[i]
1012
+ if coord_m == epi_tidx:
1013
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1014
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1015
+ if coord_m == epi_tidx + Int32(self.epi_threads_per_cta):
1016
+ row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
1017
+ else:
1018
+ row_max0 = tTR_rRed[0]
1019
+ else:
1020
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1021
+ coord_m, coord_n = tTR_cC[i]
1022
+ q_local = q_tile_start + coord_m
1023
+ k_local = ktile * Int32(_BLOCK_K) + coord_n
1024
+ if coord_m == epi_tidx and q_local < q_len and k_local < k_len:
1025
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1026
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1027
+ if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len:
1028
+ row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
1029
+ else:
1030
+ if tile_full:
1031
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1032
+ coord_m, coord_n = tTR_cC[i]
1033
+ q_local = q_tile_start + coord_m
1034
+ k_local = ktile * Int32(_BLOCK_K) + coord_n
1035
+ if self._full_tile_coord_visible(
1036
+ coord_m,
1037
+ epi_tidx,
1038
+ q_local,
1039
+ k_local,
1040
+ causal_offset,
1041
+ ):
1042
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1043
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1044
+ if self._full_tile_coord_visible(
1045
+ coord_m,
1046
+ epi_tidx + Int32(self.epi_threads_per_cta),
1047
+ q_local,
1048
+ k_local,
1049
+ causal_offset,
1050
+ ):
1051
+ row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
1052
+ else:
1053
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1054
+ coord_m, coord_n = tTR_cC[i]
1055
+ q_local = q_tile_start + coord_m
1056
+ k_local = ktile * Int32(_BLOCK_K) + coord_n
1057
+ if self._partial_tile_coord_visible(
1058
+ coord_m,
1059
+ epi_tidx,
1060
+ q_local,
1061
+ k_local,
1062
+ q_len,
1063
+ k_len,
1064
+ causal_offset,
1065
+ ):
1066
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1067
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1068
+ if self._partial_tile_coord_visible(
1069
+ coord_m,
1070
+ epi_tidx + Int32(self.epi_threads_per_cta),
1071
+ q_local,
1072
+ k_local,
1073
+ q_len,
1074
+ k_len,
1075
+ causal_offset,
1076
+ ):
1077
+ row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
1078
+ if q_tile_full:
1079
+ mScores[hq, ktile, q_global_store0] = row_max0
1080
+ elif q_local_store0 < q_len:
1081
+ mScores[hq, ktile, q_global_store0] = row_max0
1082
+ if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
1083
+ if q_tile_full:
1084
+ mScores[hq, ktile, q_global_store1] = row_max1
1085
+ elif q_local_store1 < q_len:
1086
+ mScores[hq, ktile, q_global_store1] = row_max1
1087
+ cute.arch.fence_view_async_tmem_load()
1088
+ acc_pipeline.consumer_release_w_index(acc_stage_index)
1089
+ visible_tile_count += Int32(1)
1090
+ else:
1091
+ if const_expr(not self.compact_schedule):
1092
+ if epi_warpgroup_idx == Int32(0):
1093
+ if q_tile_full:
1094
+ mScores[hq, ktile, q_global_store0] = -Float32.inf
1095
+ elif q_local_store0 < q_len:
1096
+ mScores[hq, ktile, q_global_store0] = -Float32.inf
1097
+ else:
1098
+ if const_expr(not self.compact_schedule):
1099
+ if epi_warpgroup_idx == Int32(0):
1100
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1101
+ ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner)
1102
+ if ktile < max_k_tiles:
1103
+ if q_tile_full:
1104
+ mScores[hq, ktile, q_global_store0] = -Float32.inf
1105
+ elif q_local_store0 < q_len:
1106
+ mScores[hq, ktile, q_global_store0] = -Float32.inf
1107
+ cute.arch.barrier()
1108
+ tmem.free(tmem_pool.base_ptr)
1109
+
1110
+ class Fp4IndexerDecodeQPackSm100:
1111
+ """Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage."""
1112
+
1113
+ def __init__(self, *, fmt: str):
1114
+ spec = normalize_fp4_format(fmt)
1115
+ self.fmt = spec.name
1116
+ self.sf_dtype = spec.cutlass_scale_dtype
1117
+ self.scale_groups = spec.scale_groups
1118
+ self.threads_per_cta = 256
1119
+
1120
+ @cute.jit
1121
+ def __call__(
1122
+ self,
1123
+ q_ptr: cute.Pointer,
1124
+ q_scale_ptr: cute.Pointer,
1125
+ q_pack_ptr: cute.Pointer,
1126
+ q_scale_pack_ptr: cute.Pointer,
1127
+ cu_seqlens_q_ptr: cute.Pointer,
1128
+ problem_size: tuple,
1129
+ stream: cuda.CUstream,
1130
+ ):
1131
+ total_q, heads_q, heads_k, batch = problem_size
1132
+ rest_q_m = cute.ceil_div(total_q, 128)
1133
+ rest_g = ceil_div(self.scale_groups, 4)
1134
+ q = cute.make_tensor(
1135
+ q_ptr,
1136
+ cute.make_layout(
1137
+ (total_q, heads_q, _FP4_PACKED_D_BYTES),
1138
+ stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1),
1139
+ ),
1140
+ )
1141
+ q_scale = cute.make_tensor(
1142
+ q_scale_ptr,
1143
+ cute.make_layout(
1144
+ (heads_q, rest_q_m, rest_g, 32, 4, 4),
1145
+ stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1),
1146
+ ),
1147
+ )
1148
+ q_pack_l = batch * heads_k
1149
+ q_pack = cute.make_tensor(
1150
+ q_pack_ptr,
1151
+ cute.make_layout(
1152
+ (q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES),
1153
+ stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1),
1154
+ ),
1155
+ )
1156
+ q_scale_pack = cute.make_tensor(
1157
+ q_scale_pack_ptr,
1158
+ cute.make_layout(
1159
+ (q_pack_l, 1, rest_g, 32, 4, 4),
1160
+ stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1),
1161
+ ),
1162
+ )
1163
+ cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,)))
1164
+ self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch(
1165
+ grid=(q_pack_l, 1, 1),
1166
+ block=[self.threads_per_cta, 1, 1],
1167
+ cluster=(1, 1, 1),
1168
+ stream=stream,
1169
+ )
1170
+
1171
+ @cute.kernel
1172
+ def kernel(
1173
+ self,
1174
+ mQ: cute.Tensor,
1175
+ mQS: cute.Tensor,
1176
+ mQPack: cute.Tensor,
1177
+ mQSPack: cute.Tensor,
1178
+ mCuQ: cute.Tensor,
1179
+ heads_q: Int32,
1180
+ heads_k: Int32,
1181
+ ):
1182
+ tidx, _, _ = cute.arch.thread_idx()
1183
+ q_pack_l, _, _ = cute.arch.block_idx()
1184
+ batch_idx = q_pack_l // heads_k
1185
+ hk = q_pack_l - batch_idx * heads_k
1186
+ q_begin = mCuQ[batch_idx]
1187
+ q_end = mCuQ[batch_idx + 1]
1188
+ q_len = q_end - q_begin
1189
+ qhead_per_kv = heads_q // heads_k
1190
+
1191
+ linear = tidx
1192
+ while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES):
1193
+ row = linear // Int32(_FP4_PACKED_D_BYTES)
1194
+ byte = linear - row * Int32(_FP4_PACKED_D_BYTES)
1195
+ h_in_group = row // Int32(_DECODE_PACK_Q_LEN)
1196
+ q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN)
1197
+ hq = hk * qhead_per_kv + h_in_group
1198
+ if q_local < q_len and h_in_group < qhead_per_kv:
1199
+ mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte]
1200
+ else:
1201
+ mQPack[q_pack_l, row, byte] = cutlass.Uint8(0)
1202
+ linear += Int32(self.threads_per_cta)
1203
+
1204
+ scale_linear = tidx
1205
+ while scale_linear < Int32(_PAGE_SIZE * self.scale_groups):
1206
+ row = scale_linear // Int32(self.scale_groups)
1207
+ group = scale_linear - row * Int32(self.scale_groups)
1208
+ h_in_group = row // Int32(_DECODE_PACK_Q_LEN)
1209
+ q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN)
1210
+ hq = hk * qhead_per_kv + h_in_group
1211
+ q_abs = q_begin + q_local
1212
+ if q_local >= q_len or h_in_group >= qhead_per_kv:
1213
+ q_abs = q_begin
1214
+ hq = hk * qhead_per_kv
1215
+ src_rest_m = q_abs // Int32(128)
1216
+ src_row = q_abs - src_rest_m * Int32(128)
1217
+ src_row_atom = src_row % Int32(32)
1218
+ src_row_major = src_row // Int32(32)
1219
+ dst_row_atom = row % Int32(32)
1220
+ dst_row_major = row // Int32(32)
1221
+ rest_g = group // Int32(4)
1222
+ group_in_rest = group - rest_g * Int32(4)
1223
+ mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[
1224
+ hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest
1225
+ ]
1226
+ scale_linear += Int32(self.threads_per_cta)
1227
+
1228
+
1229
+ class Fp4IndexerDecodePackedQSm100:
1230
+ """Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``."""
1231
+
1232
+ def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False):
1233
+ spec = normalize_fp4_format(fmt)
1234
+ self.fmt = spec.name
1235
+ self.is_causal = bool(causal)
1236
+ self.compact_schedule = bool(compact_schedule)
1237
+ self.use_tmem_load_red = bool(use_tmem_load_red)
1238
+ self.sf_vec_size = spec.sf_vec_size
1239
+ self.sf_dtype = spec.cutlass_scale_dtype
1240
+ self.use_nvfp4 = spec.name == "nvfp4"
1241
+ self.epi_threads_per_cta = 128
1242
+ self.epi_warps_per_group = 4
1243
+ self.num_epi_warpgroups = 2
1244
+ self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups
1245
+ self.load_warp_id = self.mma_warp_id + 1
1246
+ self.threads_per_cta = 384
1247
+ self.num_tmem_alloc_cols = 512
1248
+ self.num_q_stage = 1
1249
+ self.num_acc_stage = 3
1250
+ self.num_ab_stage = 3
1251
+ self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA
1252
+ self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2)
1253
+ self.cta_tile_shape_mnk = self.mma_tiler
1254
+
1255
+ @cute.jit
1256
+ def __call__(
1257
+ self,
1258
+ q_pack_ptr: cute.Pointer,
1259
+ k_ptr: cute.Pointer,
1260
+ q_scale_pack_ptr: cute.Pointer,
1261
+ k_scale_ptr: cute.Pointer,
1262
+ scores_ptr: cute.Pointer,
1263
+ kv_indices_ptr: cute.Pointer,
1264
+ cu_seqlens_q_ptr: cute.Pointer,
1265
+ cu_seqlens_k_ptr: cute.Pointer,
1266
+ cu_page_offsets_ptr: cute.Pointer,
1267
+ qo_offset_ptr: cute.Pointer,
1268
+ problem_size: tuple,
1269
+ stream: cuda.CUstream,
1270
+ ):
1271
+ (
1272
+ _,
1273
+ _,
1274
+ _,
1275
+ _,
1276
+ lk,
1277
+ heads_q,
1278
+ heads_k,
1279
+ batch,
1280
+ max_k_tiles,
1281
+ total_q,
1282
+ has_qo_offset,
1283
+ ) = problem_size
1284
+ page_count = lk // heads_k
1285
+ q_pack_l = batch * heads_k
1286
+ q_tma_tensor = cute.make_tensor(
1287
+ cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE),
1288
+ cute.make_layout(
1289
+ (_PAGE_SIZE, _HEAD_DIM, q_pack_l),
1290
+ stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM),
1291
+ ),
1292
+ )
1293
+ k_tma_tensor = cute.make_tensor(
1294
+ cute.recast_ptr(k_ptr, dtype=_AB_DTYPE),
1295
+ cute.make_layout(
1296
+ (_PAGE_SIZE, _HEAD_DIM, heads_k, page_count),
1297
+ stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM),
1298
+ ),
1299
+ )
1300
+ q_scale_tensor = cute.make_tensor(
1301
+ q_scale_pack_ptr,
1302
+ blockscaled_utils.tile_atom_to_shape_SF(
1303
+ (_PAGE_SIZE, _HEAD_DIM, q_pack_l),
1304
+ self.sf_vec_size,
1305
+ ),
1306
+ )
1307
+ k_scale_tensor = cute.make_tensor(
1308
+ k_scale_ptr,
1309
+ blockscaled_utils.tile_atom_to_shape_SF(
1310
+ (_PAGE_SIZE, _HEAD_DIM, page_count * heads_k),
1311
+ self.sf_vec_size,
1312
+ ),
1313
+ )
1314
+ scores_tensor = cute.make_tensor(
1315
+ scores_ptr,
1316
+ cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)),
1317
+ )
1318
+ kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,)))
1319
+ cu_layout = cute.make_layout((batch + 1,), stride=(1,))
1320
+ cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout)
1321
+ cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout)
1322
+ cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout)
1323
+ qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,)))
1324
+
1325
+ if const_expr(self.use_nvfp4):
1326
+ mma_op = tcgen05.MmaMXF4NVF4Op(
1327
+ self.sf_dtype,
1328
+ (*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
1329
+ tcgen05.CtaGroup.ONE,
1330
+ tcgen05.OperandSource.SMEM,
1331
+ )
1332
+ else:
1333
+ mma_op = tcgen05.MmaMXF4Op(
1334
+ (*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
1335
+ tcgen05.CtaGroup.ONE,
1336
+ tcgen05.OperandSource.SMEM,
1337
+ )
1338
+ tiled_mma = cute.make_tiled_mma(mma_op)
1339
+ q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage)
1340
+ k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage)
1341
+ q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa(
1342
+ tiled_mma,
1343
+ self.mma_tiler,
1344
+ self.sf_vec_size,
1345
+ self.num_q_stage,
1346
+ )
1347
+ k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb(
1348
+ tiled_mma,
1349
+ self.mma_tiler,
1350
+ self.sf_vec_size,
1351
+ self.num_ab_stage,
1352
+ )
1353
+ cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1))
1354
+ tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
1355
+ q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0))
1356
+ k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0))
1357
+ tma_q = cute.nvgpu.make_tiled_tma_atom_A(
1358
+ tma_load_op,
1359
+ q_tma_tensor,
1360
+ q_smem_layout_stage,
1361
+ self.mma_tiler,
1362
+ tiled_mma,
1363
+ cluster_layout_vmnk.shape,
1364
+ )
1365
+ tma_k = cute.nvgpu.make_tiled_tma_atom_B(
1366
+ tma_load_op,
1367
+ k_tma_tensor,
1368
+ k_smem_layout_stage,
1369
+ self.mma_tiler,
1370
+ tiled_mma,
1371
+ cluster_layout_vmnk.shape,
1372
+ )
1373
+ tma_qs = cute.nvgpu.make_tiled_tma_atom_A(
1374
+ tma_load_op,
1375
+ q_scale_tensor,
1376
+ q_scale_smem_layout,
1377
+ self.mma_tiler,
1378
+ tiled_mma,
1379
+ cluster_layout_vmnk.shape,
1380
+ internal_type=cutlass.Int16,
1381
+ )
1382
+ tma_ks = cute.nvgpu.make_tiled_tma_atom_B(
1383
+ tma_load_op,
1384
+ k_scale_tensor,
1385
+ k_scale_smem_layout,
1386
+ self.mma_tiler,
1387
+ tiled_mma,
1388
+ cluster_layout_vmnk.shape,
1389
+ internal_type=cutlass.Int16,
1390
+ )
1391
+ grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta)
1392
+ compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta)
1393
+ if const_expr(self.compact_schedule):
1394
+ grid = (compact_k_groups, heads_k, 1)
1395
+ else:
1396
+ grid = (grid_k_groups, batch * heads_k, 1)
1397
+ self.kernel(
1398
+ tiled_mma,
1399
+ tma_q,
1400
+ tma_qs,
1401
+ tma_k,
1402
+ tma_ks,
1403
+ scores_tensor,
1404
+ kv_indices_tensor,
1405
+ cu_q_tensor,
1406
+ cu_k_tensor,
1407
+ cu_page_offsets_tensor,
1408
+ qo_offset_tensor,
1409
+ q_smem_layout,
1410
+ k_smem_layout,
1411
+ q_scale_smem_layout,
1412
+ k_scale_smem_layout,
1413
+ heads_q,
1414
+ heads_k,
1415
+ batch,
1416
+ has_qo_offset,
1417
+ max_k_tiles,
1418
+ ).launch(
1419
+ grid=grid,
1420
+ block=[self.threads_per_cta, 1, 1],
1421
+ cluster=(1, 1, 1),
1422
+ stream=stream,
1423
+ )
1424
+
1425
+ @cute.jit
1426
+ def _group_has_visible(
1427
+ self,
1428
+ q_len: Int32,
1429
+ group_first_ktile: Int32,
1430
+ batch_k_tiles: Int32,
1431
+ causal_offset: Int32,
1432
+ ):
1433
+ visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles
1434
+ if const_expr(self.is_causal):
1435
+ visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
1436
+ return visible
1437
+
1438
+ @cute.jit
1439
+ def _tile_has_visible(
1440
+ self,
1441
+ q_len: Int32,
1442
+ ktile: Int32,
1443
+ batch_k_tiles: Int32,
1444
+ causal_offset: Int32,
1445
+ ):
1446
+ visible = ktile < batch_k_tiles
1447
+ if const_expr(self.is_causal):
1448
+ visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
1449
+ return visible
1450
+
1451
+ @cute.jit
1452
+ def _tile_mask_free(self, ktile: Int32, causal_offset: Int32):
1453
+ if const_expr(self.is_causal):
1454
+ return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset
1455
+ return True
1456
+
1457
+ @cute.jit
1458
+ def _packed_coord_visible(
1459
+ self,
1460
+ coord_m: Int32,
1461
+ target_m: Int32,
1462
+ h_in_group: Int32,
1463
+ qhead_per_kv: Int32,
1464
+ q_local: Int32,
1465
+ q_len: Int32,
1466
+ k_local: Int32,
1467
+ k_len: Int32,
1468
+ causal_offset: Int32,
1469
+ ):
1470
+ visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len
1471
+ if const_expr(self.is_causal):
1472
+ visible = visible and k_local <= q_local + causal_offset
1473
+ return visible
1474
+
1475
+ @cute.kernel
1476
+ def kernel(
1477
+ self,
1478
+ tiled_mma: cute.TiledMma,
1479
+ tma_q: cpasync.TmaInfo,
1480
+ tma_qs: cpasync.TmaInfo,
1481
+ tma_k: cpasync.TmaInfo,
1482
+ tma_ks: cpasync.TmaInfo,
1483
+ mScores: cute.Tensor,
1484
+ mKvIndices: cute.Tensor,
1485
+ mCuQ: cute.Tensor,
1486
+ mCuK: cute.Tensor,
1487
+ mCuPages: cute.Tensor,
1488
+ mQoOffset: cute.Tensor,
1489
+ q_smem_layout: cute.ComposedLayout,
1490
+ k_smem_layout: cute.ComposedLayout,
1491
+ q_scale_smem_layout: cute.Layout,
1492
+ k_scale_smem_layout: cute.Layout,
1493
+ heads_q: Int32,
1494
+ heads_k: Int32,
1495
+ batch: Int32,
1496
+ has_qo_offset: Int32,
1497
+ max_k_tiles: Int32,
1498
+ ):
1499
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1500
+ tidx, _, _ = cute.arch.thread_idx()
1501
+ epi_tidx = tidx % Int32(self.epi_threads_per_cta)
1502
+ epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group)
1503
+ task_x, task_y, _ = cute.arch.block_idx()
1504
+ task_valid = True
1505
+ batch_idx = Int32(0)
1506
+ hk = Int32(0)
1507
+ ktile_group = Int32(0)
1508
+ q_l = Int32(0)
1509
+ if const_expr(self.compact_schedule):
1510
+ hk = task_y
1511
+ group_base = Int32(0)
1512
+ scan_batch = Int32(0)
1513
+ task_valid = False
1514
+ while scan_batch < batch and not task_valid:
1515
+ batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch]
1516
+ batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta)
1517
+ task_valid = task_x < group_base + batch_groups
1518
+ if not task_valid:
1519
+ group_base += batch_groups
1520
+ scan_batch += Int32(1)
1521
+ if task_valid:
1522
+ batch_idx = scan_batch
1523
+ ktile_group = task_x - group_base
1524
+ q_l = batch_idx * heads_k + hk
1525
+ else:
1526
+ ktile_group = task_x
1527
+ q_l = task_y
1528
+ batch_idx = q_l // heads_k
1529
+ hk = q_l - batch_idx * heads_k
1530
+ qhead_per_kv = heads_q // heads_k
1531
+ q_begin = mCuQ[batch_idx]
1532
+ q_end = mCuQ[batch_idx + 1]
1533
+ k_begin = mCuK[batch_idx]
1534
+ k_end = mCuK[batch_idx + 1]
1535
+ q_len = q_end - q_begin
1536
+ k_len = k_end - k_begin
1537
+ if const_expr(self.compact_schedule):
1538
+ if not task_valid:
1539
+ q_len = Int32(0)
1540
+ k_len = Int32(0)
1541
+ page_begin = mCuPages[batch_idx]
1542
+ batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE)
1543
+ causal_offset = Int32(0)
1544
+ if const_expr(self.is_causal):
1545
+ causal_offset = k_len - q_len
1546
+ if has_qo_offset != 0:
1547
+ causal_offset = mQoOffset[batch_idx]
1548
+ group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta)
1549
+ group_has_visible = self._group_has_visible(
1550
+ q_len,
1551
+ group_first_ktile,
1552
+ batch_k_tiles,
1553
+ causal_offset,
1554
+ )
1555
+
1556
+ @cute.struct
1557
+ class SharedStorage:
1558
+ acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
1559
+ q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
1560
+ k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
1561
+ tmem_holding_buf: cutlass.Int32
1562
+
1563
+ smem = utils.SmemAllocator()
1564
+ storage = smem.allocate(SharedStorage)
1565
+ sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner)
1566
+ sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner)
1567
+ sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128)
1568
+ sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128)
1569
+ mQ_tma = tma_q.tma_tensor
1570
+ mQS_tma = tma_qs.tma_tensor
1571
+ mK_tma = tma_k.tma_tensor
1572
+ mKS_tma = tma_ks.tma_tensor
1573
+ thr_mma = tiled_mma.get_slice(0)
1574
+ tCrQ = tiled_mma.make_fragment_A(sQ_public)
1575
+ tCrK = tiled_mma.make_fragment_B(sK_public)
1576
+ tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2]))
1577
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
1578
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
1579
+
1580
+ gQ_tma = cute.local_tile(
1581
+ mQ_tma,
1582
+ cute.slice_(self.mma_tiler, (None, 0, None)),
1583
+ (None, None, None),
1584
+ )
1585
+ tCgQ_tma = thr_mma.partition_A(gQ_tma)
1586
+ tQsQ_tma, tQgQ_tma = cpasync.tma_partition(
1587
+ tma_q.atom,
1588
+ 0,
1589
+ cute.make_layout(1),
1590
+ cute.group_modes(sQ_public, 0, 3),
1591
+ cute.group_modes(tCgQ_tma, 0, 3),
1592
+ )
1593
+ gQS_tma = cute.local_tile(
1594
+ mQS_tma,
1595
+ cute.slice_(self.mma_tiler, (None, 0, None)),
1596
+ (None, None, None),
1597
+ )
1598
+ tCgQS_tma = thr_mma.partition_A(gQS_tma)
1599
+ tQsQS_tma, tQgQS_tma = cpasync.tma_partition(
1600
+ tma_qs.atom,
1601
+ 0,
1602
+ cute.make_layout(1),
1603
+ cute.group_modes(sQS_public, 0, 3),
1604
+ cute.group_modes(tCgQS_tma, 0, 3),
1605
+ )
1606
+ tQsQS_tma = cute.filter_zeros(tQsQS_tma)
1607
+ tQgQS_tma = cute.filter_zeros(tQgQS_tma)
1608
+ gK_tma = cute.local_tile(
1609
+ mK_tma,
1610
+ cute.slice_(self.mma_tiler, (0, None, None)),
1611
+ (None, None, None, None),
1612
+ )
1613
+ tCgK_tma = thr_mma.partition_B(gK_tma)
1614
+ tKsK_tma, tKgK_tma = cpasync.tma_partition(
1615
+ tma_k.atom,
1616
+ 0,
1617
+ cute.make_layout(1),
1618
+ cute.group_modes(sK_public, 0, 3),
1619
+ cute.group_modes(tCgK_tma, 0, 3),
1620
+ )
1621
+ gKS_tma = cute.local_tile(
1622
+ mKS_tma,
1623
+ cute.slice_(self.mma_tiler, (0, None, None)),
1624
+ (None, None, None),
1625
+ )
1626
+ tCgKS_tma = thr_mma.partition_B(gKS_tma)
1627
+ tKsKS_tma, tKgKS_tma = cpasync.tma_partition(
1628
+ tma_ks.atom,
1629
+ 0,
1630
+ cute.make_layout(1),
1631
+ cute.group_modes(sKS_public, 0, 3),
1632
+ cute.group_modes(tCgKS_tma, 0, 3),
1633
+ )
1634
+ tKsKS_tma = cute.filter_zeros(tKsKS_tma)
1635
+ tKgKS_tma = cute.filter_zeros(tKgKS_tma)
1636
+
1637
+ tmem = utils.TmemAllocator(
1638
+ storage.tmem_holding_buf.ptr,
1639
+ barrier_for_retrieve=pipeline.NamedBarrier(
1640
+ barrier_id=1,
1641
+ num_threads=32 * (self.mma_warp_id + 1),
1642
+ ),
1643
+ )
1644
+ acc_pipeline = common_pipeline.PipelineUmmaAsync.create(
1645
+ barrier_storage=storage.acc_mbar_ptr.data_ptr(),
1646
+ num_stages=self.num_acc_stage,
1647
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
1648
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta),
1649
+ defer_sync=True,
1650
+ )
1651
+ acc_producer, _ = acc_pipeline.make_participants()
1652
+ q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout)
1653
+ qs_tma_copy_bytes = cute.size_in_bytes(
1654
+ self.sf_dtype,
1655
+ cute.select(tma_qs.smem_layout, mode=[0, 1, 2]),
1656
+ )
1657
+ k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout)
1658
+ ks_tma_copy_bytes = cute.size_in_bytes(
1659
+ self.sf_dtype,
1660
+ cute.select(tma_ks.smem_layout, mode=[0, 1, 2]),
1661
+ )
1662
+ q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes
1663
+ k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes
1664
+ q_producer, q_consumer = pipeline.PipelineTmaAsync.create(
1665
+ barrier_storage=storage.q_mbar_ptr.data_ptr(),
1666
+ num_stages=1,
1667
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
1668
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
1669
+ tx_count=q_pair_tma_copy_bytes,
1670
+ defer_sync=True,
1671
+ ).make_participants()
1672
+ k_producer, k_consumer = pipeline.PipelineTmaAsync.create(
1673
+ barrier_storage=storage.k_mbar_ptr.data_ptr(),
1674
+ num_stages=self.num_ab_stage,
1675
+ producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
1676
+ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
1677
+ tx_count=k_pair_tma_copy_bytes,
1678
+ defer_sync=True,
1679
+ ).make_participants()
1680
+ cute.arch.mbarrier_init_fence()
1681
+ cute.arch.barrier()
1682
+
1683
+ if warp_idx == self.load_warp_id:
1684
+ if group_has_visible:
1685
+ q_pair_empty = q_producer.acquire_and_advance()
1686
+ cute.copy(
1687
+ tma_q.atom,
1688
+ tQgQ_tma[(None, 0, 0, q_l)],
1689
+ tQsQ_tma[(None, q_pair_empty.index)],
1690
+ tma_bar_ptr=q_pair_empty.barrier,
1691
+ )
1692
+ cute.copy(
1693
+ tma_qs.atom,
1694
+ tQgQS_tma[(None, 0, 0, q_l)],
1695
+ tQsQS_tma[(None, q_pair_empty.index)],
1696
+ tma_bar_ptr=q_pair_empty.barrier,
1697
+ )
1698
+ q_pair_empty.commit()
1699
+ load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
1700
+ if const_expr(self.is_causal):
1701
+ load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
1702
+ load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
1703
+ if load_group_full:
1704
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1705
+ ktile = group_first_ktile + Int32(ktile_inner)
1706
+ k_pair_empty = k_producer.acquire_and_advance()
1707
+ physical_page = mKvIndices[page_begin + ktile]
1708
+ cute.copy(
1709
+ tma_k.atom,
1710
+ tKgK_tma[(None, 0, 0, hk, physical_page)],
1711
+ tKsK_tma[(None, k_pair_empty.index)],
1712
+ tma_bar_ptr=k_pair_empty.barrier,
1713
+ )
1714
+ scale_l = physical_page * heads_k + hk
1715
+ cute.copy(
1716
+ tma_ks.atom,
1717
+ tKgKS_tma[(None, 0, 0, scale_l)],
1718
+ tKsKS_tma[(None, k_pair_empty.index)],
1719
+ tma_bar_ptr=k_pair_empty.barrier,
1720
+ )
1721
+ k_pair_empty.commit()
1722
+ else:
1723
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1724
+ ktile = group_first_ktile + Int32(ktile_inner)
1725
+ if ktile < max_k_tiles:
1726
+ tile_has_visible = self._tile_has_visible(
1727
+ q_len,
1728
+ ktile,
1729
+ batch_k_tiles,
1730
+ causal_offset,
1731
+ )
1732
+ if tile_has_visible:
1733
+ k_pair_empty = k_producer.acquire_and_advance()
1734
+ physical_page = mKvIndices[page_begin + ktile]
1735
+ cute.copy(
1736
+ tma_k.atom,
1737
+ tKgK_tma[(None, 0, 0, hk, physical_page)],
1738
+ tKsK_tma[(None, k_pair_empty.index)],
1739
+ tma_bar_ptr=k_pair_empty.barrier,
1740
+ )
1741
+ scale_l = physical_page * heads_k + hk
1742
+ cute.copy(
1743
+ tma_ks.atom,
1744
+ tKgKS_tma[(None, 0, 0, scale_l)],
1745
+ tKsKS_tma[(None, k_pair_empty.index)],
1746
+ tma_bar_ptr=k_pair_empty.barrier,
1747
+ )
1748
+ k_pair_empty.commit()
1749
+ k_producer.tail()
1750
+ q_producer.tail()
1751
+
1752
+ if warp_idx == self.mma_warp_id:
1753
+ tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
1754
+ tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
1755
+ tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa(
1756
+ tiled_mma,
1757
+ self.mma_tiler,
1758
+ self.sf_vec_size,
1759
+ cute.slice_(q_scale_smem_layout, (None, None, None, 0)),
1760
+ )
1761
+ tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb(
1762
+ tiled_mma,
1763
+ self.mma_tiler,
1764
+ self.sf_vec_size,
1765
+ cute.slice_(k_scale_smem_layout, (None, None, None, 0)),
1766
+ )
1767
+ tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype)
1768
+ tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype)
1769
+ copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype)
1770
+ tCsQS_compact = cute.filter_zeros(sQS_public)
1771
+ tCtQS_compact = cute.filter_zeros(tCtQS)
1772
+ tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact)
1773
+ thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0)
1774
+ tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
1775
+ tiled_copy_s2t_qs,
1776
+ thr_copy_s2t_qs.partition_S(tCsQS_compact),
1777
+ )
1778
+ tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact)
1779
+ tCsKS_compact = cute.filter_zeros(sKS_public)
1780
+ tCtKS_compact = cute.filter_zeros(tCtKS)
1781
+ tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact)
1782
+ thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0)
1783
+ tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
1784
+ tiled_copy_s2t_ks,
1785
+ thr_copy_s2t_ks.partition_S(tCsKS_compact),
1786
+ )
1787
+ tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact)
1788
+ if group_has_visible:
1789
+ q_pair_full = q_consumer.wait_and_advance()
1790
+ q_pair_full.release()
1791
+ cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t)
1792
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
1793
+ q_tile_crd = (None, None, None, 0)
1794
+ if const_expr(self.is_causal):
1795
+ causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
1796
+ causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
1797
+ causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
1798
+ if causal_group_full:
1799
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1800
+ k_pair_full = k_consumer.wait_and_advance()
1801
+ acc_empty = acc_producer.acquire_and_advance()
1802
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
1803
+ k_tile_crd = (None, None, None, k_pair_full.index)
1804
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
1805
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
1806
+ acc_empty.commit()
1807
+ k_pair_full.release()
1808
+ else:
1809
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1810
+ ktile = group_first_ktile + Int32(ktile_inner)
1811
+ if ktile < max_k_tiles:
1812
+ tile_has_visible = self._tile_has_visible(
1813
+ q_len,
1814
+ ktile,
1815
+ batch_k_tiles,
1816
+ causal_offset,
1817
+ )
1818
+ if tile_has_visible:
1819
+ k_pair_full = k_consumer.wait_and_advance()
1820
+ acc_empty = acc_producer.acquire_and_advance()
1821
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
1822
+ k_tile_crd = (None, None, None, k_pair_full.index)
1823
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
1824
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
1825
+ acc_empty.commit()
1826
+ k_pair_full.release()
1827
+ else:
1828
+ k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
1829
+ if k_group_full:
1830
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1831
+ k_pair_full = k_consumer.wait_and_advance()
1832
+ acc_empty = acc_producer.acquire_and_advance()
1833
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
1834
+ k_tile_crd = (None, None, None, k_pair_full.index)
1835
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
1836
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
1837
+ acc_empty.commit()
1838
+ k_pair_full.release()
1839
+ else:
1840
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1841
+ ktile = group_first_ktile + Int32(ktile_inner)
1842
+ if ktile < batch_k_tiles:
1843
+ k_pair_full = k_consumer.wait_and_advance()
1844
+ acc_empty = acc_producer.acquire_and_advance()
1845
+ cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
1846
+ k_tile_crd = (None, None, None, k_pair_full.index)
1847
+ tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
1848
+ cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
1849
+ acc_empty.commit()
1850
+ k_pair_full.release()
1851
+ acc_producer.tail()
1852
+
1853
+ if warp_idx < self.mma_warp_id:
1854
+ tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
1855
+ tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
1856
+ if const_expr(self.use_tmem_load_red):
1857
+ copy_atom_t2r = cute.make_copy_atom(
1858
+ tcgen05.LdRed32x32bOp(
1859
+ tcgen05.Repetition.x128,
1860
+ tcgen05.Pack.NONE,
1861
+ tcgen05.TmemLoadRedOp.MAX,
1862
+ ),
1863
+ Float32,
1864
+ )
1865
+ else:
1866
+ copy_atom_t2r = cute.make_copy_atom(
1867
+ tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
1868
+ Float32,
1869
+ )
1870
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)])
1871
+ thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
1872
+ tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
1873
+ tTR_cC = thr_copy_t2r.partition_D(tCcC)
1874
+ tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32)
1875
+ if const_expr(self.use_tmem_load_red):
1876
+ tTR_rRed = cute.make_rmem_tensor((1,), Float32)
1877
+ h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN)
1878
+ q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN)
1879
+ h_global_store = hk * qhead_per_kv + h_store
1880
+ q_global_store = q_begin + q_local_store
1881
+ if group_has_visible:
1882
+ visible_tile_count = Int32(0)
1883
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1884
+ ktile = group_first_ktile + Int32(ktile_inner)
1885
+ if ktile < max_k_tiles:
1886
+ tile_has_visible = self._tile_has_visible(
1887
+ q_len,
1888
+ ktile,
1889
+ batch_k_tiles,
1890
+ causal_offset,
1891
+ )
1892
+ if tile_has_visible:
1893
+ epilogue_owns_tile = epi_warpgroup_idx == Int32(
1894
+ ktile_inner % self.num_epi_warpgroups
1895
+ )
1896
+ if epilogue_owns_tile:
1897
+ acc_stage_index = visible_tile_count % Int32(self.num_acc_stage)
1898
+ acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2)
1899
+ tile_mask_free = self._tile_mask_free(ktile, causal_offset)
1900
+ k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len
1901
+ q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN)
1902
+ tile_full = q_pack_full and k_tile_full
1903
+ acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase)
1904
+ tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)]
1905
+ if const_expr(self.use_tmem_load_red):
1906
+ cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed])
1907
+ else:
1908
+ cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc)
1909
+ row_max0 = -Float32.inf
1910
+ if tile_mask_free and tile_full:
1911
+ if const_expr(self.use_tmem_load_red):
1912
+ row_max0 = tTR_rRed[0]
1913
+ else:
1914
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1915
+ coord_m, _ = tTR_cC[i]
1916
+ if coord_m == epi_tidx:
1917
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1918
+ else:
1919
+ for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
1920
+ coord_m, coord_n = tTR_cC[i]
1921
+ h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN)
1922
+ q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN)
1923
+ k_local = ktile * Int32(_BLOCK_K) + coord_n
1924
+ valid = self._packed_coord_visible(
1925
+ coord_m,
1926
+ epi_tidx,
1927
+ h_in_group,
1928
+ qhead_per_kv,
1929
+ q_local,
1930
+ q_len,
1931
+ k_local,
1932
+ k_len,
1933
+ causal_offset,
1934
+ )
1935
+ if valid:
1936
+ row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
1937
+ if h_store < qhead_per_kv and q_local_store < q_len:
1938
+ mScores[h_global_store, ktile, q_global_store] = row_max0
1939
+ cute.arch.fence_view_async_tmem_load()
1940
+ acc_pipeline.consumer_release_w_index(acc_stage_index)
1941
+ visible_tile_count += Int32(1)
1942
+ else:
1943
+ if const_expr(not self.compact_schedule):
1944
+ if epi_warpgroup_idx == Int32(0):
1945
+ if h_store < qhead_per_kv and q_local_store < q_len:
1946
+ mScores[h_global_store, ktile, q_global_store] = -Float32.inf
1947
+ else:
1948
+ if const_expr(not self.compact_schedule):
1949
+ if epi_warpgroup_idx == Int32(0):
1950
+ for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
1951
+ ktile = group_first_ktile + Int32(ktile_inner)
1952
+ if ktile < max_k_tiles:
1953
+ if h_store < qhead_per_kv and q_local_store < q_len:
1954
+ mScores[h_global_store, ktile, q_global_store] = -Float32.inf
1955
+ cute.arch.barrier()
1956
+ tmem.free(tmem_pool.base_ptr)
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """SM100 forward kernels and combine paths."""
5
+
6
+ from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100
7
+
8
+ __all__ = ["SparseAttentionForwardNvfp4KvSm100"]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py ADDED
@@ -0,0 +1,1498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Sparse forward combine kernel and public launcher.
5
+
6
+ This keeps the local fake-layout -> real-layout epilogue needed by the lean
7
+ sparse forward path.
8
+ """
9
+
10
+ # Modified Step 7: O_out write with SMEM fake->real column permutation.
11
+ # O_partial dim is in STG.128 fake layout; O_out dim is real layout.
12
+ import math
13
+ from typing import Type, Optional
14
+ from functools import partial
15
+
16
+ import cuda.bindings.driver as cuda
17
+
18
+ import cutlass
19
+ import cutlass.cute as cute
20
+ import torch
21
+ from cutlass.cute.nvgpu import cpasync
22
+ from cutlass import Float32, Int32, Int64, Boolean, const_expr
23
+
24
+ from ....src.common import utils
25
+ from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map
26
+ from ....src.common.seqlen_info import SeqlenInfo
27
+ from cutlass.cute import FastDivmodDivisor
28
+
29
+ from ....src.common.pack_gqa import PackGQAComb
30
+ from ....src.common.tma_utils import (
31
+ stg128_fake_col_to_real_col,
32
+ stg128_fp8_fake_col_to_real_col,
33
+ stg128_half_fake_col_to_real_col,
34
+ )
35
+
36
+
37
+ class SparseAttentionForwardCombine:
38
+ def __init__(
39
+ self,
40
+ dtype: Type[cutlass.Numeric],
41
+ dtype_partial: Type[cutlass.Numeric],
42
+ head_dim: int,
43
+ tile_m: int = 8,
44
+ k_block_size: int = 64,
45
+ topk: int = 16,
46
+ num_threads: int = 256,
47
+ stages: int = 4,
48
+ use_pdl: bool = False,
49
+ min_blocks_per_mp: int = 0,
50
+ ):
51
+ """
52
+ Forward combine kernel for split attention computation.
53
+
54
+ :param dtype: output data type
55
+ :param dtype_partial: partial accumulation data type
56
+ :param head_dim: head dimension
57
+ :param tile_m: m block size
58
+ :param k_block_size: k block size
59
+ :param topk: exact number of split partials
60
+ :param num_threads: number of threads
61
+ :param varlen: whether using variable length sequences
62
+ :param stages: number of pipeline stages
63
+ """
64
+ self.dtype = dtype
65
+ self.dtype_partial = dtype_partial
66
+ self.head_dim = head_dim
67
+ self.tile_m = tile_m
68
+ self.k_block_size = k_block_size
69
+ self.topk = topk
70
+ self.num_threads = num_threads
71
+ self.is_even_k = head_dim % k_block_size == 0
72
+ self.stages = stages
73
+ self.use_pdl = use_pdl
74
+ self.min_blocks_per_mp = min_blocks_per_mp
75
+ self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16)
76
+ self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN
77
+
78
+ @staticmethod
79
+ def can_implement(
80
+ dtype,
81
+ dtype_partial,
82
+ head_dim,
83
+ tile_m,
84
+ k_block_size,
85
+ topk,
86
+ num_threads,
87
+ ) -> bool:
88
+ """Check if the kernel can be implemented with the given parameters."""
89
+ if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
90
+ return False
91
+ if dtype_partial not in [
92
+ cutlass.Float16,
93
+ cutlass.BFloat16,
94
+ cutlass.Float8E4M3FN,
95
+ Float32,
96
+ ]:
97
+ return False
98
+ if head_dim % 8 != 0:
99
+ return False
100
+ if num_threads % 32 != 0:
101
+ return False
102
+ if tile_m % 8 != 0:
103
+ return False
104
+ if topk > 256:
105
+ return False
106
+ if (tile_m * topk) % num_threads != 0:
107
+ return False
108
+ return True
109
+
110
+ def _setup_attributes(self):
111
+ # GMEM copy setup for O partial
112
+ universal_copy_bits = 128
113
+ async_copy_elems = universal_copy_bits // self.dtype_partial.width
114
+ assert self.k_block_size % async_copy_elems == 0
115
+
116
+ k_block_gmem = (
117
+ 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
118
+ )
119
+ gmem_threads_per_row = k_block_gmem // async_copy_elems
120
+ assert self.num_threads % gmem_threads_per_row == 0
121
+
122
+ # Async copy atom for O partial load
123
+ atom_async_copy_partial = cute.make_copy_atom(
124
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
125
+ self.dtype_partial,
126
+ num_bits_per_copy=universal_copy_bits,
127
+ )
128
+ tOpartial_layout = cute.make_ordered_layout(
129
+ (self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
130
+ order=(1, 0),
131
+ )
132
+ vOpartial_layout = cute.make_layout((1, async_copy_elems))
133
+ self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
134
+ atom_async_copy_partial, tOpartial_layout, vOpartial_layout
135
+ )
136
+
137
+ # GMEM copy setup for final O (use universal copy for store).
138
+ # Keep this independent from O_partial: fp8 partial uses 16 elements
139
+ # per 128b transaction, while bf16/fp16 O stores must remain 8-wide.
140
+ output_copy_elems = universal_copy_bits // self.dtype.width
141
+ assert self.k_block_size % output_copy_elems == 0
142
+ gmem_threads_per_row_o = k_block_gmem // output_copy_elems
143
+ assert self.num_threads % gmem_threads_per_row_o == 0
144
+ atom_universal_copy = cute.make_copy_atom(
145
+ cute.nvgpu.CopyUniversalOp(),
146
+ self.dtype,
147
+ num_bits_per_copy=universal_copy_bits,
148
+ )
149
+ tO_layout = cute.make_ordered_layout(
150
+ (self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o),
151
+ order=(1, 0),
152
+ )
153
+ vO_layout = cute.make_layout((1, output_copy_elems))
154
+ self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
155
+ atom_universal_copy,
156
+ tO_layout,
157
+ vO_layout,
158
+ )
159
+ # LSE copy setup with async copy (alignment = 1)
160
+ lse_copy_bits = Float32.width # 1 element per copy, width is in bits
161
+ m_block_smem = (
162
+ 128
163
+ if self.tile_m % 128 == 0
164
+ else (
165
+ 64
166
+ if self.tile_m % 64 == 0
167
+ else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
168
+ )
169
+ )
170
+ gmem_threads_per_row_lse = m_block_smem
171
+ assert self.num_threads % gmem_threads_per_row_lse == 0
172
+
173
+ # Async copy atom for LSE load
174
+ atom_async_copy_lse = cute.make_copy_atom(
175
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
176
+ Float32,
177
+ num_bits_per_copy=lse_copy_bits,
178
+ )
179
+ tLSE_layout = cute.make_ordered_layout(
180
+ (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
181
+ order=(1, 0),
182
+ )
183
+ vLSE_layout = cute.make_layout(1)
184
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
185
+ atom_async_copy_lse, tLSE_layout, vLSE_layout
186
+ )
187
+
188
+ # ///////////////////////////////////////////////////////////////////////////////
189
+ # Shared memory
190
+ # ///////////////////////////////////////////////////////////////////////////////
191
+
192
+ # Shared memory to register copy for LSE
193
+ self.smem_threads_per_col_lse = self.num_threads // m_block_smem
194
+ assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
195
+
196
+ s2r_layout_atom_lse = cute.make_ordered_layout(
197
+ (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
198
+ order=(0, 1),
199
+ )
200
+ self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
201
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
202
+ s2r_layout_atom_lse,
203
+ cute.make_layout(1),
204
+ )
205
+
206
+ # LSE shared memory layout with swizzling to avoid bank conflicts
207
+ # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
208
+ if const_expr(m_block_smem == 8):
209
+ smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
210
+ elif const_expr(m_block_smem == 16):
211
+ smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
212
+ else:
213
+ smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
214
+ lse_atom_splits = min(self.topk, 8)
215
+ smem_layout_atom_lse = cute.make_composed_layout(
216
+ smem_lse_swizzle,
217
+ 0,
218
+ cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)),
219
+ )
220
+ self.smem_layout_lse = cute.tile_to_shape(
221
+ smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1)
222
+ )
223
+
224
+ # O_partial staging layout.
225
+ if const_expr(
226
+ self.dtype_partial
227
+ in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN]
228
+ ):
229
+ smem_layout_atom_o = _get_cpasync_smem_layout_atom(
230
+ self.dtype_partial, self.k_block_size
231
+ )
232
+ self.smem_layout_o = cute.tile_to_shape(
233
+ smem_layout_atom_o,
234
+ (self.tile_m, self.k_block_size, self.stages),
235
+ (0, 1, 2),
236
+ )
237
+ else:
238
+ self.smem_layout_o = cute.make_ordered_layout(
239
+ (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
240
+ )
241
+
242
+ @cute.jit
243
+ def __call__(
244
+ self,
245
+ mO_partial: cute.Tensor,
246
+ mLSE_partial: cute.Tensor,
247
+ mO: cute.Tensor,
248
+ mLSE: Optional[cute.Tensor] = None,
249
+ mLSE_temperature_partial: Optional[cute.Tensor] = None,
250
+ mLSE_temperature: Optional[cute.Tensor] = None,
251
+ cu_seqlens: Optional[cute.Tensor] = None,
252
+ seqused: Optional[cute.Tensor] = None,
253
+ num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
254
+ varlen_batch_idx: Optional[cute.Tensor] = None,
255
+ semaphore_to_reset: Optional[cute.Tensor] = None,
256
+ mSplitCounts: Optional[cute.Tensor] = None,
257
+ mOutputScale: Optional[cute.Tensor] = None,
258
+ qhead_per_kvhead: Int32 = Int32(1),
259
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
260
+ stream: cuda.CUstream = None,
261
+ ):
262
+ # Type checking
263
+ if const_expr(not (mO_partial.element_type == self.dtype_partial)):
264
+ raise TypeError("O partial tensor must match dtype_partial")
265
+ if const_expr(not (mO.element_type == self.dtype)):
266
+ raise TypeError("O tensor must match dtype")
267
+ if const_expr(mLSE_partial.element_type not in [Float32]):
268
+ raise TypeError("LSE partial tensor must be Float32")
269
+ if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
270
+ raise TypeError("LSE tensor must be Float32")
271
+ if const_expr(
272
+ mLSE_temperature_partial is not None
273
+ and mLSE_temperature_partial.element_type not in [Float32]
274
+ ):
275
+ raise TypeError("temperature LSE partial tensor must be Float32")
276
+ if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]):
277
+ raise TypeError("temperature LSE tensor must be Float32")
278
+ if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)):
279
+ raise ValueError(
280
+ "temperature LSE partial and output tensors must either both be provided or both be None"
281
+ )
282
+
283
+ # Shape validation - input tensors are in user format, need to be converted to kernel format
284
+ if const_expr(len(mO_partial.shape) not in [4, 5]):
285
+ raise ValueError(
286
+ "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
287
+ )
288
+ if const_expr(len(mLSE_partial.shape) not in [3, 4]):
289
+ raise ValueError(
290
+ "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
291
+ )
292
+ if const_expr(len(mO.shape) not in [3, 4]):
293
+ raise ValueError(
294
+ "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
295
+ )
296
+ if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
297
+ raise ValueError(
298
+ "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
299
+ )
300
+ if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]):
301
+ raise ValueError(
302
+ "temperature LSE partial tensor must have 3 or 4 dimensions: "
303
+ "(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
304
+ )
305
+ if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]):
306
+ raise ValueError(
307
+ "temperature LSE tensor must have 2 or 3 dimensions: "
308
+ "(batch, seqlen, nheads) or (total_q, nheads)"
309
+ )
310
+ if const_expr(mSplitCounts is not None):
311
+ if const_expr(mSplitCounts.element_type not in [Int32]):
312
+ raise TypeError("split_counts tensor must be Int32")
313
+ if const_expr(cu_seqlens is not None):
314
+ if const_expr(len(mSplitCounts.shape) != 2):
315
+ raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)")
316
+ elif const_expr(len(mSplitCounts.shape) != 3):
317
+ raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)")
318
+ if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]):
319
+ raise TypeError("output_scale tensor must be Float32")
320
+
321
+ mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
322
+ # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
323
+ # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
324
+ O_partial_layout_transpose = (
325
+ [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
326
+ )
327
+ # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
328
+ mO_partial = cute.make_tensor(
329
+ mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
330
+ )
331
+ O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
332
+ mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
333
+ # (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b)
334
+ # Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads.
335
+ # or (num_splits, total_q, h) -> (total_q, num_splits, h)
336
+ LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
337
+ mLSE_partial = cute.make_tensor(
338
+ mLSE_partial.iterator,
339
+ cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
340
+ )
341
+ # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
342
+ LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
343
+ mLSE = (
344
+ cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
345
+ if mLSE is not None
346
+ else None
347
+ )
348
+ mLSE_temperature_partial = (
349
+ cute.make_tensor(
350
+ mLSE_temperature_partial.iterator,
351
+ cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose),
352
+ )
353
+ if mLSE_temperature_partial is not None
354
+ else None
355
+ )
356
+ mLSE_temperature = (
357
+ cute.make_tensor(
358
+ mLSE_temperature.iterator,
359
+ cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose),
360
+ )
361
+ if mLSE_temperature is not None
362
+ else None
363
+ )
364
+
365
+ # Determine if we have variable length sequences
366
+ varlen = const_expr(cu_seqlens is not None or seqused is not None)
367
+
368
+ self._setup_attributes()
369
+
370
+ # Output-dtype permutation buffer for Step 7 (tile_m × k_block_size).
371
+ # Accumulation stays fp32; the final dtype conversion happens before
372
+ # the fake→real SMEM scatter to reduce half-output SMEM pressure.
373
+ if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]):
374
+ smem_layout_perm = cute.make_layout(
375
+ (self.tile_m, self.k_block_size),
376
+ stride=(self.k_block_size + 16, 1),
377
+ )
378
+ else:
379
+ smem_layout_perm = cute.make_ordered_layout(
380
+ (self.tile_m, self.k_block_size), order=(1, 0)
381
+ )
382
+
383
+ @cute.struct
384
+ class SharedStorage:
385
+ sLSE: cute.struct.Align[
386
+ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
387
+ ]
388
+ sLSETemperature: cute.struct.Align[
389
+ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
390
+ ]
391
+ sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128]
392
+ sO: cute.struct.Align[
393
+ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
394
+ ]
395
+ sO_perm: cute.struct.Align[
396
+ cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128
397
+ ]
398
+
399
+ smem_size = SharedStorage.size_in_bytes()
400
+
401
+ # Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch)
402
+ # Head separated from seqlen → enables future TMA (contiguous Sq tiles)
403
+ seqlen = mO_partial.shape[0]
404
+ num_head = mO_partial.shape[3]
405
+ batch_size = (
406
+ mO_partial.shape[4]
407
+ if const_expr(cu_seqlens is None)
408
+ else Int32(cu_seqlens.shape[0] - 1)
409
+ )
410
+
411
+ seqlen_divmod = FastDivmodDivisor(seqlen)
412
+ head_divmod = FastDivmodDivisor(num_head)
413
+
414
+ grid_dim = (
415
+ cute.ceil_div(seqlen * num_head, self.tile_m),
416
+ cute.ceil_div(self.head_dim, self.k_block_size),
417
+ batch_size,
418
+ )
419
+
420
+ self.kernel(
421
+ mO_partial,
422
+ mLSE_partial,
423
+ mO,
424
+ mLSE,
425
+ mLSE_temperature_partial,
426
+ mLSE_temperature,
427
+ cu_seqlens,
428
+ seqused,
429
+ num_splits_dynamic_ptr,
430
+ varlen_batch_idx,
431
+ semaphore_to_reset,
432
+ mSplitCounts,
433
+ mOutputScale,
434
+ qhead_per_kvhead,
435
+ SharedStorage,
436
+ self.smem_layout_lse,
437
+ self.smem_layout_o,
438
+ smem_layout_perm,
439
+ self.gmem_tiled_copy_O_partial,
440
+ self.gmem_tiled_copy_O,
441
+ self.gmem_tiled_copy_LSE,
442
+ self.s2r_tiled_copy_LSE,
443
+ seqlen_divmod,
444
+ head_divmod,
445
+ self.use_pdl,
446
+ varlen,
447
+ ).launch(
448
+ grid=grid_dim,
449
+ block=[self.num_threads, 1, 1],
450
+ smem=smem_size,
451
+ stream=stream,
452
+ min_blocks_per_mp=self.min_blocks_per_mp,
453
+ use_pdl=self.use_pdl,
454
+ )
455
+
456
+ @cute.jit
457
+ def decode_flat_row_idx(
458
+ self,
459
+ idx: Int32,
460
+ head_divmod: FastDivmodDivisor,
461
+ ):
462
+ """Decode flattened tile rows under the H_q-innermost contract."""
463
+ q_idx_local, head_idx = divmod(idx, head_divmod)
464
+ return q_idx_local, head_idx
465
+
466
+ @cute.kernel
467
+ def kernel(
468
+ self,
469
+ mO_partial: cute.Tensor,
470
+ mLSE_partial: cute.Tensor,
471
+ mO: cute.Tensor,
472
+ mLSE: Optional[cute.Tensor],
473
+ mLSE_temperature_partial: Optional[cute.Tensor],
474
+ mLSE_temperature: Optional[cute.Tensor],
475
+ cu_seqlens: Optional[cute.Tensor],
476
+ seqused: Optional[cute.Tensor],
477
+ num_splits_dynamic_ptr: Optional[cute.Tensor],
478
+ varlen_batch_idx: Optional[cute.Tensor],
479
+ semaphore_to_reset: Optional[cute.Tensor],
480
+ mSplitCounts: Optional[cute.Tensor],
481
+ mOutputScale: Optional[cute.Tensor],
482
+ qhead_per_kvhead: Int32,
483
+ SharedStorage: cutlass.Constexpr,
484
+ smem_layout_lse: cute.Layout | cute.ComposedLayout,
485
+ smem_layout_o: cute.Layout | cute.ComposedLayout,
486
+ smem_layout_perm: cute.Layout,
487
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
488
+ gmem_tiled_copy_O: cute.TiledCopy,
489
+ gmem_tiled_copy_LSE: cute.TiledCopy,
490
+ s2r_tiled_copy_LSE: cute.TiledCopy,
491
+ seqlen_divmod: FastDivmodDivisor,
492
+ head_divmod: FastDivmodDivisor,
493
+ use_pdl: cutlass.Constexpr[bool],
494
+ varlen: cutlass.Constexpr[bool],
495
+ ):
496
+ # Thread and block indices
497
+ tidx, _, _ = cute.arch.thread_idx()
498
+ m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()
499
+
500
+ batch_idx = (
501
+ varlen_batch_idx[maybe_virtual_batch]
502
+ if const_expr(varlen_batch_idx is not None)
503
+ else maybe_virtual_batch
504
+ )
505
+
506
+ # ///////////////////////////////////////////////////////////////////////////////
507
+ # Get shared memory buffer
508
+ # ///////////////////////////////////////////////////////////////////////////////
509
+ smem = cutlass.utils.SmemAllocator()
510
+ storage = smem.allocate(SharedStorage)
511
+ sLSE = storage.sLSE.get_tensor(smem_layout_lse)
512
+ sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse)
513
+ sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
514
+ sO = storage.sO.get_tensor(smem_layout_o)
515
+ sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm)
516
+
517
+ # Handle semaphore reset — wait for dependent grids first
518
+ if const_expr(use_pdl and semaphore_to_reset is not None):
519
+ if (
520
+ tidx == 0
521
+ and m_block == cute.arch.grid_dim()[0] - 1
522
+ and k_block == cute.arch.grid_dim()[1] - 1
523
+ and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1
524
+ ):
525
+ cute.arch.griddepcontrol_wait()
526
+ semaphore_to_reset[0] = 0
527
+
528
+ if const_expr(num_splits_dynamic_ptr is not None):
529
+ raise ValueError("K2 combine requires compile-time exact topK")
530
+ num_splits = Int32(self.topk)
531
+ # Handle variable length sequences using SeqlenInfo
532
+ seqlen_info = SeqlenInfo.create(
533
+ batch_idx=batch_idx,
534
+ seqlen_static=mO_partial.shape[0],
535
+ cu_seqlens=cu_seqlens,
536
+ seqused=seqused,
537
+ # Don't need to pass in tile size since we won't use offset_padded
538
+ )
539
+ seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
540
+
541
+ num_head = mO_partial.shape[3]
542
+ max_idx = seqlen * num_head
543
+ output_scale = Float32(1.0)
544
+ if const_expr(mOutputScale is not None):
545
+ output_scale = mOutputScale[0]
546
+
547
+ if const_expr(not varlen) or m_block * self.tile_m < max_idx:
548
+ # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)
549
+ if const_expr(use_pdl):
550
+ cute.arch.griddepcontrol_wait()
551
+
552
+ # ===============================
553
+ # Step 1: Load LSE_partial from gmem to shared memory
554
+ # ===============================
555
+ # `cLSE` (identity tensor for row/split coord tracking) is reused
556
+ # later in steps 4-5, so it must be defined on both branches.
557
+ cLSE = cute.make_identity_tensor((self.topk, self.tile_m))
558
+ # Reshape mLSE_partial to PackGQA packed layout and delegate the
559
+ # tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq)
560
+ # into one compound dim with H_q innermost (stride 1), so thread
561
+ # rows that vary along h_pos produce one-sector coalesced reads.
562
+ # Non-varlen path only — varlen keeps the original inline loop.
563
+ if const_expr(not varlen):
564
+ mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
565
+ # mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose
566
+ # [3,0,2,1] on [topK,B,Sq,H_q] and dropping B.
567
+ # Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim:
568
+ mLSE_partial_reord = cute.make_tensor(
569
+ mLSE_partial_cur.iterator,
570
+ cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]),
571
+ )
572
+ mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2)
573
+ # shape ((H_q, Sq), topK) with H_q innermost.
574
+ packgqa = PackGQAComb(
575
+ m_block_size=self.tile_m,
576
+ head_dim_padded=0, # unused for LSE load
577
+ check_hdim_oob=False, # unused for LSE load
578
+ qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly
579
+ )
580
+ packgqa.load_LSE(
581
+ mLSE_partial_packed,
582
+ sLSE,
583
+ self.topk,
584
+ gmem_tiled_copy_LSE,
585
+ tidx,
586
+ m_block,
587
+ num_splits,
588
+ seqlen,
589
+ head_divmod,
590
+ mSplitCounts,
591
+ batch_idx,
592
+ qhead_per_kvhead,
593
+ )
594
+ if const_expr(mLSE_temperature_partial is not None):
595
+ mLSE_temperature_partial_cur = seqlen_info.offset_batch(
596
+ mLSE_temperature_partial, batch_idx, dim=3)
597
+ mLSE_temperature_partial_reord = cute.make_tensor(
598
+ mLSE_temperature_partial_cur.iterator,
599
+ cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]),
600
+ )
601
+ mLSE_temperature_partial_packed = cute.group_modes(
602
+ mLSE_temperature_partial_reord, 0, 2)
603
+ packgqa.load_LSE(
604
+ mLSE_temperature_partial_packed,
605
+ sLSE_temperature,
606
+ self.topk,
607
+ gmem_tiled_copy_LSE,
608
+ tidx,
609
+ m_block,
610
+ num_splits,
611
+ seqlen,
612
+ head_divmod,
613
+ mSplitCounts,
614
+ batch_idx,
615
+ qhead_per_kvhead,
616
+ )
617
+ else:
618
+ # Varlen path keeps the same H_q-innermost flat-row contract:
619
+ # after transpose [1, 0, 2], mLSE_partial_cur is
620
+ # (q_local, split, head).
621
+ # mSplitCounts is the authoritative valid-split count per
622
+ # packed (q_abs, kv_head); masked splits stay at -inf and
623
+ # therefore drop out of the final kernel LSE_out reduction.
624
+ mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
625
+ mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
626
+ gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
627
+ tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
628
+ tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature)
629
+ tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
630
+ if const_expr(mLSE_temperature_partial is not None):
631
+ mLSE_temperature_partial_cur = seqlen_info.offset_batch(
632
+ mLSE_temperature_partial, batch_idx, dim=3)
633
+ mLSE_temperature_partial_copy = cute.tiled_divide(
634
+ mLSE_temperature_partial_cur, (1,))
635
+
636
+ for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
637
+ mi = tLSEcLSE[0, 0, m][1]
638
+ idx = m_block * self.tile_m + mi
639
+ if idx < max_idx:
640
+ m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
641
+ row_count = (
642
+ mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead]
643
+ if const_expr(mSplitCounts is not None)
644
+ else num_splits
645
+ )
646
+ mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
647
+ if const_expr(mLSE_temperature_partial is not None):
648
+ mLSE_temperature_partial_cur_copy = (
649
+ mLSE_temperature_partial_copy[None, m_idx, None, head_idx])
650
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
651
+ si = tLSEcLSE[0, s, 0][0]
652
+ if si < num_splits and si < row_count:
653
+ cute.copy(
654
+ gmem_thr_copy_LSE,
655
+ mLSE_partial_cur_copy[None, si],
656
+ tLSEsLSE[None, s, m],
657
+ )
658
+ if const_expr(mLSE_temperature_partial is not None):
659
+ cute.copy(
660
+ gmem_thr_copy_LSE,
661
+ mLSE_temperature_partial_cur_copy[None, si],
662
+ tLSEsLSE_temperature[None, s, m],
663
+ )
664
+ else:
665
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
666
+ if const_expr(mLSE_temperature_partial is not None):
667
+ tLSEsLSE_temperature[None, s, m].fill(-Float32.inf)
668
+ else:
669
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
670
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
671
+ if const_expr(mLSE_temperature_partial is not None):
672
+ tLSEsLSE_temperature[None, s, m].fill(-Float32.inf)
673
+ cute.arch.cp_async_commit_group()
674
+
675
+ # ===============================
676
+ # Step 2: Load O_partial for pipeline stages
677
+ # ===============================
678
+
679
+ gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
680
+ cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
681
+ tOcO = gmem_thr_copy_O_partial.partition_D(cO)
682
+ tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
683
+ mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4)
684
+
685
+ # Precompute per-row values for flattened (q_local, head) tiles.
686
+ num_rows = const_expr(cute.size(tOcO, mode=[1]))
687
+ tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
688
+ tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
689
+ tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32)
690
+ tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64)
691
+ for m in cutlass.range(num_rows, unroll_full=True):
692
+ mi = tOcO[0, m, 0][0] # m coordinate in tile
693
+ idx = m_block * self.tile_m + mi
694
+ if idx >= max_idx:
695
+ tOhidx[m] = -1
696
+ tOmidx[m] = 0
697
+ tOSplitCount[m] = 0
698
+ tOrOptr[m] = cutlass.Int64(0)
699
+ else:
700
+ tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod)
701
+ if const_expr(mSplitCounts is None):
702
+ tOSplitCount[m] = num_splits
703
+ elif const_expr(cu_seqlens is None):
704
+ tOSplitCount[m] = mSplitCounts[
705
+ batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead
706
+ ]
707
+ else:
708
+ tOSplitCount[m] = mSplitCounts[
709
+ offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead
710
+ ]
711
+ tOrOptr[m] = utils.elem_pointer(
712
+ mO_partial_cur,
713
+ (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]),
714
+ ).toint()
715
+
716
+ tOpO = None
717
+ if const_expr(not self.is_even_k):
718
+ tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean)
719
+ for k in cutlass.range(cute.size(tOpO), unroll_full=True):
720
+ tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
721
+ # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
722
+
723
+ load_O_partial = partial(
724
+ self.load_O_partial,
725
+ gmem_tiled_copy_O_partial,
726
+ tOrOptr,
727
+ tOsO_partial,
728
+ tOhidx,
729
+ tOSplitCount,
730
+ tOpO,
731
+ tOcO,
732
+ mO_partial_cur.layout,
733
+ )
734
+
735
+ # Load first few stages of O_partial
736
+ for stage in cutlass.range(self.stages - 1, unroll_full=True):
737
+ if stage < num_splits:
738
+ load_O_partial(stage, stage)
739
+ cute.arch.cp_async_commit_group()
740
+
741
+ # ===============================
742
+ # Step 3: Load and transpose LSE from smem to registers
743
+ # ===============================
744
+
745
+ # Wait for LSE and initial O partial stages to complete
746
+ cute.arch.cp_async_wait_group(self.stages - 1)
747
+ cute.arch.sync_threads()
748
+ # if cute.arch.thread_idx()[0] == 0:
749
+ # # cute.print_tensor(sLSE)
750
+ # for i in range(64):
751
+ # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
752
+ # cute.arch.sync_threads()
753
+
754
+ s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
755
+ ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
756
+ ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
757
+ cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
758
+ if const_expr(mLSE_temperature_partial is not None):
759
+ ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature)
760
+ ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature)
761
+ cute.copy(
762
+ s2r_tiled_copy_LSE,
763
+ ts2rsLSE_temperature,
764
+ ts2rrLSE_temperature,
765
+ )
766
+
767
+ # ===============================
768
+ # Step 4: Compute final LSE along split dimension
769
+ # ===============================
770
+
771
+ final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
772
+ ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
773
+ # We compute the max valid split for each row to short-circuit the computation later
774
+ max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
775
+ assert cute.size(ts2rrLSE, mode=[0]) == 1
776
+ # Compute max, scales, and final LSE for each row. Invalid splits
777
+ # have already been filled with -inf, so Step 5 can write the
778
+ # kernel-native LSE_out directly.
779
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
780
+ # Find max LSE value across splits
781
+ threads_per_col = const_expr(self.smem_threads_per_col_lse)
782
+ lse_max = cute.arch.warp_reduction_max(
783
+ ts2rrLSE[None, None, m]
784
+ .load()
785
+ .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
786
+ threads_in_group=threads_per_col,
787
+ )
788
+ # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
789
+ # Find max valid split index
790
+ max_valid_idx = -1
791
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
792
+ if ts2rrLSE[0, s, m] != -Float32.inf:
793
+ max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
794
+ # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
795
+ max_valid_split[m] = cute.arch.warp_reduction_max(
796
+ max_valid_idx, threads_in_group=threads_per_col
797
+ )
798
+ # Compute exp scales and sum
799
+ lse_max_cur = (
800
+ 0.0 if lse_max == -Float32.inf else lse_max
801
+ ) # In case all local LSEs are -inf
802
+ LOG2_E = math.log2(math.e)
803
+ lse_sum_cur = 0.0
804
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
805
+ scale = cute.math.exp2(
806
+ ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
807
+ )
808
+ lse_sum_cur += scale
809
+ ts2rrLSE[0, s, m] = scale # Store scale for later use
810
+ lse_sum_cur = cute.arch.warp_reduction_sum(
811
+ lse_sum_cur, threads_in_group=threads_per_col
812
+ )
813
+ # Normalize scales
814
+ inv_sum = 0.0
815
+ if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur:
816
+ final_lse[m] = -Float32.inf
817
+ else:
818
+ final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
819
+ inv_sum = 1.0 / lse_sum_cur
820
+ ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
821
+ # Store the scales exp(lse - lse_logsum) back to smem
822
+ cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
823
+
824
+ if const_expr(mLSE_temperature_partial is not None):
825
+ final_lse_temperature = cute.make_rmem_tensor(
826
+ cute.size(ts2rrLSE_temperature, mode=[2]), Float32)
827
+ for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True):
828
+ threads_per_col = const_expr(self.smem_threads_per_col_lse)
829
+ lse_temperature_max = cute.arch.warp_reduction_max(
830
+ ts2rrLSE_temperature[None, None, m]
831
+ .load()
832
+ .reduce(
833
+ cute.ReductionOp.MAX,
834
+ init_val=-Float32.inf,
835
+ reduction_profile=0,
836
+ ),
837
+ threads_in_group=threads_per_col,
838
+ )
839
+ lse_temperature_max_cur = (
840
+ 0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max
841
+ )
842
+ LOG2_E = math.log2(math.e)
843
+ lse_temperature_sum_cur = 0.0
844
+ for s in cutlass.range(
845
+ cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True):
846
+ scale = cute.math.exp2(
847
+ ts2rrLSE_temperature[0, s, m] * LOG2_E
848
+ - (lse_temperature_max_cur * LOG2_E),
849
+ fastmath=True,
850
+ )
851
+ lse_temperature_sum_cur += scale
852
+ lse_temperature_sum_cur = cute.arch.warp_reduction_sum(
853
+ lse_temperature_sum_cur, threads_in_group=threads_per_col
854
+ )
855
+ if (
856
+ max_valid_split[m] < 0
857
+ or lse_temperature_sum_cur == 0.0
858
+ or lse_temperature_sum_cur != lse_temperature_sum_cur
859
+ ):
860
+ final_lse_temperature[m] = -Float32.inf
861
+ else:
862
+ final_lse_temperature[m] = (
863
+ cute.math.log(lse_temperature_sum_cur, fastmath=True)
864
+ + lse_temperature_max
865
+ )
866
+
867
+ # Store max valid split to smem
868
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
869
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
870
+ mi = ts2rcLSE[0, 0, m][1]
871
+ if mi < self.tile_m:
872
+ sMaxValidSplit[mi] = max_valid_split[m]
873
+
874
+ # ===============================
875
+ # Step 5: Store final LSE to gmem
876
+ # This writeback is the authoritative LSE_out returned by the
877
+ # public Sparse Attention / Sparse Page Attention interface.
878
+ # ===============================
879
+
880
+ if const_expr(mLSE is not None):
881
+ if const_expr(cu_seqlens is None):
882
+ mLSE_cur = mLSE[None, None, batch_idx]
883
+ else:
884
+ mLSE_cur = cute.domain_offset((offset, 0), mLSE)
885
+ if const_expr(mLSE_temperature is not None):
886
+ if const_expr(cu_seqlens is None):
887
+ mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx]
888
+ else:
889
+ mLSE_temperature_cur = cute.domain_offset(
890
+ (offset, 0), mLSE_temperature)
891
+ if k_block == 0: # Only first k_block writes LSE when mLSE is provided
892
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
893
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
894
+ mi = ts2rcLSE[0, 0, m][1]
895
+ idx = m_block * self.tile_m + mi
896
+ if idx < max_idx:
897
+ m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
898
+ mLSE_cur[m_idx, head_idx] = final_lse[m]
899
+ if const_expr(mLSE_temperature is not None):
900
+ mLSE_temperature_cur[m_idx, head_idx] = (
901
+ final_lse_temperature[m])
902
+
903
+ # ===============================
904
+ # Step 6: Read O_partial and accumulate final O
905
+ # ===============================
906
+
907
+ cute.arch.sync_threads()
908
+
909
+ # Get max valid split for this thread
910
+ thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
911
+ for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
912
+ thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
913
+
914
+ tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
915
+ tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
916
+ tOrO.fill(0.0)
917
+
918
+ stage_load = self.stages - 1
919
+ stage_compute = 0
920
+
921
+ # Main accumulation loop
922
+ for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
923
+ # Get scales for this split
924
+ scale = cute.make_rmem_tensor(num_rows, Float32)
925
+ for m in cutlass.range(num_rows, unroll_full=True):
926
+ scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
927
+
928
+ # Load next stage if needed
929
+ split_to_load = s + self.stages - 1
930
+ if split_to_load <= thr_max_valid_split:
931
+ load_O_partial(split_to_load, stage_load)
932
+ cute.arch.cp_async_commit_group()
933
+ stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
934
+
935
+ # Wait for the current stage to be ready
936
+ cute.arch.cp_async_wait_group(self.stages - 1)
937
+ # We don't need __syncthreads() because each thread is just reading its own data from smem
938
+ # Copy from smem to registers
939
+ cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
940
+ stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
941
+
942
+ # Accumulate scaled partial results
943
+ for m in cutlass.range(num_rows, unroll_full=True):
944
+ if tOhidx[m] >= 0 and scale[m] > 0.0:
945
+ tOrO[None, m, None].store(
946
+ tOrO[None, m, None].load()
947
+ + scale[m] * tOrO_partial[None, m, None].load().to(Float32)
948
+ )
949
+
950
+ # Flush any outstanding async-copy groups before the local Step-7
951
+ # permutation buffer is read on the tail of the kernel.
952
+ cute.arch.cp_async_wait_group(0)
953
+ cute.arch.sync_threads()
954
+
955
+ # ===============================
956
+ # Step 7: Write final O to gmem (fake→real via SMEM)
957
+ # ===============================
958
+
959
+ mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)
960
+ if const_expr(cu_seqlens is None):
961
+ mO_cur = mO[None, None, None, batch_idx]
962
+ else:
963
+ mO_cur = cute.domain_offset((offset, 0, 0), mO)
964
+ mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
965
+ num_vals = const_expr(cute.size(tOcO, mode=[0]))
966
+ if const_expr(not use_pdl):
967
+ # Direct / standalone calls don't participate in the K1->K2
968
+ # dependency chain. Use a simple per-element real-column store
969
+ # path here to keep mixed-shape launches stable.
970
+ for m in cutlass.range(num_rows, unroll_full=True):
971
+ if tOhidx[m] >= 0:
972
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
973
+ if const_expr(self.is_even_k) or tOpO[k]:
974
+ for v in cutlass.range(num_vals, unroll_full=True):
975
+ fake_col = tOcO[v, 0, k][1]
976
+ if const_expr(self.use_stg128_fp8_layout):
977
+ real_col = stg128_fp8_fake_col_to_real_col(fake_col)
978
+ elif const_expr(self.use_stg128_half_layout):
979
+ real_col = stg128_half_fake_col_to_real_col(fake_col)
980
+ else:
981
+ real_col = stg128_fake_col_to_real_col(fake_col)
982
+ o_val = tOrO[v, m, k]
983
+ if const_expr(mOutputScale is not None):
984
+ o_val = o_val * output_scale
985
+ mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype)
986
+ else:
987
+ # 7a: fp32 accumulator -> output dtype SMEM with fake→real
988
+ # permutation. The dedicated permutation buffer stays separate
989
+ # from the O_partial pipeline staging buffer.
990
+ sO_perm = sO_perm_buf
991
+
992
+ if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]):
993
+ # O_partial uses a dtype-specific STG.128 fake layout, but
994
+ # sO_perm is in the final O dtype. For all supported fake
995
+ # layouts, adjacent fake pairs map to adjacent real columns,
996
+ # so write the final BF16/F16 O pair as one 32-bit SMEM store.
997
+ assert num_vals % 2 == 0
998
+ r2s_o_pair_atom = cute.make_copy_atom(
999
+ cute.nvgpu.CopyUniversalOp(),
1000
+ cutlass.Int32,
1001
+ num_bits_per_copy=32,
1002
+ )
1003
+ rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32)
1004
+ sO_perm_i32_base = cute.make_ptr(
1005
+ dtype=cutlass.Int32,
1006
+ value=sO_perm.iterator.toint(),
1007
+ mem_space=sO_perm.iterator.memspace,
1008
+ assumed_align=4,
1009
+ )
1010
+ sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2)
1011
+ for m in cutlass.range(num_rows, unroll_full=True):
1012
+ row_local = tOcO[0, m, 0][0]
1013
+ if tOhidx[m] >= 0:
1014
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
1015
+ for v_pair in cutlass.range(num_vals // 2, unroll_full=True):
1016
+ v = v_pair * 2
1017
+ fake_col = tOcO[v, 0, k][1]
1018
+ if const_expr(self.use_stg128_fp8_layout):
1019
+ real_col = stg128_fp8_fake_col_to_real_col(fake_col)
1020
+ elif const_expr(self.use_stg128_half_layout):
1021
+ real_col = stg128_half_fake_col_to_real_col(fake_col)
1022
+ else:
1023
+ real_col = stg128_fake_col_to_real_col(fake_col)
1024
+ o0 = tOrO[v, m, k]
1025
+ o1 = tOrO[v + 1, m, k]
1026
+ if const_expr(mOutputScale is not None):
1027
+ o0, o1 = cute.arch.mul_packed_f32x2(
1028
+ (o0, o1),
1029
+ (output_scale, output_scale),
1030
+ )
1031
+ rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype)
1032
+ smem_pair_ptr = cute.make_ptr(
1033
+ dtype=cutlass.Int32,
1034
+ value=(
1035
+ sO_perm_i32_base.toint()
1036
+ + Int64(
1037
+ row_local * sO_perm_i32_row_stride
1038
+ + real_col // Int32(2)
1039
+ )
1040
+ * Int64(4)
1041
+ ),
1042
+ mem_space=sO_perm.iterator.memspace,
1043
+ assumed_align=4,
1044
+ )
1045
+ sO_pair = cute.make_tensor(
1046
+ smem_pair_ptr,
1047
+ cute.make_layout((1,), stride=(1,)),
1048
+ )
1049
+ cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair)
1050
+ else:
1051
+ # 7a: iterate over ALL val elements in mode[0].
1052
+ # tOcO[v, m, k][1] gives different fake_col for each v.
1053
+ r2s_o_scalar_atom = cute.make_copy_atom(
1054
+ cute.nvgpu.CopyUniversalOp(),
1055
+ self.dtype,
1056
+ num_bits_per_copy=self.dtype.width,
1057
+ )
1058
+ rO_scalar = cute.make_rmem_tensor((1,), self.dtype)
1059
+ for m in cutlass.range(num_rows, unroll_full=True):
1060
+ row_local = tOcO[0, m, 0][0]
1061
+ if tOhidx[m] >= 0:
1062
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
1063
+ for v in cutlass.range(num_vals, unroll_full=True):
1064
+ fake_col = tOcO[v, 0, k][1]
1065
+ if const_expr(self.use_stg128_fp8_layout):
1066
+ real_col = stg128_fp8_fake_col_to_real_col(fake_col)
1067
+ elif const_expr(self.use_stg128_half_layout):
1068
+ real_col = stg128_half_fake_col_to_real_col(fake_col)
1069
+ else:
1070
+ real_col = stg128_fake_col_to_real_col(fake_col)
1071
+ o_val = tOrO[v, m, k]
1072
+ if const_expr(mOutputScale is not None):
1073
+ o_val = o_val * output_scale
1074
+ rO_scalar[0] = o_val.to(self.dtype)
1075
+ smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col))
1076
+ smem_scalar_ptr = cute.make_ptr(
1077
+ dtype=self.dtype,
1078
+ value=smem_ptr.toint(),
1079
+ mem_space=sO_perm.iterator.memspace,
1080
+ assumed_align=self.dtype.width // 8,
1081
+ )
1082
+ sO_scalar = cute.make_tensor(
1083
+ smem_scalar_ptr,
1084
+ cute.make_layout((1,), stride=(1,)),
1085
+ )
1086
+ cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar)
1087
+
1088
+ cute.arch.sync_threads()
1089
+
1090
+ # 7b: SMEM (real order, output dtype) → GMEM
1091
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
1092
+ tOcO_store = gmem_thr_copy_O.partition_D(cO)
1093
+ tOsO_store = gmem_thr_copy_O.partition_D(sO_perm)
1094
+ rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype)
1095
+ elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
1096
+ num_store_rows = const_expr(cute.size(tOcO_store, mode=[1]))
1097
+ num_store_vals = const_expr(cute.size(tOcO_store, mode=[0]))
1098
+ tOpO_store = None
1099
+ if const_expr(not self.is_even_k):
1100
+ tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean)
1101
+ for k in cutlass.range(cute.size(tOpO_store), unroll_full=True):
1102
+ tOpO_store[k] = (
1103
+ tOcO_store[0, 0, k][1]
1104
+ < mO_partial.shape[1] - k_block * self.k_block_size
1105
+ )
1106
+
1107
+ # Read output dtype from SMEM (now in real column order).
1108
+ for m in cutlass.range(num_store_rows, unroll_full=True):
1109
+ for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True):
1110
+ if const_expr(self.is_even_k) or tOpO_store[k]:
1111
+ cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k])
1112
+
1113
+ # Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7)
1114
+ for m in cutlass.range(num_store_rows, unroll_full=True):
1115
+ row_local = tOcO_store[0, m, 0][0]
1116
+ idx = m_block * self.tile_m + row_local
1117
+ if idx < max_idx:
1118
+ m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
1119
+ mO_cur_copy = cute.tiled_divide(
1120
+ mO_cur[m_idx, None, head_idx], (elems_per_store,)
1121
+ )
1122
+ for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True):
1123
+ k_idx = tOcO_store[0, 0, k][1] // elems_per_store
1124
+ if const_expr(self.is_even_k) or tOpO_store[k]:
1125
+ cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
1126
+
1127
+ @cute.jit
1128
+ def load_O_partial(
1129
+ self,
1130
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
1131
+ tOrOptr: cute.Tensor,
1132
+ tOsO_partial: cute.Tensor,
1133
+ tOhidx: cute.Tensor,
1134
+ tOSplitCount: cute.Tensor,
1135
+ tOpO: Optional[cute.Tensor],
1136
+ tOcO: cute.Tensor,
1137
+ mO_cur_partial_layout: cute.Layout,
1138
+ split: Int32,
1139
+ stage: Int32,
1140
+ ) -> None:
1141
+ elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
1142
+ tOsO_partial_cur = tOsO_partial[None, None, None, stage]
1143
+ for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
1144
+ if tOhidx[m] >= 0:
1145
+ o_gmem_ptr = cute.make_ptr(
1146
+ tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
1147
+ )
1148
+ mO_partial_cur = cute.make_tensor(
1149
+ o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
1150
+ )
1151
+ mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
1152
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
1153
+ k_idx = tOcO[0, 0, k][1] // elems_per_load
1154
+ if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]):
1155
+ cute.copy(
1156
+ gmem_tiled_copy_O_partial,
1157
+ mO_partial_cur_copy[None, k_idx, split],
1158
+ tOsO_partial_cur[None, m, k],
1159
+ )
1160
+ else:
1161
+ tOsO_partial_cur[None, m, k].fill(0)
1162
+
1163
+
1164
+ def _get_cutlass_dtype(torch_dtype: torch.dtype):
1165
+ if torch_dtype not in torch2cute_dtype_map:
1166
+ raise TypeError(f"Unsupported dtype: {torch_dtype}")
1167
+ return torch2cute_dtype_map[torch_dtype]
1168
+
1169
+
1170
+ _combine_compile_cache = {}
1171
+
1172
+
1173
+ def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
1174
+ dtype_byte = const_expr(dtype.width // 8)
1175
+ bytes_per_row = const_expr(k_dim * dtype_byte)
1176
+ smem_k_block_size = (
1177
+ const_expr(
1178
+ 128
1179
+ if bytes_per_row % 128 == 0
1180
+ else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
1181
+ )
1182
+ // dtype_byte
1183
+ )
1184
+ swizzle_bits = (
1185
+ 4
1186
+ if smem_k_block_size == 128
1187
+ else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
1188
+ )
1189
+ swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
1190
+ return cute.make_composed_layout(
1191
+ cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
1192
+ 0,
1193
+ cute.make_ordered_layout(
1194
+ (8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size),
1195
+ order=(1, 0),
1196
+ ),
1197
+ )
1198
+
1199
+
1200
+ def combine(
1201
+ o_partial_fake,
1202
+ lse_partial,
1203
+ o_out,
1204
+ lse_out,
1205
+ *,
1206
+ lse_temperature_partial=None,
1207
+ lse_temperature_out=None,
1208
+ cu_seqlens=None,
1209
+ seqused=None,
1210
+ split_counts=None,
1211
+ output_scale=None,
1212
+ use_pdl=False,
1213
+ ):
1214
+ """K2: merge sparse forward split partials into the final output.
1215
+
1216
+ STG.128 fake-layout handling remains an internal implementation detail.
1217
+ When lse_out is provided, the kernel writes the final authoritative
1218
+ log-sum-exp for each query row/head directly into that tensor.
1219
+
1220
+ Args:
1221
+ o_partial_fake:
1222
+ Batched: [num_splits, batch, Sq, head_q, dim]
1223
+ Varlen: [num_splits, total_q, head_q, dim]
1224
+ lse_partial:
1225
+ Batched: [num_splits, batch, Sq, head_q]
1226
+ Varlen: [num_splits, total_q, head_q]
1227
+ o_out:
1228
+ Batched: [batch, Sq, head_q, dim]
1229
+ Varlen: [total_q, head_q, dim]
1230
+ lse_out:
1231
+ Batched: [batch, Sq, head_q]
1232
+ Varlen: [total_q, head_q]
1233
+ lse_temperature_partial:
1234
+ Optional temperature-scaled LSE partial with the same shape as
1235
+ lse_partial.
1236
+ lse_temperature_out:
1237
+ Optional temperature-scaled final LSE with the same shape as
1238
+ lse_out.
1239
+ cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine.
1240
+ seqused: Optional [batch] int32 effective lengths for combine.
1241
+ split_counts: Optional int32 rowwise valid split counts prepared from
1242
+ q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen:
1243
+ [total_q, head_kv].
1244
+ output_scale: Optional fp32 tensor with at least one element. When
1245
+ provided, the final O accumulator is multiplied once before store.
1246
+ use_pdl: When True, wait on PDL dependencies from the producer K1
1247
+ kernel. When False, launch without PDL waits.
1248
+ """
1249
+ D = o_partial_fake.shape[-1]
1250
+ num_splits = o_partial_fake.shape[0]
1251
+ return_temperature_lse = (
1252
+ lse_temperature_partial is not None or lse_temperature_out is not None
1253
+ )
1254
+ if (lse_temperature_partial is None) != (lse_temperature_out is None):
1255
+ raise ValueError(
1256
+ "lse_temperature_partial and lse_temperature_out must either both be provided or both be None"
1257
+ )
1258
+ if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape:
1259
+ raise ValueError(
1260
+ "lse_temperature_partial must have the same shape as lse_partial, "
1261
+ f"got {lse_temperature_partial.shape} vs {lse_partial.shape}"
1262
+ )
1263
+ if lse_temperature_out is not None:
1264
+ if lse_out is None:
1265
+ raise ValueError("lse_temperature_out requires lse_out")
1266
+ if lse_temperature_out.shape != lse_out.shape:
1267
+ raise ValueError(
1268
+ "lse_temperature_out must have the same shape as lse_out, "
1269
+ f"got {lse_temperature_out.shape} vs {lse_out.shape}"
1270
+ )
1271
+ if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32:
1272
+ raise TypeError("temperature LSE tensors must be torch.float32")
1273
+
1274
+ partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype)
1275
+ out_dtype = _get_cutlass_dtype(o_out.dtype)
1276
+ if output_scale is not None:
1277
+ if output_scale.dtype != torch.float32:
1278
+ raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}")
1279
+ if output_scale.numel() < 1:
1280
+ raise ValueError("output_scale must contain at least one element")
1281
+ if output_scale.device != o_out.device:
1282
+ raise ValueError("output_scale must be on the same device as o_out")
1283
+ output_scale = output_scale.contiguous()
1284
+ if split_counts is not None:
1285
+ if split_counts.dtype != torch.int32:
1286
+ raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}")
1287
+ if o_out.ndim == 4:
1288
+ if split_counts.ndim != 3:
1289
+ raise ValueError(
1290
+ f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}"
1291
+ )
1292
+ if split_counts.shape[:2] != o_out.shape[:2]:
1293
+ raise ValueError(
1294
+ f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}"
1295
+ )
1296
+ else:
1297
+ if cu_seqlens is None:
1298
+ raise ValueError("split_counts with varlen output requires cu_seqlens")
1299
+ if split_counts.ndim != 2:
1300
+ raise ValueError(
1301
+ f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}"
1302
+ )
1303
+ if split_counts.shape[0] != o_out.shape[0]:
1304
+ raise ValueError(
1305
+ f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q "
1306
+ f"({o_out.shape[0]})"
1307
+ )
1308
+ if o_out.shape[-2] % split_counts.shape[-1] != 0:
1309
+ raise ValueError(
1310
+ f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})"
1311
+ )
1312
+ qheadperkv = o_out.shape[-2] // split_counts.shape[-1]
1313
+ else:
1314
+ qheadperkv = 1
1315
+ if cu_seqlens is not None:
1316
+ if cu_seqlens.dtype != torch.int32:
1317
+ raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}")
1318
+ if cu_seqlens.ndim != 1:
1319
+ raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}")
1320
+ if not cu_seqlens.is_contiguous():
1321
+ raise ValueError("cu_seqlens must be contiguous")
1322
+ if seqused is not None:
1323
+ if seqused.dtype != torch.int32:
1324
+ raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}")
1325
+ if seqused.ndim != 1:
1326
+ raise ValueError(f"seqused must be rank-1, got {seqused.shape}")
1327
+ if not seqused.is_contiguous():
1328
+ raise ValueError("seqused must be contiguous")
1329
+
1330
+ k_block_size = 128 if D > 64 else 64
1331
+ tile_m = 64
1332
+ has_cu_seqlens = cu_seqlens is not None
1333
+ has_seqused = seqused is not None
1334
+ has_lse = lse_out is not None
1335
+ has_split_counts = split_counts is not None
1336
+ has_output_scale = output_scale is not None
1337
+ min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0
1338
+
1339
+ key = (
1340
+ "combine",
1341
+ D,
1342
+ k_block_size,
1343
+ tile_m,
1344
+ num_splits,
1345
+ partial_dtype,
1346
+ out_dtype,
1347
+ has_cu_seqlens,
1348
+ has_seqused,
1349
+ has_lse,
1350
+ bool(return_temperature_lse),
1351
+ has_split_counts,
1352
+ has_output_scale,
1353
+ use_pdl,
1354
+ min_blocks_per_mp,
1355
+ )
1356
+ if key not in _combine_compile_cache:
1357
+ from ....src.common.aot_cache import try_load_aot, save_aot
1358
+
1359
+ loaded = try_load_aot(key)
1360
+ if loaded is not None:
1361
+ _combine_compile_cache[key] = loaded
1362
+ else:
1363
+ from ....quack.compile_utils import make_fake_tensor
1364
+
1365
+ kernel = SparseAttentionForwardCombine(
1366
+ dtype=out_dtype,
1367
+ dtype_partial=partial_dtype,
1368
+ head_dim=D,
1369
+ tile_m=tile_m,
1370
+ k_block_size=k_block_size,
1371
+ topk=num_splits,
1372
+ use_pdl=use_pdl,
1373
+ min_blocks_per_mp=min_blocks_per_mp,
1374
+ # stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM,
1375
+ # theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35%
1376
+ # -> 88.64%. Runtime latency within noise (kernel already at HBM
1377
+ # bandwidth ceiling in practice) but the cleaner SOL profile
1378
+ # matters for downstream NCU comparison.
1379
+ stages=2,
1380
+ )
1381
+ div = 128 // partial_dtype.width
1382
+ if has_cu_seqlens:
1383
+ total_q, nheads = (cute.sym_int64() for _ in range(2))
1384
+ mO_partial = make_fake_tensor(
1385
+ partial_dtype, (num_splits, total_q, nheads, D), divisibility=div
1386
+ )
1387
+ mLSE_partial = make_fake_tensor(
1388
+ Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2
1389
+ )
1390
+ mO = make_fake_tensor(
1391
+ out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width
1392
+ )
1393
+ mLSE = (
1394
+ make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1)
1395
+ if has_lse
1396
+ else None
1397
+ )
1398
+ mLSE_temperature_partial = (
1399
+ make_fake_tensor(
1400
+ Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2
1401
+ )
1402
+ if return_temperature_lse
1403
+ else None
1404
+ )
1405
+ mLSE_temperature = (
1406
+ make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1)
1407
+ if return_temperature_lse
1408
+ else None
1409
+ )
1410
+ else:
1411
+ batch, sq, nheads = (cute.sym_int64() for _ in range(3))
1412
+ mO_partial = make_fake_tensor(
1413
+ partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div
1414
+ )
1415
+ mLSE_partial = make_fake_tensor(
1416
+ Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3
1417
+ )
1418
+ mO = make_fake_tensor(
1419
+ out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width
1420
+ )
1421
+ mLSE = (
1422
+ make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2)
1423
+ if has_lse
1424
+ else None
1425
+ )
1426
+ mLSE_temperature_partial = (
1427
+ make_fake_tensor(
1428
+ Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3
1429
+ )
1430
+ if return_temperature_lse
1431
+ else None
1432
+ )
1433
+ mLSE_temperature = (
1434
+ make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2)
1435
+ if return_temperature_lse
1436
+ else None
1437
+ )
1438
+ if not has_split_counts:
1439
+ mSplitCounts = None
1440
+ elif has_cu_seqlens:
1441
+ total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2))
1442
+ mSplitCounts = make_fake_tensor(
1443
+ Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1
1444
+ )
1445
+ else:
1446
+ nheads_kv = cute.sym_int64()
1447
+ mSplitCounts = make_fake_tensor(
1448
+ Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2
1449
+ )
1450
+ mOutputScale = (
1451
+ make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0)
1452
+ if has_output_scale
1453
+ else None
1454
+ )
1455
+ stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
1456
+
1457
+ _combine_compile_cache[key] = cute.compile(
1458
+ kernel,
1459
+ mO_partial,
1460
+ mLSE_partial,
1461
+ mO,
1462
+ mLSE,
1463
+ mLSE_temperature_partial,
1464
+ mLSE_temperature,
1465
+ None
1466
+ if cu_seqlens is None
1467
+ else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0),
1468
+ None
1469
+ if seqused is None
1470
+ else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0),
1471
+ None,
1472
+ None,
1473
+ None,
1474
+ mSplitCounts,
1475
+ mOutputScale,
1476
+ Int32(qheadperkv),
1477
+ stream,
1478
+ options="--enable-tvm-ffi",
1479
+ )
1480
+ save_aot(key, _combine_compile_cache[key])
1481
+
1482
+ with torch.cuda.nvtx.range("K2_Combine"):
1483
+ _combine_compile_cache[key](
1484
+ o_partial_fake,
1485
+ lse_partial,
1486
+ o_out,
1487
+ lse_out,
1488
+ lse_temperature_partial,
1489
+ lse_temperature_out,
1490
+ cu_seqlens,
1491
+ seqused,
1492
+ None,
1493
+ None,
1494
+ None,
1495
+ split_counts,
1496
+ output_scale,
1497
+ qheadperkv,
1498
+ )
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """CUTE DSL launchers for paged fp8 decode forward."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+
10
+ from .atten_fwd import run_decode_attention
11
+ from .combine import run_decode_combine
12
+
13
+
14
+ def decode_forward_paged_fp8(
15
+ q: torch.Tensor,
16
+ k: torch.Tensor,
17
+ v: torch.Tensor,
18
+ page_table: torch.Tensor,
19
+ seqused_k: torch.Tensor,
20
+ out: torch.Tensor,
21
+ lse: torch.Tensor,
22
+ request_indices: torch.Tensor,
23
+ qo_tile_indices: torch.Tensor,
24
+ kv_tile_indices: torch.Tensor,
25
+ block_valid_mask: torch.Tensor,
26
+ split_counts: torch.Tensor,
27
+ o_indptr: torch.Tensor,
28
+ merge_indptr: torch.Tensor,
29
+ O_partial: torch.Tensor | None,
30
+ LSE_partial: torch.Tensor | None,
31
+ *,
32
+ softmax_scale: float,
33
+ seqlen_q: int,
34
+ page_size: int,
35
+ kv_chunk_size_pages: int,
36
+ max_split_count: int,
37
+ split_kv: bool,
38
+ causal: bool,
39
+ return_lse: bool = True,
40
+ O_partial_dummy: torch.Tensor | None = None,
41
+ LSE_partial_dummy: torch.Tensor | None = None,
42
+ ) -> None:
43
+ """Launch dense paged fp8 decode forward and optional compressed combine.
44
+
45
+ ``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated
46
+ placeholder buffers for the non-split path. When supplied, ``run_decode_attention``
47
+ skips the per-call ``torch.empty`` it would otherwise need to satisfy the
48
+ kernel's positional arg signature, saving ~5us on small-kv calls.
49
+ """
50
+
51
+ run_decode_attention(
52
+ q,
53
+ k,
54
+ v,
55
+ page_table,
56
+ seqused_k,
57
+ request_indices,
58
+ qo_tile_indices,
59
+ kv_tile_indices,
60
+ block_valid_mask,
61
+ split_counts,
62
+ o_indptr,
63
+ out,
64
+ lse,
65
+ O_partial,
66
+ LSE_partial,
67
+ softmax_scale=float(softmax_scale),
68
+ seqlen_q=int(seqlen_q),
69
+ page_size=int(page_size),
70
+ kv_chunk_size_pages=int(kv_chunk_size_pages),
71
+ split_kv=bool(split_kv),
72
+ causal=bool(causal),
73
+ return_lse=bool(return_lse),
74
+ O_partial_dummy=O_partial_dummy,
75
+ LSE_partial_dummy=LSE_partial_dummy,
76
+ )
77
+ if split_kv:
78
+ if O_partial is None or LSE_partial is None:
79
+ raise ValueError("split decode requires O_partial and LSE_partial")
80
+ qhead_per_kv = q.shape[1] // k.shape[1]
81
+ q_tokens_per_group = 128 // int(qhead_per_kv)
82
+ run_decode_combine(
83
+ O_partial,
84
+ LSE_partial,
85
+ split_counts,
86
+ o_indptr,
87
+ out,
88
+ lse,
89
+ seqlen_q=int(seqlen_q),
90
+ q_tokens_per_group=q_tokens_per_group,
91
+ max_split_count=int(max_split_count),
92
+ )
93
+
94
+
95
+ __all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Paged decode split-KV scheduling backed by the precompiled Torch op.
5
+
6
+ The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is
7
+ built ahead of time by kernel-builder. The op returns the schedule arrays
8
+ plus a fixed-order scalar summary, which is reassembled into the schedule
9
+ dict here.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+
16
+ from ....._ops import ops
17
+
18
+ # Order of the scalar summary returned by the op; must match
19
+ # csrc/build_decode_schedule.cu.
20
+ _SCALAR_KEYS = (
21
+ "split_kv",
22
+ "cta_tile_q",
23
+ "num_q_tiles",
24
+ "kv_chunk_size_pages",
25
+ "kv_chunk_size_tokens",
26
+ "work_count",
27
+ "padded_work_count",
28
+ "partial_rows",
29
+ "max_split_count",
30
+ "max_grid_size",
31
+ "active_blocks_per_sm",
32
+ "num_sms",
33
+ "base_cta",
34
+ )
35
+
36
+
37
+ def build_decode_schedule(
38
+ seqused_k: torch.Tensor,
39
+ *,
40
+ page_size: int,
41
+ seqlen_q: int,
42
+ num_qo_heads: int,
43
+ num_kv_heads: int,
44
+ head_dim: int,
45
+ max_seqlen_k: int,
46
+ enable_cuda_graph: bool = False,
47
+ max_grid_size: int = 0,
48
+ fixed_split_size: int = -1,
49
+ disable_split_kv: bool = False,
50
+ ) -> dict[str, object]:
51
+ """GPU-only schedule build: single CUDA kernel produces all schedule
52
+ index arrays on device. Only a small summary tensor is D2H'd at the end
53
+ so the wrapper can size O_partial, pick the kernel grid, and choose
54
+ split/non-split compile path.
55
+
56
+ ``max_seqlen_k`` is required as the host-side worst-case bound for
57
+ padding the work-tile arrays.
58
+ """
59
+
60
+ (
61
+ request_indices,
62
+ qo_tile_indices,
63
+ kv_tile_indices,
64
+ block_valid_mask,
65
+ split_counts,
66
+ kv_pages,
67
+ merge_indptr,
68
+ o_indptr,
69
+ scalars,
70
+ ) = ops.build_decode_schedule(
71
+ seqused_k,
72
+ int(page_size),
73
+ int(seqlen_q),
74
+ int(num_qo_heads),
75
+ int(num_kv_heads),
76
+ int(head_dim),
77
+ int(max_seqlen_k),
78
+ bool(enable_cuda_graph),
79
+ int(max_grid_size),
80
+ int(fixed_split_size),
81
+ bool(disable_split_kv),
82
+ )
83
+
84
+ raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars)))
85
+ raw["split_kv"] = bool(raw["split_kv"])
86
+ raw["request_indices"] = request_indices
87
+ raw["qo_tile_indices"] = qo_tile_indices
88
+ raw["kv_tile_indices"] = kv_tile_indices
89
+ raw["block_valid_mask"] = block_valid_mask
90
+ raw["split_counts"] = split_counts
91
+ raw["kv_pages"] = kv_pages
92
+ raw["merge_indptr"] = merge_indptr
93
+ raw["o_indptr"] = o_indptr
94
+
95
+ # The CUDA kernel writes into worst-case-padded buffers (size =
96
+ # batch * num_q_tiles * max_pages_global) but only the first
97
+ # ``padded_work_count`` entries are valid. Downstream consumers
98
+ # (tile_scheduler) take grid size from ``request_indices.shape[0]``
99
+ # so we narrow the views to that count; the underlying allocation
100
+ # is unchanged so this is a view, no copy.
101
+ pad = int(raw["padded_work_count"])
102
+ for key in (
103
+ "request_indices",
104
+ "qo_tile_indices",
105
+ "kv_tile_indices",
106
+ "block_valid_mask",
107
+ ):
108
+ raw[key] = raw[key].narrow(0, 0, pad)
109
+ return raw
110
+
111
+
112
+ __all__ = ["build_decode_schedule"]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """LDGSTS split-KV combine for paged decode attention."""
5
+
6
+ import math
7
+ from functools import partial
8
+ from typing import Type
9
+
10
+ import cuda.bindings.driver as cuda
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ import torch
14
+ from cutlass import Float32, Int32, Int64, const_expr
15
+ from cutlass.cute import FastDivmodDivisor
16
+ from cutlass.cute.nvgpu import cpasync
17
+
18
+ from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map
19
+
20
+
21
+ class SparseDecodeForwardCombine:
22
+ """Combine split-KV decode partials with FA-style LDGSTS staging.
23
+
24
+ ``mO_partial`` and ``mLSE_partial`` use the split-major padded layout:
25
+ ``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where
26
+ ``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``.
27
+ A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one
28
+ ``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM
29
+ via ``cpasync.CopyG2SOp`` before the split reduction.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ dtype: Type[cutlass.Numeric],
35
+ dtype_partial: Type[cutlass.Numeric],
36
+ head_dim: int,
37
+ *,
38
+ tile_m: int = 64,
39
+ k_block_size: int = 128,
40
+ max_splits: int = 4,
41
+ num_threads: int = 256,
42
+ stages: int = 2,
43
+ ):
44
+ if head_dim != 128:
45
+ raise NotImplementedError(
46
+ f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}"
47
+ )
48
+ if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]:
49
+ raise TypeError(f"Unsupported output dtype: {dtype}")
50
+ if dtype_partial is not Float32:
51
+ raise TypeError("decode O_partial must be Float32")
52
+ if k_block_size != head_dim:
53
+ raise NotImplementedError("decode combine currently uses one D=128 k block")
54
+ if tile_m % 8 != 0:
55
+ raise ValueError("decode combine tile_m must be divisible by 8")
56
+ if max_splits < 1 or max_splits > 256:
57
+ raise ValueError("decode combine max_splits must be in [1, 256]")
58
+
59
+ self.dtype = dtype
60
+ self.dtype_partial = dtype_partial
61
+ self.head_dim = head_dim
62
+ self.tile_m = tile_m
63
+ self.k_block_size = k_block_size
64
+ self.max_splits = max_splits
65
+ self.num_threads = num_threads
66
+ self.stages = stages
67
+ self.is_even_k = head_dim % k_block_size == 0
68
+
69
+ def _setup_attributes(self) -> None:
70
+ universal_copy_bits = 128
71
+ async_copy_elems = universal_copy_bits // self.dtype_partial.width
72
+ assert self.k_block_size % async_copy_elems == 0
73
+
74
+ k_block_gmem = (
75
+ 128
76
+ if self.k_block_size % 128 == 0
77
+ else (64 if self.k_block_size % 64 == 0 else 32)
78
+ )
79
+ gmem_threads_per_row = k_block_gmem // async_copy_elems
80
+ assert self.num_threads % gmem_threads_per_row == 0
81
+
82
+ atom_async_copy_partial = cute.make_copy_atom(
83
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
84
+ self.dtype_partial,
85
+ num_bits_per_copy=universal_copy_bits,
86
+ )
87
+ tOpartial_layout = cute.make_ordered_layout(
88
+ (self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
89
+ order=(1, 0),
90
+ )
91
+ vOpartial_layout = cute.make_layout((1, async_copy_elems))
92
+ self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
93
+ atom_async_copy_partial, tOpartial_layout, vOpartial_layout
94
+ )
95
+
96
+ atom_universal_copy = cute.make_copy_atom(
97
+ cute.nvgpu.CopyUniversalOp(),
98
+ self.dtype,
99
+ num_bits_per_copy=async_copy_elems * self.dtype.width,
100
+ )
101
+ self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
102
+ atom_universal_copy, tOpartial_layout, vOpartial_layout
103
+ )
104
+
105
+ lse_copy_bits = Float32.width
106
+ m_block_smem = (
107
+ 128
108
+ if self.tile_m % 128 == 0
109
+ else (
110
+ 64
111
+ if self.tile_m % 64 == 0
112
+ else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
113
+ )
114
+ )
115
+ gmem_threads_per_row_lse = m_block_smem
116
+ assert self.num_threads % gmem_threads_per_row_lse == 0
117
+
118
+ atom_async_copy_lse = cute.make_copy_atom(
119
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
120
+ Float32,
121
+ num_bits_per_copy=lse_copy_bits,
122
+ )
123
+ tLSE_layout = cute.make_ordered_layout(
124
+ (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
125
+ order=(1, 0),
126
+ )
127
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
128
+ atom_async_copy_lse, tLSE_layout, cute.make_layout(1)
129
+ )
130
+
131
+ self.smem_threads_per_col_lse = self.num_threads // m_block_smem
132
+ assert 32 % self.smem_threads_per_col_lse == 0
133
+ s2r_layout_atom_lse = cute.make_ordered_layout(
134
+ (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
135
+ order=(0, 1),
136
+ )
137
+ self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
138
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
139
+ s2r_layout_atom_lse,
140
+ cute.make_layout(1),
141
+ )
142
+
143
+ if const_expr(m_block_smem == 8):
144
+ smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
145
+ elif const_expr(m_block_smem == 16):
146
+ smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
147
+ else:
148
+ smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
149
+ lse_atom_splits = min(self.max_splits, 8)
150
+ smem_layout_atom_lse = cute.make_composed_layout(
151
+ smem_lse_swizzle,
152
+ 0,
153
+ cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)),
154
+ )
155
+ self.smem_layout_lse = cute.tile_to_shape(
156
+ smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1)
157
+ )
158
+ self.smem_layout_o = cute.make_ordered_layout(
159
+ (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
160
+ )
161
+
162
+ @cute.jit
163
+ def __call__(
164
+ self,
165
+ mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32
166
+ mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32
167
+ mSplitCounts: cute.Tensor, # [B] int32
168
+ mOIndptr: cute.Tensor, # [B + 1] int32
169
+ mO: cute.Tensor, # [total_q, Hq, D]
170
+ mLSE: cute.Tensor, # [total_q, Hq] fp32
171
+ seqlen_q: Int32,
172
+ q_tokens_per_group: Int32,
173
+ stream: cuda.CUstream = None,
174
+ ):
175
+ if const_expr(mO_partial.element_type is not Float32):
176
+ raise TypeError("decode O_partial tensor must be Float32")
177
+ if const_expr(mLSE_partial.element_type is not Float32):
178
+ raise TypeError("decode LSE_partial tensor must be Float32")
179
+ if const_expr(mLSE.element_type is not Float32):
180
+ raise TypeError("decode LSE tensor must be Float32")
181
+ if const_expr(mO.element_type != self.dtype):
182
+ raise TypeError("decode O tensor dtype must match kernel dtype")
183
+ if const_expr(mSplitCounts.element_type is not Int32):
184
+ raise TypeError("decode split_counts tensor must be Int32")
185
+ if const_expr(mOIndptr.element_type is not Int32):
186
+ raise TypeError("decode o_indptr tensor must be Int32")
187
+
188
+ mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [
189
+ assume_tensor_aligned(t)
190
+ for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE)
191
+ ]
192
+ self._setup_attributes()
193
+
194
+ @cute.struct
195
+ class SharedStorage:
196
+ sLSE: cute.struct.Align[
197
+ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
198
+ ]
199
+ sMaxValidSplit: cute.struct.Align[
200
+ cute.struct.MemRange[Int32, self.tile_m], 128
201
+ ]
202
+ sO: cute.struct.Align[
203
+ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
204
+ ]
205
+
206
+ total_q = mO.shape[0]
207
+ head_q = mO.shape[1]
208
+ batch = mSplitCounts.shape[0]
209
+ head_divmod = FastDivmodDivisor(head_q)
210
+ grid = (
211
+ cute.ceil_div(seqlen_q * head_q, self.tile_m),
212
+ cute.ceil_div(self.head_dim, self.k_block_size),
213
+ batch,
214
+ )
215
+
216
+ self.kernel(
217
+ mO_partial,
218
+ mLSE_partial,
219
+ mSplitCounts,
220
+ mOIndptr,
221
+ mO,
222
+ mLSE,
223
+ SharedStorage,
224
+ self.smem_layout_lse,
225
+ self.smem_layout_o,
226
+ self.gmem_tiled_copy_O_partial,
227
+ self.gmem_tiled_copy_O,
228
+ self.gmem_tiled_copy_LSE,
229
+ self.s2r_tiled_copy_LSE,
230
+ head_divmod,
231
+ Int32(total_q),
232
+ Int32(head_q),
233
+ seqlen_q,
234
+ q_tokens_per_group,
235
+ ).launch(
236
+ grid=grid,
237
+ block=[self.num_threads, 1, 1],
238
+ smem=SharedStorage.size_in_bytes(),
239
+ stream=stream,
240
+ )
241
+
242
+ @cute.kernel
243
+ def kernel(
244
+ self,
245
+ mO_partial: cute.Tensor,
246
+ mLSE_partial: cute.Tensor,
247
+ mSplitCounts: cute.Tensor,
248
+ mOIndptr: cute.Tensor,
249
+ mO: cute.Tensor,
250
+ mLSE: cute.Tensor,
251
+ SharedStorage: cutlass.Constexpr,
252
+ smem_layout_lse: cute.Layout | cute.ComposedLayout,
253
+ smem_layout_o: cute.Layout,
254
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
255
+ gmem_tiled_copy_O: cute.TiledCopy,
256
+ gmem_tiled_copy_LSE: cute.TiledCopy,
257
+ s2r_tiled_copy_LSE: cute.TiledCopy,
258
+ head_divmod: FastDivmodDivisor,
259
+ total_q: Int32,
260
+ head_q: Int32,
261
+ seqlen_q: Int32,
262
+ q_tokens_per_group: Int32,
263
+ ):
264
+ tidx, _, _ = cute.arch.thread_idx()
265
+ m_block, k_block, batch_idx = cute.arch.block_idx()
266
+
267
+ smem = cutlass.utils.SmemAllocator()
268
+ storage = smem.allocate(SharedStorage)
269
+ sLSE = storage.sLSE.get_tensor(smem_layout_lse)
270
+ sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
271
+ sO = storage.sO.get_tensor(smem_layout_o)
272
+
273
+ split_count = mSplitCounts[batch_idx]
274
+ q_stride = (
275
+ (seqlen_q + q_tokens_per_group - Int32(1))
276
+ // q_tokens_per_group
277
+ ) * q_tokens_per_group
278
+ max_idx = seqlen_q * head_q
279
+
280
+ if m_block * Int32(self.tile_m) < max_idx:
281
+ gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
282
+ tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
283
+ cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m))
284
+ tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
285
+
286
+ for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
287
+ mi = tLSEcLSE[0, 0, m][1]
288
+ idx = m_block * Int32(self.tile_m) + mi
289
+ if idx < max_idx:
290
+ q_idx, q_head = divmod(idx, head_divmod)
291
+ partial_base = mOIndptr[batch_idx] + q_idx
292
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
293
+ si = tLSEcLSE[0, s, 0][0]
294
+ if si < split_count:
295
+ partial_row = partial_base + si * q_stride
296
+ lse_ptr = (
297
+ mLSE_partial.iterator
298
+ + Int64(partial_row) * Int64(head_q)
299
+ + Int64(q_head)
300
+ )
301
+ lse_gmem_ptr = cute.make_ptr(
302
+ Float32,
303
+ lse_ptr.toint(),
304
+ cute.AddressSpace.gmem,
305
+ assumed_align=4,
306
+ )
307
+ lse_src = cute.make_tensor(lse_gmem_ptr, (1,))
308
+ cute.copy(
309
+ gmem_thr_copy_LSE,
310
+ lse_src,
311
+ tLSEsLSE[None, s, m],
312
+ )
313
+ else:
314
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
315
+ else:
316
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
317
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
318
+ cute.arch.cp_async_commit_group()
319
+
320
+ gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
321
+ cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
322
+ tOcO = gmem_thr_copy_O_partial.partition_D(cO)
323
+ tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
324
+
325
+ num_rows = const_expr(cute.size(tOcO, mode=[1]))
326
+ tOqidx = cute.make_rmem_tensor(num_rows, Int32)
327
+ tOhidx = cute.make_rmem_tensor(num_rows, Int32)
328
+ for m in cutlass.range(num_rows, unroll_full=True):
329
+ mi = tOcO[0, m, 0][0]
330
+ idx = m_block * Int32(self.tile_m) + mi
331
+ if idx >= max_idx:
332
+ tOqidx[m] = Int32(0)
333
+ tOhidx[m] = -Int32(1)
334
+ else:
335
+ tOqidx[m], tOhidx[m] = divmod(idx, head_divmod)
336
+
337
+ load_O_partial = partial(
338
+ self.load_O_partial,
339
+ mO_partial,
340
+ mOIndptr,
341
+ gmem_tiled_copy_O_partial,
342
+ tOsO_partial,
343
+ tOqidx,
344
+ tOhidx,
345
+ tOcO,
346
+ batch_idx,
347
+ q_stride,
348
+ split_count,
349
+ head_q,
350
+ k_block,
351
+ )
352
+
353
+ for stage in cutlass.range(self.stages - 1, unroll_full=True):
354
+ if stage < split_count:
355
+ load_O_partial(stage, stage)
356
+ cute.arch.cp_async_commit_group()
357
+
358
+ cute.arch.cp_async_wait_group(self.stages - 1)
359
+ cute.arch.sync_threads()
360
+
361
+ s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
362
+ ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
363
+ ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
364
+ cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
365
+
366
+ lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
367
+ ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
368
+ max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
369
+ assert cute.size(ts2rrLSE, mode=[0]) == 1
370
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
371
+ threads_per_col = const_expr(self.smem_threads_per_col_lse)
372
+ lse_max = cute.arch.warp_reduction_max(
373
+ ts2rrLSE[None, None, m]
374
+ .load()
375
+ .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
376
+ threads_in_group=threads_per_col,
377
+ )
378
+ max_valid_idx = -Int32(1)
379
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
380
+ if ts2rrLSE[0, s, m] != -Float32.inf:
381
+ max_valid_idx = ts2rcLSE[0, s, 0][0]
382
+ max_valid_split[m] = cute.arch.warp_reduction_max(
383
+ max_valid_idx, threads_in_group=threads_per_col
384
+ )
385
+
386
+ lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max
387
+ LOG2_E = Float32(math.log2(math.e))
388
+ lse_sum_cur = Float32(0.0)
389
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
390
+ scale = cute.math.exp2(
391
+ (ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E,
392
+ fastmath=True,
393
+ )
394
+ lse_sum_cur += scale
395
+ ts2rrLSE[0, s, m] = scale
396
+ lse_sum_cur = cute.arch.warp_reduction_sum(
397
+ lse_sum_cur, threads_in_group=threads_per_col
398
+ )
399
+ lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
400
+ inv_sum = (
401
+ Float32(0.0)
402
+ if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur)
403
+ else cute.arch.rcp_approx(lse_sum_cur)
404
+ )
405
+ ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
406
+ cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
407
+
408
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
409
+ if ts2rcLSE[0, 0, m][0] == Int32(0):
410
+ mi = ts2rcLSE[0, 0, m][1]
411
+ if mi < Int32(self.tile_m):
412
+ sMaxValidSplit[mi] = max_valid_split[m]
413
+
414
+ if k_block == Int32(0):
415
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
416
+ if ts2rcLSE[0, 0, m][0] == Int32(0):
417
+ mi = ts2rcLSE[0, 0, m][1]
418
+ idx = m_block * Int32(self.tile_m) + mi
419
+ if idx < max_idx:
420
+ q_idx, q_head = divmod(idx, head_divmod)
421
+ q_abs = batch_idx * seqlen_q + q_idx
422
+ mLSE[q_abs, q_head] = lse_sum[m]
423
+
424
+ cute.arch.sync_threads()
425
+
426
+ thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
427
+ for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
428
+ thr_max_valid_split = max(
429
+ thr_max_valid_split,
430
+ sMaxValidSplit[tOcO[0, m, 0][0]],
431
+ )
432
+
433
+ tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
434
+ tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
435
+ tOrO.fill(Float32(0.0))
436
+
437
+ stage_load = self.stages - 1
438
+ stage_compute = 0
439
+ for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4):
440
+ scale = cute.make_rmem_tensor(num_rows, Float32)
441
+ for m in cutlass.range(num_rows, unroll_full=True):
442
+ scale[m] = sLSE[s, tOcO[0, m, 0][0]]
443
+
444
+ split_to_load = s + Int32(self.stages - 1)
445
+ if split_to_load <= thr_max_valid_split:
446
+ load_O_partial(split_to_load, stage_load)
447
+ cute.arch.cp_async_commit_group()
448
+ stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
449
+
450
+ cute.arch.cp_async_wait_group(self.stages - 1)
451
+ cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
452
+ stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
453
+
454
+ for m in cutlass.range(num_rows, unroll_full=True):
455
+ if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0):
456
+ tOrO[None, m, None].store(
457
+ tOrO[None, m, None].load()
458
+ + scale[m] * tOrO_partial[None, m, None].load().to(Float32)
459
+ )
460
+
461
+ cute.arch.cp_async_wait_group(0)
462
+ cute.arch.sync_threads()
463
+
464
+ rO = cute.make_rmem_tensor_like(tOrO, self.dtype)
465
+ rO.store(tOrO.load().to(self.dtype))
466
+ elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
467
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
468
+ for m in cutlass.range(num_rows, unroll_full=True):
469
+ if tOhidx[m] >= Int32(0):
470
+ q_abs = batch_idx * seqlen_q + tOqidx[m]
471
+ row_ptr = (
472
+ mO.iterator
473
+ + (
474
+ (Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m]))
475
+ * Int64(self.head_dim)
476
+ + Int64(k_block * Int32(self.k_block_size))
477
+ )
478
+ )
479
+ row_gmem_ptr = cute.make_ptr(
480
+ mO.element_type,
481
+ row_ptr.toint(),
482
+ cute.AddressSpace.gmem,
483
+ assumed_align=16,
484
+ )
485
+ mO_row = cute.make_tensor(
486
+ row_gmem_ptr,
487
+ cute.make_layout((self.k_block_size,)),
488
+ )
489
+ mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,))
490
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
491
+ k_idx = tOcO[0, 0, k][1] // elems_per_store
492
+ cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx])
493
+
494
+ @cute.jit
495
+ def load_O_partial(
496
+ self,
497
+ mO_partial: cute.Tensor,
498
+ mOIndptr: cute.Tensor,
499
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
500
+ tOsO_partial: cute.Tensor,
501
+ tOqidx: cute.Tensor,
502
+ tOhidx: cute.Tensor,
503
+ tOcO: cute.Tensor,
504
+ batch_idx: Int32,
505
+ q_stride: Int32,
506
+ split_count: Int32,
507
+ head_q: Int32,
508
+ k_block: Int32,
509
+ split: Int32,
510
+ stage: Int32,
511
+ ) -> None:
512
+ elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
513
+ tOsO_partial_cur = tOsO_partial[None, None, None, stage]
514
+ for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
515
+ if tOhidx[m] >= Int32(0):
516
+ if split < split_count:
517
+ partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m]
518
+ row_ptr = (
519
+ mO_partial.iterator
520
+ + (
521
+ (Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m]))
522
+ * Int64(self.head_dim)
523
+ + Int64(k_block * Int32(self.k_block_size))
524
+ )
525
+ )
526
+ row_gmem_ptr = cute.make_ptr(
527
+ mO_partial.element_type,
528
+ row_ptr.toint(),
529
+ cute.AddressSpace.gmem,
530
+ assumed_align=16,
531
+ )
532
+ mO_partial_row = cute.make_tensor(
533
+ row_gmem_ptr,
534
+ cute.make_layout((self.k_block_size,)),
535
+ )
536
+ mO_partial_row_copy = cute.tiled_divide(
537
+ mO_partial_row, (elems_per_load,))
538
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
539
+ k_idx = tOcO[0, 0, k][1] // elems_per_load
540
+ cute.copy(
541
+ gmem_tiled_copy_O_partial,
542
+ mO_partial_row_copy[None, k_idx],
543
+ tOsO_partial_cur[None, m, k],
544
+ )
545
+ else:
546
+ tOsO_partial_cur[None, m, None].fill(Float32(0.0))
547
+
548
+
549
+ _combine_compile_cache: dict[tuple[object, ...], object] = {}
550
+
551
+
552
+ def _next_power_of_2(x: int) -> int:
553
+ return 1 << (max(int(x), 1) - 1).bit_length()
554
+
555
+
556
+ def run_decode_combine(
557
+ O_partial: torch.Tensor,
558
+ LSE_partial: torch.Tensor,
559
+ split_counts: torch.Tensor,
560
+ o_indptr: torch.Tensor,
561
+ out: torch.Tensor,
562
+ lse: torch.Tensor,
563
+ *,
564
+ seqlen_q: int,
565
+ q_tokens_per_group: int,
566
+ max_split_count: int,
567
+ ) -> None:
568
+ """Launch LDGSTS decode split-KV combine."""
569
+
570
+ if O_partial.dtype != torch.float32:
571
+ raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}")
572
+ if LSE_partial.dtype != torch.float32:
573
+ raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}")
574
+ if lse.dtype != torch.float32:
575
+ raise TypeError(f"lse must be torch.float32, got {lse.dtype}")
576
+ if split_counts.dtype != torch.int32:
577
+ raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}")
578
+ if o_indptr.dtype != torch.int32:
579
+ raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}")
580
+ if out.ndim != 3 or O_partial.ndim != 3:
581
+ raise ValueError("decode combine expects O tensors with shape [rows, heads, D]")
582
+ if LSE_partial.ndim != 2 or lse.ndim != 2:
583
+ raise ValueError("decode combine expects LSE tensors with shape [rows, heads]")
584
+ if out.shape[1:] != O_partial.shape[1:]:
585
+ raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}")
586
+ if lse.shape != out.shape[:2]:
587
+ raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}")
588
+ if LSE_partial.shape != O_partial.shape[:2]:
589
+ raise ValueError(
590
+ f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}"
591
+ )
592
+ if split_counts.ndim != 1 or o_indptr.ndim != 1:
593
+ raise ValueError("split_counts and o_indptr must be rank-1 tensors")
594
+ if o_indptr.shape != (split_counts.shape[0] + 1,):
595
+ raise ValueError(
596
+ f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)"
597
+ )
598
+ seqlen_q = int(seqlen_q)
599
+ q_tokens_per_group = int(q_tokens_per_group)
600
+ if seqlen_q <= 0:
601
+ raise ValueError("seqlen_q must be positive")
602
+ if q_tokens_per_group <= 0:
603
+ raise ValueError("q_tokens_per_group must be positive")
604
+ if out.shape[0] != split_counts.shape[0] * seqlen_q:
605
+ raise ValueError(
606
+ f"out rows {out.shape[0]} must equal batch*seqlen_q "
607
+ f"{split_counts.shape[0]}*{seqlen_q}"
608
+ )
609
+
610
+ max_split_count = int(max_split_count)
611
+ if max_split_count <= 0:
612
+ raise ValueError("max_split_count must be positive")
613
+ if max_split_count > 256:
614
+ raise NotImplementedError(
615
+ f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}"
616
+ )
617
+ max_splits = max(4, _next_power_of_2(max_split_count))
618
+ tile_m = 64
619
+ k_block_size = int(out.shape[-1])
620
+ stages = 2
621
+
622
+ dtype = torch2cute_dtype_map[out.dtype]
623
+ key = (
624
+ "decode_combine_ldgsts",
625
+ out.shape[-1],
626
+ dtype,
627
+ O_partial.dtype,
628
+ seqlen_q,
629
+ q_tokens_per_group,
630
+ tile_m,
631
+ k_block_size,
632
+ max_splits,
633
+ stages,
634
+ )
635
+ if key not in _combine_compile_cache:
636
+ from ....quack.compile_utils import make_fake_tensor
637
+
638
+ total_q = cute.sym_int64()
639
+ batch = cute.sym_int64()
640
+ batch_plus_one = cute.sym_int64()
641
+ partial_rows = cute.sym_int64()
642
+ head_q = cute.sym_int64()
643
+ head_dim = int(out.shape[-1])
644
+ kernel = SparseDecodeForwardCombine(
645
+ dtype=dtype,
646
+ dtype_partial=Float32,
647
+ head_dim=head_dim,
648
+ tile_m=tile_m,
649
+ k_block_size=k_block_size,
650
+ max_splits=max_splits,
651
+ stages=stages,
652
+ )
653
+ _combine_compile_cache[key] = cute.compile(
654
+ kernel,
655
+ make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4),
656
+ make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1),
657
+ make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0),
658
+ make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0),
659
+ make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width),
660
+ make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1),
661
+ Int32(seqlen_q),
662
+ Int32(q_tokens_per_group),
663
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
664
+ options="--enable-tvm-ffi",
665
+ )
666
+
667
+ with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"):
668
+ _combine_compile_cache[key](
669
+ O_partial,
670
+ LSE_partial,
671
+ split_counts,
672
+ o_indptr,
673
+ out,
674
+ lse,
675
+ seqlen_q,
676
+ q_tokens_per_group,
677
+ )
678
+
679
+
680
+ __all__ = ["SparseDecodeForwardCombine", "run_decode_combine"]
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Decode-specific tile scheduler for paged fp8 attention.
5
+
6
+ The pre-schedule step builds a dense worklist over decode KV chunks. Static
7
+ persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id.
8
+ CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``,
9
+ and maps the canceled CTA coordinate back to the same logical task space.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Tuple
14
+
15
+ import cutlass
16
+ import cutlass.cute as cute
17
+ from cutlass import Int32, const_expr
18
+ from cutlass.cute import FastDivmodDivisor
19
+
20
+ from ....quack.cute_dsl_utils import ParamsBase
21
+
22
+ from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo
23
+
24
+
25
+ @dataclass
26
+ class DecodeTileSchedulerArguments(ParamsBase):
27
+ work_capacity: Int32
28
+ num_heads_kv: Int32
29
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
30
+
31
+
32
+ class DecodeTileScheduler:
33
+ """Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks."""
34
+
35
+ @dataclass
36
+ class Params(ParamsBase):
37
+ work_capacity: Int32
38
+ num_heads_kv: Int32
39
+ num_heads_kv_divmod: FastDivmodDivisor
40
+ total_tasks: Int32
41
+ cluster_shape_m: cutlass.Constexpr[int] = 1
42
+ scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
43
+
44
+ def __init__(
45
+ self,
46
+ params: Params,
47
+ task_idx: Int32,
48
+ clc_scheduler=None,
49
+ clc_pipeline=None,
50
+ clc_consumer_state=None,
51
+ clc_response_ptr=None,
52
+ *,
53
+ loc=None,
54
+ ip=None,
55
+ ):
56
+ self.params = params
57
+ self._task_idx = task_idx
58
+ self._clc_scheduler = clc_scheduler
59
+ self._clc_pipeline = clc_pipeline
60
+ self._clc_consumer_state = clc_consumer_state
61
+ self._clc_response_ptr = clc_response_ptr
62
+ self._loc = loc
63
+ self._ip = ip
64
+
65
+ @staticmethod
66
+ def to_underlying_arguments(
67
+ args: DecodeTileSchedulerArguments,
68
+ *,
69
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
70
+ loc=None,
71
+ ip=None,
72
+ ) -> Params:
73
+ assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1"
74
+ total_tasks = args.work_capacity * args.num_heads_kv
75
+ return DecodeTileScheduler.Params(
76
+ args.work_capacity,
77
+ args.num_heads_kv,
78
+ FastDivmodDivisor(args.num_heads_kv),
79
+ total_tasks,
80
+ cluster_shape_m=args.cluster_shape_mn[0],
81
+ scheduling_mode=scheduling_mode,
82
+ )
83
+
84
+ @staticmethod
85
+ def _clc_grid_shape(params: Params):
86
+ return (
87
+ cute.round_up(params.work_capacity, params.cluster_shape_m),
88
+ params.num_heads_kv,
89
+ Int32(1),
90
+ )
91
+
92
+ @staticmethod
93
+ @cute.jit
94
+ def create(
95
+ params: Params,
96
+ clc_response_ptr=None,
97
+ *,
98
+ loc=None,
99
+ ip=None,
100
+ ) -> "DecodeTileScheduler":
101
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
102
+ from cutlass.utils import (
103
+ ClcDynamicPersistentTileScheduler,
104
+ ClcDynamicPersistentTileSchedulerParams,
105
+ )
106
+
107
+ cutlass_params = ClcDynamicPersistentTileSchedulerParams(
108
+ problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params),
109
+ cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
110
+ )
111
+ block_idx = cute.arch.block_idx()
112
+ grid_dim = cute.arch.grid_dim()
113
+ clc_scheduler = ClcDynamicPersistentTileScheduler.create(
114
+ cutlass_params,
115
+ block_idx,
116
+ grid_dim,
117
+ clc_response_ptr,
118
+ )
119
+ return DecodeTileScheduler(
120
+ params,
121
+ block_idx[0],
122
+ clc_scheduler,
123
+ clc_response_ptr=clc_response_ptr,
124
+ loc=loc,
125
+ ip=ip,
126
+ )
127
+
128
+ if const_expr(params.cluster_shape_m == 1):
129
+ task_idx = cute.arch.block_idx()[0]
130
+ else:
131
+ task_idx = cute.arch.cluster_idx()[0]
132
+ return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip)
133
+
134
+ @staticmethod
135
+ def get_grid_shape(
136
+ params: Params,
137
+ *,
138
+ loc=None,
139
+ ip=None,
140
+ ) -> Tuple[Int32, Int32, Int32]:
141
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
142
+ return DecodeTileScheduler._clc_grid_shape(params)
143
+ hardware_info = cutlass.utils.HardwareInfo()
144
+ sm_count = hardware_info.get_device_multiprocessor_count()
145
+ max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
146
+ grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m)
147
+ return (grid_x, Int32(1), Int32(1))
148
+
149
+ @cute.jit
150
+ def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo:
151
+ work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod)
152
+ return WorkTileInfo(
153
+ (Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)),
154
+ is_valid,
155
+ )
156
+
157
+ @cute.jit
158
+ def _clc_work_to_coords(self, work) -> WorkTileInfo:
159
+ work_idx = work.tile_idx[0]
160
+ if const_expr(self.params.cluster_shape_m > 1):
161
+ work_idx = work_idx // self.params.cluster_shape_m
162
+ return WorkTileInfo(
163
+ (
164
+ Int32(work_idx),
165
+ Int32(work.tile_idx[1]),
166
+ Int32(0),
167
+ Int32(0),
168
+ ),
169
+ work.is_valid_tile,
170
+ )
171
+
172
+ @cute.jit
173
+ def _clc_response_to_work(
174
+ self,
175
+ response_stage: Int32,
176
+ *,
177
+ loc=None,
178
+ ip=None,
179
+ ) -> WorkTileInfo:
180
+ # CLC responses are 16B opaque records. The scheduler warp can query
181
+ # the next stage before all consumer warps have read the current one,
182
+ # so each pipeline stage needs its own response slot.
183
+ response_ptr = self._clc_response_ptr + response_stage * Int32(4)
184
+ m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response(
185
+ response_ptr, loc=loc, ip=ip)
186
+ cute.arch.fence_proxy("async.shared", space="cta")
187
+ cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32(
188
+ self.params.cluster_shape_m)
189
+ return WorkTileInfo(
190
+ (
191
+ Int32(m_idx) + cta_idx_in_cluster,
192
+ Int32(n_idx),
193
+ Int32(l_idx),
194
+ Int32(0),
195
+ ),
196
+ is_valid,
197
+ )
198
+
199
+ @cute.jit
200
+ def get_current_work(
201
+ self,
202
+ response_stage: Int32 = Int32(0),
203
+ *,
204
+ loc=None,
205
+ ip=None,
206
+ ) -> WorkTileInfo:
207
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
208
+ work = self._clc_response_to_work(
209
+ response_stage, loc=loc, ip=ip)
210
+ self._task_idx = (
211
+ work.tile_idx[0] * self.params.num_heads_kv
212
+ + work.tile_idx[1]
213
+ )
214
+ return self._clc_work_to_coords(work)
215
+ is_valid = self._task_idx < self.params.total_tasks
216
+ return self._task_to_work(self._task_idx, is_valid)
217
+
218
+ @cute.jit
219
+ def initial_work_tile_info(self, *, loc=None, ip=None):
220
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
221
+ work = self._clc_scheduler.initial_work_tile_info()
222
+ self._task_idx = (
223
+ work.tile_idx[0] * self.params.num_heads_kv
224
+ + work.tile_idx[1]
225
+ )
226
+ return self._clc_work_to_coords(work)
227
+ return self.get_current_work(loc=loc, ip=ip)
228
+
229
+ def prefetch_next_work(self, *, loc=None, ip=None):
230
+ pass
231
+
232
+ def advance_to_next_work(
233
+ self,
234
+ *,
235
+ loc=None,
236
+ ip=None,
237
+ mbarrier_addr=None,
238
+ response_stage: Int32 = Int32(0),
239
+ ):
240
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
241
+ assert mbarrier_addr is not None
242
+ response_ptr = self._clc_response_ptr + response_stage * Int32(4)
243
+ with cute.arch.elect_one():
244
+ cute.arch.issue_clc_query(
245
+ mbarrier_addr, response_ptr, loc=loc, ip=ip)
246
+ else:
247
+ assert mbarrier_addr is None
248
+ if const_expr(self.params.cluster_shape_m == 1):
249
+ self._task_idx += cute.arch.grid_dim()[0]
250
+ else:
251
+ self._task_idx += cute.arch.cluster_dim()[0]
252
+
253
+ def consumer_advance(self, *, loc=None, ip=None):
254
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
255
+ response_stage = self._clc_consumer_state.index
256
+ self._clc_pipeline.consumer_wait(self._clc_consumer_state)
257
+ work_tile = self.get_current_work(response_stage=response_stage)
258
+ self._clc_pipeline.consumer_release(self._clc_consumer_state)
259
+ self._clc_consumer_state.advance()
260
+ return work_tile
261
+ self.advance_to_next_work()
262
+ return self.get_current_work()
263
+
264
+ def set_clc_pipeline(self, clc_pipeline, clc_consumer_state):
265
+ self._clc_pipeline = clc_pipeline
266
+ self._clc_consumer_state = clc_consumer_state
267
+
268
+ def producer_tail(self, *, loc=None, ip=None):
269
+ pass
270
+
271
+ def __extract_mlir_values__(self):
272
+ values, self._values_pos = [], []
273
+ objs = [self.params, self._task_idx]
274
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
275
+ objs += [
276
+ self._clc_scheduler,
277
+ self._clc_pipeline,
278
+ self._clc_consumer_state,
279
+ self._clc_response_ptr,
280
+ ]
281
+ for obj in objs:
282
+ obj_values = cutlass.extract_mlir_values(obj)
283
+ values += obj_values
284
+ self._values_pos.append(len(obj_values))
285
+ return values
286
+
287
+ def __new_from_mlir_values__(self, values):
288
+ obj_list = []
289
+ objs = [self.params, self._task_idx]
290
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
291
+ objs += [
292
+ self._clc_scheduler,
293
+ self._clc_pipeline,
294
+ self._clc_consumer_state,
295
+ self._clc_response_ptr,
296
+ ]
297
+ for obj, n_items in zip(objs, self._values_pos):
298
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
299
+ values = values[n_items:]
300
+ return DecodeTileScheduler(*obj_list, loc=self._loc)
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """Sparse k2q CSR builder for SM100.
5
+
6
+ Thin dispatcher that calls the CUDA C++ kernel pipeline in
7
+ ``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and
8
+ ``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than
9
+ silently falling back to a torch-reference path.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Optional
15
+
16
+ import torch
17
+
18
+ from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL
19
+
20
+
21
+ _SUPPORTED_TOPK = (4, 8, 16, 32)
22
+ _SUPPORTED_BLK_KV = 128
23
+
24
+
25
+ def _ceil_div(x: int, y: int) -> int:
26
+ return (x + y - 1) // y
27
+
28
+
29
+ class SparseK2qCsrBuilderSm100:
30
+ """Build the k2q CSR reverse index for sparse attention on SM100.
31
+
32
+ The public API matches the historical CUTE DSL builder so callers
33
+ (``sparse_index_utils.build_k2q_csr``, attention kernels) need no
34
+ changes. Internally the kernel pipeline runs five CUDA C++ kernels:
35
+ ``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem``
36
+ -> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``).
37
+ """
38
+
39
+ def __init__(self) -> None:
40
+ # No persistent state — the JIT-compiled extension is loaded
41
+ # lazily by ``src.sm100.build_k2q_csr`` on first call.
42
+ self._run = None
43
+ self._run_with_schedule = None
44
+
45
+ def _ensure_loaded(self) -> None:
46
+ if self._run is None:
47
+ from ...src.sm100.build_k2q_csr import (
48
+ run_build_k2q_csr,
49
+ run_build_k2q_csr_with_schedule,
50
+ )
51
+ self._run = run_build_k2q_csr
52
+ self._run_with_schedule = run_build_k2q_csr_with_schedule
53
+
54
+ def __call__(
55
+ self,
56
+ q2k_indices: torch.Tensor,
57
+ cu_seqlens_q: torch.Tensor,
58
+ cu_seqlens_k: torch.Tensor,
59
+ *,
60
+ total_k: int,
61
+ blk_kv: int = 128,
62
+ max_seqlen_k: Optional[int] = None,
63
+ max_seqlen_q: Optional[int] = None,
64
+ total_rows: Optional[int] = None,
65
+ qhead_per_kv: int = 1,
66
+ return_schedule: bool = False,
67
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]:
68
+ # ---- Validation ----------------------------------------------------
69
+ if blk_kv != _SUPPORTED_BLK_KV:
70
+ raise ValueError(
71
+ f"SparseK2qCsrBuilderSm100 only supports blk_kv == "
72
+ f"{_SUPPORTED_BLK_KV}, got {blk_kv}"
73
+ )
74
+ if q2k_indices.dtype != torch.int32:
75
+ raise TypeError(
76
+ f"q2k_indices must be torch.int32, got {q2k_indices.dtype}"
77
+ )
78
+ if q2k_indices.ndim != 3:
79
+ raise ValueError(
80
+ f"q2k_indices must be rank-3 [head_kv, total_q, topK], "
81
+ f"got shape {tuple(q2k_indices.shape)}"
82
+ )
83
+ if not q2k_indices.is_contiguous():
84
+ raise ValueError("q2k_indices must be contiguous")
85
+ if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32:
86
+ raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32")
87
+ if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1:
88
+ raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1")
89
+ if cu_seqlens_q.shape != cu_seqlens_k.shape:
90
+ raise ValueError(
91
+ "cu_seqlens_q and cu_seqlens_k must share shape [B + 1]"
92
+ )
93
+ if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda):
94
+ raise ValueError("all inputs must be CUDA tensors")
95
+ if (
96
+ q2k_indices.device != cu_seqlens_q.device
97
+ or q2k_indices.device != cu_seqlens_k.device
98
+ ):
99
+ raise ValueError("all inputs must share a device")
100
+ if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous():
101
+ raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous")
102
+
103
+ total_k = int(total_k)
104
+ if total_k < 0:
105
+ raise ValueError(f"total_k must be non-negative, got {total_k}")
106
+
107
+ head_kv, total_q, topk = (int(v) for v in q2k_indices.shape)
108
+ if topk not in _SUPPORTED_TOPK:
109
+ raise ValueError(
110
+ f"SparseK2qCsrBuilderSm100 only supports topK in "
111
+ f"{_SUPPORTED_TOPK}, got {topk}"
112
+ )
113
+
114
+ batch = int(cu_seqlens_q.shape[0] - 1)
115
+ if batch < 0:
116
+ raise ValueError("cu_seqlens tensors must have shape [B + 1]")
117
+ if return_schedule and max_seqlen_k is None:
118
+ raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True")
119
+ max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k
120
+ max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv)
121
+ if total_rows is not None:
122
+ total_rows = int(total_rows)
123
+ elif total_k % blk_kv == 0:
124
+ total_rows = total_k // blk_kv
125
+ else:
126
+ total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv)
127
+ if total_rows < 0:
128
+ raise ValueError(f"total_rows must be non-negative, got {total_rows}")
129
+ total_rows = max(total_rows, 0)
130
+ nnz_upper_bound = total_q * topk
131
+ qhead_per_kv = int(qhead_per_kv)
132
+ if qhead_per_kv <= 0:
133
+ raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}")
134
+ if return_schedule:
135
+ if max_seqlen_q is None:
136
+ raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True")
137
+ max_seqlen_q = int(max_seqlen_q)
138
+
139
+ # ---- Output tensors ------------------------------------------------
140
+ device = q2k_indices.device
141
+ k2q_row_ptr = torch.empty(
142
+ (head_kv, total_rows + 1), dtype=torch.int32, device=device,
143
+ )
144
+ k2q_q_indices = torch.empty(
145
+ (head_kv, nnz_upper_bound), dtype=torch.int32, device=device,
146
+ )
147
+ schedule = None
148
+ if return_schedule:
149
+ target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta(
150
+ total_q=total_q,
151
+ topk=topk,
152
+ blk_kv=blk_kv,
153
+ head_kv=head_kv,
154
+ qhead_per_kv=qhead_per_kv,
155
+ device=device,
156
+ )
157
+ scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity(
158
+ total_rows=total_rows,
159
+ total_q=total_q,
160
+ topk=topk,
161
+ head_kv=head_kv,
162
+ target_q_per_cta=target_q_per_cta,
163
+ )
164
+ scheduler_metadata = torch.empty(
165
+ (scheduler_metadata_capacity, 6), dtype=torch.int32, device=device
166
+ )
167
+ work_count = torch.empty((1,), dtype=torch.int32, device=device)
168
+ qsplit_indices = torch.empty_like(k2q_q_indices)
169
+ split_counts = torch.empty(
170
+ (total_q, head_kv), dtype=torch.int32, device=device
171
+ )
172
+ schedule = SparseAttentionSchedule(
173
+ enabled=True,
174
+ scheduler_metadata=scheduler_metadata,
175
+ work_count=work_count,
176
+ qsplit_indices=qsplit_indices,
177
+ split_counts=split_counts,
178
+ target_q_per_cta=target_q_per_cta,
179
+ )
180
+
181
+ # Empty workload short-circuit (the CUDA path also handles this,
182
+ # but doing it here saves a JIT load for trivial calls).
183
+ if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0:
184
+ k2q_row_ptr.zero_()
185
+ k2q_q_indices.fill_(-1)
186
+ if schedule is not None:
187
+ schedule.work_count.zero_()
188
+ schedule.split_counts.zero_()
189
+ return k2q_row_ptr, k2q_q_indices, schedule
190
+ return k2q_row_ptr, k2q_q_indices
191
+
192
+ self._ensure_loaded()
193
+ with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"):
194
+ if schedule is None:
195
+ self._run(
196
+ q2k_indices,
197
+ cu_seqlens_q,
198
+ cu_seqlens_k,
199
+ k2q_row_ptr,
200
+ k2q_q_indices,
201
+ topk,
202
+ blk_kv,
203
+ total_rows,
204
+ max_kv_blocks,
205
+ )
206
+ else:
207
+ self._run_with_schedule(
208
+ q2k_indices,
209
+ cu_seqlens_q,
210
+ cu_seqlens_k,
211
+ k2q_row_ptr,
212
+ k2q_q_indices,
213
+ schedule.scheduler_metadata,
214
+ schedule.work_count,
215
+ schedule.qsplit_indices,
216
+ schedule.split_counts,
217
+ topk,
218
+ blk_kv,
219
+ total_rows,
220
+ max_kv_blocks,
221
+ schedule.target_q_per_cta,
222
+ schedule.work_capacity,
223
+ max_seqlen_q,
224
+ )
225
+ if schedule is not None:
226
+ return k2q_row_ptr, k2q_q_indices, schedule
227
+ return k2q_row_ptr, k2q_q_indices