Kernels:
Trusted publisher
Uploaded using `kernel-builder`.
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch211-cxx11-cu128-x86_64-linux/__init__.py +60 -0
- build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so +3 -0
- build/torch211-cxx11-cu128-x86_64-linux/_ops.py +9 -0
- build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py +1061 -0
- build/torch211-cxx11-cu128-x86_64-linux/interface.py +2011 -0
- build/torch211-cxx11-cu128-x86_64-linux/metadata.json +71 -0
- build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore +1 -0
- build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py +26 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py +532 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py +19 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py +890 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py +104 -0
- build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py +295 -0
- build/torch211-cxx11-cu128-x86_64-linux/quantize.py +362 -0
- build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py +411 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py +3 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py +3 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py +72 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py +74 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py +1093 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py +61 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py +1179 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py +190 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py +22 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py +189 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py +304 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py +22 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py +320 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py +67 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py +372 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py +203 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py +498 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py +967 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py +515 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py +1088 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py +4 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py +103 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py +193 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py +1956 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py +8 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py +1498 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py +95 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py +0 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py +112 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py +680 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py +300 -0
- build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py +227 -0
build/torch211-cxx11-cu128-x86_64-linux/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""MiniMax Sparse Attention (MSA) CuTe-DSL kernels for NVIDIA SM100.
|
| 5 |
+
|
| 6 |
+
Hub-kernel packaging of the CuTe-DSL sparse attention stack from
|
| 7 |
+
https://github.com/MiniMax-AI/MSA (``python/fmha_sm100/cute``). The
|
| 8 |
+
host-side helper kernels (CSR builder, decode scheduler) are precompiled
|
| 9 |
+
Torch ops; the attention kernels are compiled at runtime through
|
| 10 |
+
nvidia-cutlass-dsl.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
# Sparse attention forward / decode.
|
| 14 |
+
from .interface import (
|
| 15 |
+
SparseDecodePagedAttentionWrapper,
|
| 16 |
+
sparse_atten_func,
|
| 17 |
+
sparse_atten_nvfp4_kv_func,
|
| 18 |
+
sparse_decode_atten_func,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# CSR + schedule construction.
|
| 22 |
+
from .sparse_index_utils import build_k2q_csr
|
| 23 |
+
|
| 24 |
+
# SM100 fused CSR builder.
|
| 25 |
+
from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100
|
| 26 |
+
|
| 27 |
+
# FP4 block-score indexer. Returns per-(Hq, kv_block, q) max scores; topK
|
| 28 |
+
# selection + q2k construction remain caller-owned downstream steps.
|
| 29 |
+
from .fp4_indexer_interface import fp4_indexer_block_scores
|
| 30 |
+
|
| 31 |
+
# NVFP4 quantization helpers used to feed the FP4 indexer / NVFP4 attention.
|
| 32 |
+
from .quantize import (
|
| 33 |
+
Nvfp4QuantizedTensor,
|
| 34 |
+
dequantize_nvfp4_128x4_to_bf16,
|
| 35 |
+
nvfp4_global_scale_from_amax,
|
| 36 |
+
quantize_bf16_to_nvfp4_128x4,
|
| 37 |
+
quantize_kv_bf16_to_nvfp4_128x4,
|
| 38 |
+
swizzle_nvfp4_scale_to_128x4,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
__version__ = "0.1.1"
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
# attention
|
| 45 |
+
"sparse_atten_func",
|
| 46 |
+
"sparse_atten_nvfp4_kv_func",
|
| 47 |
+
"sparse_decode_atten_func",
|
| 48 |
+
"SparseDecodePagedAttentionWrapper",
|
| 49 |
+
# indexing / CSR
|
| 50 |
+
"fp4_indexer_block_scores",
|
| 51 |
+
"build_k2q_csr",
|
| 52 |
+
"SparseK2qCsrBuilderSm100",
|
| 53 |
+
# nvfp4 quantization helpers
|
| 54 |
+
"Nvfp4QuantizedTensor",
|
| 55 |
+
"quantize_bf16_to_nvfp4_128x4",
|
| 56 |
+
"quantize_kv_bf16_to_nvfp4_128x4",
|
| 57 |
+
"dequantize_nvfp4_128x4_to_bf16",
|
| 58 |
+
"swizzle_nvfp4_scale_to_128x4",
|
| 59 |
+
"nvfp4_global_scale_from_amax",
|
| 60 |
+
]
|
build/torch211-cxx11-cu128-x86_64-linux/_msa_cuda_09d7851.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8dcd8c86e512f3ddd5acb95f6fdcad3cfaa1579bb6f874a714fba066e6877161
|
| 3 |
+
size 1169368
|
build/torch211-cxx11-cu128-x86_64-linux/_ops.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import _msa_cuda_09d7851
|
| 3 |
+
ops = torch.ops._msa_cuda_09d7851
|
| 4 |
+
|
| 5 |
+
def add_op_namespace_prefix(op_name: str):
|
| 6 |
+
"""
|
| 7 |
+
Prefix op by namespace.
|
| 8 |
+
"""
|
| 9 |
+
return f"_msa_cuda_09d7851::{op_name}"
|
build/torch211-cxx11-cu128-x86_64-linux/fp4_indexer_interface.py
ADDED
|
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Public FP4 sparse-attention indexer block-score interface."""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import cuda.bindings.driver as cuda
|
| 11 |
+
import cutlass
|
| 12 |
+
import cutlass.cute as cute
|
| 13 |
+
import torch
|
| 14 |
+
from cutlass import Int32
|
| 15 |
+
from cutlass.cute.runtime import make_ptr
|
| 16 |
+
|
| 17 |
+
from .src.sm100.fp4_indexer import (
|
| 18 |
+
Fp4FormatSpec,
|
| 19 |
+
Fp4IndexerDecodePackedQSm100,
|
| 20 |
+
Fp4IndexerDecodeQPackSm100,
|
| 21 |
+
Fp4IndexerScaleReorderSm100,
|
| 22 |
+
Fp4IndexerStagedMmaSm100,
|
| 23 |
+
_BLOCK_K,
|
| 24 |
+
_DECODE_K_TILES_PER_CTA,
|
| 25 |
+
_DECODE_PACK_Q_LEN,
|
| 26 |
+
_DECODE_QHEAD_PER_KV,
|
| 27 |
+
_FP4_PACKED_D_BYTES,
|
| 28 |
+
_HEAD_DIM,
|
| 29 |
+
_MMA_TILER_MN,
|
| 30 |
+
_PAGE_SIZE,
|
| 31 |
+
ceil_div,
|
| 32 |
+
k_tiles_per_cta_for,
|
| 33 |
+
normalize_fp4_format,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_PUBLIC_SCALE_LAYOUT = "public"
|
| 38 |
+
_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma"
|
| 39 |
+
_FP4_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _device_arch(device: torch.device) -> tuple[int, int]:
|
| 43 |
+
major, minor = torch.cuda.get_device_capability(device)
|
| 44 |
+
return int(major), int(minor)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _supports_tmem_load_red(device_arch: tuple[int, int]) -> bool:
|
| 48 |
+
return device_arch >= (10, 3)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_scale_layout(scale_layout: str) -> str:
|
| 52 |
+
"""Normalize and validate FP4 indexer scale layout mode.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
scale_layout : str
|
| 57 |
+
Either ``"public"`` for logical scale tensors or ``"preordered_mma"``
|
| 58 |
+
for tensors already laid out with ``fp4_indexer_mma_scale_storage_*``.
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
str
|
| 63 |
+
The normalized scale layout string.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
scale_layout = str(scale_layout)
|
| 67 |
+
if scale_layout not in (_PUBLIC_SCALE_LAYOUT, _PREORDERED_MMA_SCALE_LAYOUT):
|
| 68 |
+
raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {scale_layout!r}")
|
| 69 |
+
return scale_layout
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _causal_compact_task_count(q_len: int, k_len: int, k_tiles_per_cta: int) -> int:
|
| 73 |
+
if q_len <= 0 or k_len <= 0:
|
| 74 |
+
return 0
|
| 75 |
+
q_tile_count = ceil_div(q_len, _MMA_TILER_MN[0])
|
| 76 |
+
k_group_count = ceil_div(ceil_div(k_len, _PAGE_SIZE), k_tiles_per_cta)
|
| 77 |
+
group_tokens = k_tiles_per_cta * _BLOCK_K
|
| 78 |
+
causal_offset = int(k_len) - int(q_len)
|
| 79 |
+
tasks = 0
|
| 80 |
+
for q_tile_idx in range(q_tile_count):
|
| 81 |
+
q_tile_start = q_tile_idx * _MMA_TILER_MN[0]
|
| 82 |
+
q_tile_last = min(q_tile_start + _MMA_TILER_MN[0] - 1, int(q_len) - 1)
|
| 83 |
+
visible_limit = q_tile_last + causal_offset
|
| 84 |
+
if visible_limit >= 0:
|
| 85 |
+
tasks += min(k_group_count, visible_limit // group_tokens + 1)
|
| 86 |
+
return tasks
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _causal_compact_task_bound(max_q_len: int, max_k_len: int, k_tiles_per_cta: int) -> int:
|
| 90 |
+
"""Conservative X-grid bound for per-batch causal prefill compact mapping."""
|
| 91 |
+
|
| 92 |
+
if max_q_len <= 0 or max_k_len <= 0:
|
| 93 |
+
return 0
|
| 94 |
+
q_tile_count = ceil_div(max_q_len, _MMA_TILER_MN[0])
|
| 95 |
+
candidates = {int(max_q_len)}
|
| 96 |
+
for q_tile_idx in range(q_tile_count):
|
| 97 |
+
q_len = q_tile_idx * _MMA_TILER_MN[0] + 1
|
| 98 |
+
if q_len <= max_q_len:
|
| 99 |
+
candidates.add(q_len)
|
| 100 |
+
return max(_causal_compact_task_count(q_len, max_k_len, k_tiles_per_cta) for q_len in candidates)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _require_cuda_tensor(tensor: torch.Tensor, *, name: str) -> None:
|
| 104 |
+
if not tensor.is_cuda:
|
| 105 |
+
raise ValueError(f"{name} must be a CUDA tensor")
|
| 106 |
+
if not tensor.is_contiguous():
|
| 107 |
+
raise ValueError(f"{name} must be contiguous")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _require_int32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None:
|
| 111 |
+
if tensor.device != device:
|
| 112 |
+
raise ValueError(f"{name} must be on the same CUDA device")
|
| 113 |
+
if tensor.dtype != torch.int32:
|
| 114 |
+
raise TypeError(f"{name} must be torch.int32")
|
| 115 |
+
if tensor.ndim != 1:
|
| 116 |
+
raise ValueError(f"{name} must be rank-1")
|
| 117 |
+
if not tensor.is_contiguous():
|
| 118 |
+
raise ValueError(f"{name} must be contiguous")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _require_fp4_packed_dtype(tensor: torch.Tensor, *, name: str) -> None:
|
| 122 |
+
fp4_x2_dtype = getattr(torch, "float4_e2m1fn_x2", None)
|
| 123 |
+
allowed = {torch.uint8, torch.int8}
|
| 124 |
+
if fp4_x2_dtype is not None:
|
| 125 |
+
allowed.add(fp4_x2_dtype)
|
| 126 |
+
if tensor.dtype not in allowed:
|
| 127 |
+
raise TypeError(f"{name} must use packed FP4 storage dtype, got {tensor.dtype}")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _as_fp4_thd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor:
|
| 131 |
+
if tensor.ndim != 3:
|
| 132 |
+
raise ValueError(f"{name} must have shape [total_q, Hq, 64]")
|
| 133 |
+
if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES:
|
| 134 |
+
raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128")
|
| 135 |
+
_require_fp4_packed_dtype(tensor, name=name)
|
| 136 |
+
if tensor.dtype == torch.uint8:
|
| 137 |
+
return tensor
|
| 138 |
+
return tensor.view(torch.uint8)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _as_fp4_paged_hnd_bytes(tensor: torch.Tensor, *, name: str) -> torch.Tensor:
|
| 142 |
+
if tensor.ndim != 4:
|
| 143 |
+
raise ValueError(f"{name} must have shape [total_pages, Hk, 128, 64]")
|
| 144 |
+
if int(tensor.shape[-2]) != _PAGE_SIZE:
|
| 145 |
+
raise ValueError(f"{name}.shape[-2] must be 128")
|
| 146 |
+
if int(tensor.shape[-1]) != _FP4_PACKED_D_BYTES:
|
| 147 |
+
raise ValueError(f"{name}.shape[-1] must be 64 packed bytes for D=128")
|
| 148 |
+
_require_fp4_packed_dtype(tensor, name=name)
|
| 149 |
+
if tensor.dtype == torch.uint8:
|
| 150 |
+
return tensor
|
| 151 |
+
return tensor.view(torch.uint8)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def validate_q_scale_thg(
|
| 155 |
+
scale: torch.Tensor,
|
| 156 |
+
*,
|
| 157 |
+
name: str,
|
| 158 |
+
fmt: Fp4FormatSpec,
|
| 159 |
+
total_q: int,
|
| 160 |
+
heads: int,
|
| 161 |
+
) -> None:
|
| 162 |
+
"""Validate public Q FP4 scale layout ``[total_q, Hq, G]``.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
scale : torch.Tensor
|
| 167 |
+
Logical Q scale tensor.
|
| 168 |
+
name : str
|
| 169 |
+
Name used in validation error messages.
|
| 170 |
+
fmt : Fp4FormatSpec
|
| 171 |
+
FP4 format specification from ``normalize_fp4_format``.
|
| 172 |
+
total_q : int
|
| 173 |
+
Total query token count.
|
| 174 |
+
heads : int
|
| 175 |
+
Number of Q heads.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
expected = (int(total_q), int(heads), fmt.scale_groups)
|
| 179 |
+
if tuple(scale.shape) != expected:
|
| 180 |
+
raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}")
|
| 181 |
+
if scale.dtype != fmt.torch_scale_dtype:
|
| 182 |
+
raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
|
| 183 |
+
if not scale.is_contiguous():
|
| 184 |
+
raise ValueError(f"{name} must be contiguous")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def validate_k_scale_phsg(
|
| 188 |
+
scale: torch.Tensor,
|
| 189 |
+
*,
|
| 190 |
+
name: str,
|
| 191 |
+
fmt: Fp4FormatSpec,
|
| 192 |
+
page_count: int,
|
| 193 |
+
heads: int,
|
| 194 |
+
) -> None:
|
| 195 |
+
"""Validate public K FP4 scale layout ``[page_count, Hk, 128, G]``.
|
| 196 |
+
|
| 197 |
+
Parameters
|
| 198 |
+
----------
|
| 199 |
+
scale : torch.Tensor
|
| 200 |
+
Logical K scale tensor.
|
| 201 |
+
name : str
|
| 202 |
+
Name used in validation error messages.
|
| 203 |
+
fmt : Fp4FormatSpec
|
| 204 |
+
FP4 format specification from ``normalize_fp4_format``.
|
| 205 |
+
page_count : int
|
| 206 |
+
Number of physical KV pages.
|
| 207 |
+
heads : int
|
| 208 |
+
Number of KV heads.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
expected = (int(page_count), int(heads), _PAGE_SIZE, fmt.scale_groups)
|
| 212 |
+
if tuple(scale.shape) != expected:
|
| 213 |
+
raise ValueError(f"{name} must have shape {expected}, got {tuple(scale.shape)}")
|
| 214 |
+
if scale.dtype != fmt.torch_scale_dtype:
|
| 215 |
+
raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
|
| 216 |
+
if not scale.is_contiguous():
|
| 217 |
+
raise ValueError(f"{name} must be contiguous")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def fp4_indexer_mma_scale_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
|
| 221 |
+
"""Return semantic MMA scale view shape ``(32,4,restM,4,restG,L)``."""
|
| 222 |
+
|
| 223 |
+
spec = normalize_fp4_format(fp4_format)
|
| 224 |
+
return (32, 4, ceil_div(mn, 128), 4, ceil_div(spec.scale_groups, 4), int(l))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def fp4_indexer_mma_scale_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
|
| 228 |
+
"""Return element strides for ``fp4_indexer_mma_scale_shape``."""
|
| 229 |
+
|
| 230 |
+
spec = normalize_fp4_format(fp4_format)
|
| 231 |
+
rest_m = ceil_div(mn, 128)
|
| 232 |
+
rest_g = ceil_div(spec.scale_groups, 4)
|
| 233 |
+
return (16, 4, 512 * rest_g, 1, 512, 512 * rest_m * rest_g)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def fp4_indexer_mma_scale_storage_shape(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
|
| 237 |
+
"""Return contiguous storage shape for preordered MMA scales."""
|
| 238 |
+
|
| 239 |
+
spec = normalize_fp4_format(fp4_format)
|
| 240 |
+
return (int(l), ceil_div(mn, 128), ceil_div(spec.scale_groups, 4), 32, 4, 4)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def fp4_indexer_mma_scale_storage_stride(mn: int, l: int, *, fp4_format: str) -> tuple[int, int, int, int, int, int]:
|
| 244 |
+
"""Return element strides for ``fp4_indexer_mma_scale_storage_shape``."""
|
| 245 |
+
|
| 246 |
+
spec = normalize_fp4_format(fp4_format)
|
| 247 |
+
rest_m = ceil_div(mn, 128)
|
| 248 |
+
rest_g = ceil_div(spec.scale_groups, 4)
|
| 249 |
+
return (512 * rest_m * rest_g, 512 * rest_g, 512, 16, 4, 1)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def validate_mma_scale_storage(
|
| 253 |
+
scale: torch.Tensor,
|
| 254 |
+
*,
|
| 255 |
+
name: str,
|
| 256 |
+
fmt: Fp4FormatSpec,
|
| 257 |
+
mn: int,
|
| 258 |
+
l: int,
|
| 259 |
+
) -> None:
|
| 260 |
+
"""Validate preordered MMA scale storage expected by the FP4 indexer.
|
| 261 |
+
|
| 262 |
+
Parameters
|
| 263 |
+
----------
|
| 264 |
+
scale : torch.Tensor
|
| 265 |
+
Tensor view whose shape/stride should match
|
| 266 |
+
``fp4_indexer_mma_scale_storage_shape`` and
|
| 267 |
+
``fp4_indexer_mma_scale_storage_stride``.
|
| 268 |
+
name : str
|
| 269 |
+
Name used in validation error messages.
|
| 270 |
+
fmt : Fp4FormatSpec
|
| 271 |
+
FP4 format specification from ``normalize_fp4_format``.
|
| 272 |
+
mn : int
|
| 273 |
+
Logical M/N extent of the scale domain.
|
| 274 |
+
l : int
|
| 275 |
+
Logical batch/head extent folded into the final layout dimension.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
expected_shape = fp4_indexer_mma_scale_storage_shape(mn, l, fp4_format=fmt.name)
|
| 279 |
+
expected_stride = fp4_indexer_mma_scale_storage_stride(mn, l, fp4_format=fmt.name)
|
| 280 |
+
if tuple(scale.shape) != expected_shape:
|
| 281 |
+
raise ValueError(f"{name} must have MMA storage shape {expected_shape}, got {tuple(scale.shape)}")
|
| 282 |
+
if tuple(scale.stride()) != expected_stride:
|
| 283 |
+
raise ValueError(f"{name} must have MMA storage stride {expected_stride}, got {tuple(scale.stride())}")
|
| 284 |
+
if scale.dtype != fmt.torch_scale_dtype:
|
| 285 |
+
raise TypeError(f"{name} must have dtype {fmt.torch_scale_dtype}, got {scale.dtype}")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _empty_mma_scale_tensor(
|
| 289 |
+
*,
|
| 290 |
+
mn: int,
|
| 291 |
+
l: int,
|
| 292 |
+
spec: Fp4FormatSpec,
|
| 293 |
+
device: torch.device,
|
| 294 |
+
) -> torch.Tensor:
|
| 295 |
+
rest_m = ceil_div(mn, 128)
|
| 296 |
+
rest_g = ceil_div(spec.scale_groups, 4)
|
| 297 |
+
storage = torch.empty(
|
| 298 |
+
(int(l), rest_m, rest_g, 32, 4, 4),
|
| 299 |
+
dtype=spec.torch_scale_dtype,
|
| 300 |
+
device=device,
|
| 301 |
+
)
|
| 302 |
+
return storage.permute(3, 4, 1, 5, 2, 0)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _compile_fp4_scale_reorder_kernel(
|
| 306 |
+
*,
|
| 307 |
+
fmt: Fp4FormatSpec,
|
| 308 |
+
q_scale_ptr: cute.Pointer,
|
| 309 |
+
k_scale_ptr: cute.Pointer,
|
| 310 |
+
q_scale_mma_ptr: cute.Pointer,
|
| 311 |
+
k_scale_mma_ptr: cute.Pointer,
|
| 312 |
+
problem_size: tuple,
|
| 313 |
+
stream: cuda.CUstream,
|
| 314 |
+
):
|
| 315 |
+
key = (
|
| 316 |
+
"fp4_indexer_scale_reorder_sm100_1cta",
|
| 317 |
+
fmt.name,
|
| 318 |
+
)
|
| 319 |
+
if key not in _FP4_COMPILE_CACHE:
|
| 320 |
+
kernel = Fp4IndexerScaleReorderSm100(fmt=fmt.name)
|
| 321 |
+
_FP4_COMPILE_CACHE[key] = cute.compile(
|
| 322 |
+
kernel,
|
| 323 |
+
q_scale_ptr,
|
| 324 |
+
k_scale_ptr,
|
| 325 |
+
q_scale_mma_ptr,
|
| 326 |
+
k_scale_mma_ptr,
|
| 327 |
+
problem_size,
|
| 328 |
+
stream,
|
| 329 |
+
)
|
| 330 |
+
return _FP4_COMPILE_CACHE[key]
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def fp4_indexer_reorder_scales_for_mma_cute(
|
| 334 |
+
q_scale: torch.Tensor,
|
| 335 |
+
k_scale: torch.Tensor,
|
| 336 |
+
*,
|
| 337 |
+
fp4_format: str,
|
| 338 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 339 |
+
"""Reorder public Q/K FP4 scales to MMA-friendly storage.
|
| 340 |
+
|
| 341 |
+
Parameters
|
| 342 |
+
----------
|
| 343 |
+
q_scale : torch.Tensor
|
| 344 |
+
Public Q scale tensor with shape ``[total_q, Hq, G]``.
|
| 345 |
+
k_scale : torch.Tensor
|
| 346 |
+
Public K scale tensor with shape ``[page_count, Hk, 128, G]``.
|
| 347 |
+
fp4_format : str
|
| 348 |
+
``"mxfp4"`` or ``"nvfp4"``.
|
| 349 |
+
|
| 350 |
+
Returns
|
| 351 |
+
-------
|
| 352 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 353 |
+
``(q_scale_mma, k_scale_mma)`` views in the storage layout validated by
|
| 354 |
+
``validate_mma_scale_storage``. These tensors can be passed to
|
| 355 |
+
``fp4_indexer_block_scores`` with ``scale_layout="preordered_mma"``.
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
spec = normalize_fp4_format(fp4_format)
|
| 359 |
+
if q_scale.device != k_scale.device:
|
| 360 |
+
raise ValueError("q_scale and k_scale must be on the same CUDA device")
|
| 361 |
+
_require_cuda_tensor(q_scale, name="q_scale")
|
| 362 |
+
_require_cuda_tensor(k_scale, name="k_scale")
|
| 363 |
+
if q_scale.ndim != 3:
|
| 364 |
+
raise ValueError(f"q_scale must have shape [total_q, Hq, G], got {tuple(q_scale.shape)}")
|
| 365 |
+
if k_scale.ndim != 4:
|
| 366 |
+
raise ValueError(f"k_scale must have shape [page_count, Hk, 128, G], got {tuple(k_scale.shape)}")
|
| 367 |
+
total_q, heads_q, _ = (int(v) for v in q_scale.shape)
|
| 368 |
+
page_count, heads_k, _, _ = (int(v) for v in k_scale.shape)
|
| 369 |
+
validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q)
|
| 370 |
+
validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k)
|
| 371 |
+
|
| 372 |
+
q_scale_mma = _empty_mma_scale_tensor(
|
| 373 |
+
mn=total_q,
|
| 374 |
+
l=heads_q,
|
| 375 |
+
spec=spec,
|
| 376 |
+
device=q_scale.device,
|
| 377 |
+
)
|
| 378 |
+
k_scale_mma = _empty_mma_scale_tensor(
|
| 379 |
+
mn=_PAGE_SIZE,
|
| 380 |
+
l=page_count * heads_k,
|
| 381 |
+
spec=spec,
|
| 382 |
+
device=k_scale.device,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
q_scale_ptr = make_ptr(
|
| 386 |
+
spec.cutlass_scale_dtype,
|
| 387 |
+
q_scale.data_ptr(),
|
| 388 |
+
cute.AddressSpace.gmem,
|
| 389 |
+
assumed_align=16,
|
| 390 |
+
)
|
| 391 |
+
k_scale_ptr = make_ptr(
|
| 392 |
+
spec.cutlass_scale_dtype,
|
| 393 |
+
k_scale.data_ptr(),
|
| 394 |
+
cute.AddressSpace.gmem,
|
| 395 |
+
assumed_align=16,
|
| 396 |
+
)
|
| 397 |
+
q_scale_mma_ptr = make_ptr(
|
| 398 |
+
spec.cutlass_scale_dtype,
|
| 399 |
+
q_scale_mma.data_ptr(),
|
| 400 |
+
cute.AddressSpace.gmem,
|
| 401 |
+
assumed_align=32,
|
| 402 |
+
)
|
| 403 |
+
k_scale_mma_ptr = make_ptr(
|
| 404 |
+
spec.cutlass_scale_dtype,
|
| 405 |
+
k_scale_mma.data_ptr(),
|
| 406 |
+
cute.AddressSpace.gmem,
|
| 407 |
+
assumed_align=32,
|
| 408 |
+
)
|
| 409 |
+
problem_size = (
|
| 410 |
+
Int32(total_q),
|
| 411 |
+
Int32(heads_q),
|
| 412 |
+
Int32(page_count),
|
| 413 |
+
Int32(heads_k),
|
| 414 |
+
)
|
| 415 |
+
stream = cuda.CUstream(torch.cuda.current_stream(q_scale.device).cuda_stream)
|
| 416 |
+
compiled = _compile_fp4_scale_reorder_kernel(
|
| 417 |
+
fmt=spec,
|
| 418 |
+
q_scale_ptr=q_scale_ptr,
|
| 419 |
+
k_scale_ptr=k_scale_ptr,
|
| 420 |
+
q_scale_mma_ptr=q_scale_mma_ptr,
|
| 421 |
+
k_scale_mma_ptr=k_scale_mma_ptr,
|
| 422 |
+
problem_size=problem_size,
|
| 423 |
+
stream=stream,
|
| 424 |
+
)
|
| 425 |
+
compiled(
|
| 426 |
+
q_scale_ptr,
|
| 427 |
+
k_scale_ptr,
|
| 428 |
+
q_scale_mma_ptr,
|
| 429 |
+
k_scale_mma_ptr,
|
| 430 |
+
problem_size,
|
| 431 |
+
stream,
|
| 432 |
+
)
|
| 433 |
+
return q_scale_mma, k_scale_mma
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _compile_fp4_decode_q_pack_kernel(
|
| 437 |
+
*,
|
| 438 |
+
fmt: Fp4FormatSpec,
|
| 439 |
+
q_ptr: cute.Pointer,
|
| 440 |
+
q_scale_ptr: cute.Pointer,
|
| 441 |
+
q_pack_ptr: cute.Pointer,
|
| 442 |
+
q_scale_pack_ptr: cute.Pointer,
|
| 443 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 444 |
+
problem_size: tuple,
|
| 445 |
+
stream: cuda.CUstream,
|
| 446 |
+
):
|
| 447 |
+
key = (
|
| 448 |
+
"fp4_indexer_decode_q_pack_sm100",
|
| 449 |
+
fmt.name,
|
| 450 |
+
)
|
| 451 |
+
if key not in _FP4_COMPILE_CACHE:
|
| 452 |
+
kernel = Fp4IndexerDecodeQPackSm100(fmt=fmt.name)
|
| 453 |
+
_FP4_COMPILE_CACHE[key] = cute.compile(
|
| 454 |
+
kernel,
|
| 455 |
+
q_ptr,
|
| 456 |
+
q_scale_ptr,
|
| 457 |
+
q_pack_ptr,
|
| 458 |
+
q_scale_pack_ptr,
|
| 459 |
+
cu_seqlens_q_ptr,
|
| 460 |
+
problem_size,
|
| 461 |
+
stream,
|
| 462 |
+
)
|
| 463 |
+
return _FP4_COMPILE_CACHE[key]
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _pack_decode_q_for_mma(
|
| 467 |
+
q_bytes: torch.Tensor,
|
| 468 |
+
q_scale_storage: torch.Tensor,
|
| 469 |
+
cu_seqlens_q: torch.Tensor,
|
| 470 |
+
*,
|
| 471 |
+
fmt: Fp4FormatSpec,
|
| 472 |
+
heads_q: int,
|
| 473 |
+
heads_k: int,
|
| 474 |
+
batch: int,
|
| 475 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 476 |
+
q_pack = torch.empty(
|
| 477 |
+
(batch * heads_k, _PAGE_SIZE, _FP4_PACKED_D_BYTES),
|
| 478 |
+
dtype=torch.uint8,
|
| 479 |
+
device=q_bytes.device,
|
| 480 |
+
)
|
| 481 |
+
q_scale_pack = torch.empty(
|
| 482 |
+
fp4_indexer_mma_scale_storage_shape(_PAGE_SIZE, batch * heads_k, fp4_format=fmt.name),
|
| 483 |
+
dtype=fmt.torch_scale_dtype,
|
| 484 |
+
device=q_bytes.device,
|
| 485 |
+
)
|
| 486 |
+
if q_pack.data_ptr() % 128 != 0:
|
| 487 |
+
raise ValueError("internal decode q_pack data pointer must be 128B aligned for TMA")
|
| 488 |
+
if q_scale_pack.data_ptr() % 32 != 0:
|
| 489 |
+
raise ValueError("internal decode q_scale_pack data pointer must be 32B aligned")
|
| 490 |
+
q_ptr = make_ptr(cutlass.Uint8, q_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
|
| 491 |
+
q_scale_ptr = make_ptr(
|
| 492 |
+
fmt.cutlass_scale_dtype,
|
| 493 |
+
q_scale_storage.data_ptr(),
|
| 494 |
+
cute.AddressSpace.gmem,
|
| 495 |
+
assumed_align=32,
|
| 496 |
+
)
|
| 497 |
+
q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
|
| 498 |
+
q_scale_pack_ptr = make_ptr(
|
| 499 |
+
fmt.cutlass_scale_dtype,
|
| 500 |
+
q_scale_pack.data_ptr(),
|
| 501 |
+
cute.AddressSpace.gmem,
|
| 502 |
+
assumed_align=32,
|
| 503 |
+
)
|
| 504 |
+
cu_seqlens_q_ptr = make_ptr(
|
| 505 |
+
cutlass.Int32,
|
| 506 |
+
cu_seqlens_q.data_ptr(),
|
| 507 |
+
cute.AddressSpace.gmem,
|
| 508 |
+
assumed_align=4,
|
| 509 |
+
)
|
| 510 |
+
problem_size = (
|
| 511 |
+
Int32(q_bytes.shape[0]),
|
| 512 |
+
Int32(heads_q),
|
| 513 |
+
Int32(heads_k),
|
| 514 |
+
Int32(batch),
|
| 515 |
+
)
|
| 516 |
+
stream = cuda.CUstream(torch.cuda.current_stream(q_bytes.device).cuda_stream)
|
| 517 |
+
compiled = _compile_fp4_decode_q_pack_kernel(
|
| 518 |
+
fmt=fmt,
|
| 519 |
+
q_ptr=q_ptr,
|
| 520 |
+
q_scale_ptr=q_scale_ptr,
|
| 521 |
+
q_pack_ptr=q_pack_ptr,
|
| 522 |
+
q_scale_pack_ptr=q_scale_pack_ptr,
|
| 523 |
+
cu_seqlens_q_ptr=cu_seqlens_q_ptr,
|
| 524 |
+
problem_size=problem_size,
|
| 525 |
+
stream=stream,
|
| 526 |
+
)
|
| 527 |
+
compiled(
|
| 528 |
+
q_ptr,
|
| 529 |
+
q_scale_ptr,
|
| 530 |
+
q_pack_ptr,
|
| 531 |
+
q_scale_pack_ptr,
|
| 532 |
+
cu_seqlens_q_ptr,
|
| 533 |
+
problem_size,
|
| 534 |
+
stream,
|
| 535 |
+
)
|
| 536 |
+
return q_pack, q_scale_pack
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _compile_fp4_decode_packed_q_kernel(
|
| 540 |
+
*,
|
| 541 |
+
fmt: Fp4FormatSpec,
|
| 542 |
+
causal: bool,
|
| 543 |
+
compact_schedule: bool,
|
| 544 |
+
device_arch: tuple[int, int],
|
| 545 |
+
use_tmem_load_red: bool,
|
| 546 |
+
q_pack_ptr: cute.Pointer,
|
| 547 |
+
k_ptr: cute.Pointer,
|
| 548 |
+
q_scale_pack_ptr: cute.Pointer,
|
| 549 |
+
k_scale_ptr: cute.Pointer,
|
| 550 |
+
scores_ptr: cute.Pointer,
|
| 551 |
+
kv_indices_ptr: cute.Pointer,
|
| 552 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 553 |
+
cu_seqlens_k_ptr: cute.Pointer,
|
| 554 |
+
cu_page_offsets_ptr: cute.Pointer,
|
| 555 |
+
qo_offset_ptr: cute.Pointer,
|
| 556 |
+
problem_size: tuple,
|
| 557 |
+
stream: cuda.CUstream,
|
| 558 |
+
):
|
| 559 |
+
key = (
|
| 560 |
+
"fp4_indexer_decode_packed_q_sm100",
|
| 561 |
+
fmt.name,
|
| 562 |
+
bool(causal),
|
| 563 |
+
bool(compact_schedule),
|
| 564 |
+
device_arch,
|
| 565 |
+
)
|
| 566 |
+
if key not in _FP4_COMPILE_CACHE:
|
| 567 |
+
kernel = Fp4IndexerDecodePackedQSm100(
|
| 568 |
+
fmt=fmt.name,
|
| 569 |
+
causal=causal,
|
| 570 |
+
compact_schedule=compact_schedule,
|
| 571 |
+
use_tmem_load_red=use_tmem_load_red,
|
| 572 |
+
)
|
| 573 |
+
_FP4_COMPILE_CACHE[key] = cute.compile(
|
| 574 |
+
kernel,
|
| 575 |
+
q_pack_ptr,
|
| 576 |
+
k_ptr,
|
| 577 |
+
q_scale_pack_ptr,
|
| 578 |
+
k_scale_ptr,
|
| 579 |
+
scores_ptr,
|
| 580 |
+
kv_indices_ptr,
|
| 581 |
+
cu_seqlens_q_ptr,
|
| 582 |
+
cu_seqlens_k_ptr,
|
| 583 |
+
cu_page_offsets_ptr,
|
| 584 |
+
qo_offset_ptr,
|
| 585 |
+
problem_size,
|
| 586 |
+
stream,
|
| 587 |
+
)
|
| 588 |
+
return _FP4_COMPILE_CACHE[key]
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def _run_fp4_decode_packed_q_scores(
|
| 592 |
+
q_pack: torch.Tensor,
|
| 593 |
+
k_bytes: torch.Tensor,
|
| 594 |
+
q_scale_pack: torch.Tensor,
|
| 595 |
+
k_scale_storage: torch.Tensor,
|
| 596 |
+
scores: torch.Tensor,
|
| 597 |
+
kv_indices: torch.Tensor,
|
| 598 |
+
cu_seqlens_q: torch.Tensor,
|
| 599 |
+
cu_seqlens_k: torch.Tensor,
|
| 600 |
+
cu_page_offsets: torch.Tensor,
|
| 601 |
+
qo_offset_arg: torch.Tensor,
|
| 602 |
+
*,
|
| 603 |
+
fmt: Fp4FormatSpec,
|
| 604 |
+
causal: bool,
|
| 605 |
+
has_qo_offset: int,
|
| 606 |
+
heads_q: int,
|
| 607 |
+
heads_k: int,
|
| 608 |
+
batch: int,
|
| 609 |
+
max_k_tiles: int,
|
| 610 |
+
total_q: int,
|
| 611 |
+
device_arch: tuple[int, int],
|
| 612 |
+
use_tmem_load_red: bool,
|
| 613 |
+
) -> None:
|
| 614 |
+
page_count = int(k_bytes.shape[0])
|
| 615 |
+
rectangular_groups = batch * ceil_div(max_k_tiles, _DECODE_K_TILES_PER_CTA)
|
| 616 |
+
compact_groups = ceil_div(page_count + batch * (_DECODE_K_TILES_PER_CTA - 1), _DECODE_K_TILES_PER_CTA)
|
| 617 |
+
compact_schedule = compact_groups < rectangular_groups
|
| 618 |
+
if compact_schedule:
|
| 619 |
+
scores.fill_(float("-inf"))
|
| 620 |
+
|
| 621 |
+
q_pack_ptr = make_ptr(cutlass.Uint8, q_pack.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
|
| 622 |
+
k_ptr = make_ptr(cutlass.Uint8, k_bytes.data_ptr(), cute.AddressSpace.gmem, assumed_align=128)
|
| 623 |
+
q_scale_pack_ptr = make_ptr(
|
| 624 |
+
fmt.cutlass_scale_dtype,
|
| 625 |
+
q_scale_pack.data_ptr(),
|
| 626 |
+
cute.AddressSpace.gmem,
|
| 627 |
+
assumed_align=32,
|
| 628 |
+
)
|
| 629 |
+
k_scale_ptr = make_ptr(
|
| 630 |
+
fmt.cutlass_scale_dtype,
|
| 631 |
+
k_scale_storage.data_ptr(),
|
| 632 |
+
cute.AddressSpace.gmem,
|
| 633 |
+
assumed_align=32,
|
| 634 |
+
)
|
| 635 |
+
scores_ptr = make_ptr(cutlass.Float32, scores.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
|
| 636 |
+
kv_indices_ptr = make_ptr(cutlass.Int32, kv_indices.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 637 |
+
cu_seqlens_q_ptr = make_ptr(cutlass.Int32, cu_seqlens_q.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 638 |
+
cu_seqlens_k_ptr = make_ptr(cutlass.Int32, cu_seqlens_k.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 639 |
+
cu_page_offsets_ptr = make_ptr(cutlass.Int32, cu_page_offsets.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 640 |
+
qo_offset_ptr = make_ptr(cutlass.Int32, qo_offset_arg.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
| 641 |
+
problem_size = (
|
| 642 |
+
Int32(_PAGE_SIZE),
|
| 643 |
+
Int32(max_k_tiles * _PAGE_SIZE),
|
| 644 |
+
Int32(_HEAD_DIM),
|
| 645 |
+
Int32(batch * heads_k),
|
| 646 |
+
Int32(page_count * heads_k),
|
| 647 |
+
Int32(heads_q),
|
| 648 |
+
Int32(heads_k),
|
| 649 |
+
Int32(batch),
|
| 650 |
+
Int32(max_k_tiles),
|
| 651 |
+
Int32(total_q),
|
| 652 |
+
Int32(has_qo_offset),
|
| 653 |
+
)
|
| 654 |
+
stream = cuda.CUstream(torch.cuda.current_stream(q_pack.device).cuda_stream)
|
| 655 |
+
compiled = _compile_fp4_decode_packed_q_kernel(
|
| 656 |
+
fmt=fmt,
|
| 657 |
+
causal=causal,
|
| 658 |
+
compact_schedule=compact_schedule,
|
| 659 |
+
device_arch=device_arch,
|
| 660 |
+
use_tmem_load_red=use_tmem_load_red,
|
| 661 |
+
q_pack_ptr=q_pack_ptr,
|
| 662 |
+
k_ptr=k_ptr,
|
| 663 |
+
q_scale_pack_ptr=q_scale_pack_ptr,
|
| 664 |
+
k_scale_ptr=k_scale_ptr,
|
| 665 |
+
scores_ptr=scores_ptr,
|
| 666 |
+
kv_indices_ptr=kv_indices_ptr,
|
| 667 |
+
cu_seqlens_q_ptr=cu_seqlens_q_ptr,
|
| 668 |
+
cu_seqlens_k_ptr=cu_seqlens_k_ptr,
|
| 669 |
+
cu_page_offsets_ptr=cu_page_offsets_ptr,
|
| 670 |
+
qo_offset_ptr=qo_offset_ptr,
|
| 671 |
+
problem_size=problem_size,
|
| 672 |
+
stream=stream,
|
| 673 |
+
)
|
| 674 |
+
compiled(
|
| 675 |
+
q_pack_ptr,
|
| 676 |
+
k_ptr,
|
| 677 |
+
q_scale_pack_ptr,
|
| 678 |
+
k_scale_ptr,
|
| 679 |
+
scores_ptr,
|
| 680 |
+
kv_indices_ptr,
|
| 681 |
+
cu_seqlens_q_ptr,
|
| 682 |
+
cu_seqlens_k_ptr,
|
| 683 |
+
cu_page_offsets_ptr,
|
| 684 |
+
qo_offset_ptr,
|
| 685 |
+
problem_size,
|
| 686 |
+
stream,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _compile_fp4_qk_kernel(
|
| 691 |
+
*,
|
| 692 |
+
fmt: Fp4FormatSpec,
|
| 693 |
+
causal: bool,
|
| 694 |
+
preordered_q_scale_tma: bool,
|
| 695 |
+
compact_schedule: bool,
|
| 696 |
+
device_arch: tuple[int, int],
|
| 697 |
+
use_tmem_load_red: bool,
|
| 698 |
+
q_ptr: cute.Pointer,
|
| 699 |
+
k_ptr: cute.Pointer,
|
| 700 |
+
q_scale_ptr: cute.Pointer,
|
| 701 |
+
k_scale_ptr: cute.Pointer,
|
| 702 |
+
scores_ptr: cute.Pointer,
|
| 703 |
+
kv_indices_ptr: cute.Pointer,
|
| 704 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 705 |
+
cu_seqlens_k_ptr: cute.Pointer,
|
| 706 |
+
cu_page_offsets_ptr: cute.Pointer,
|
| 707 |
+
qo_offset_ptr: cute.Pointer,
|
| 708 |
+
problem_size: tuple,
|
| 709 |
+
stream: cuda.CUstream,
|
| 710 |
+
):
|
| 711 |
+
key = (
|
| 712 |
+
"fp4_indexer_staged_mma_sm100",
|
| 713 |
+
fmt.name,
|
| 714 |
+
bool(causal),
|
| 715 |
+
bool(preordered_q_scale_tma),
|
| 716 |
+
bool(compact_schedule),
|
| 717 |
+
device_arch,
|
| 718 |
+
)
|
| 719 |
+
if key not in _FP4_COMPILE_CACHE:
|
| 720 |
+
kernel = Fp4IndexerStagedMmaSm100(
|
| 721 |
+
fmt=fmt.name,
|
| 722 |
+
causal=causal,
|
| 723 |
+
preordered_q_scale_tma=preordered_q_scale_tma,
|
| 724 |
+
compact_schedule=compact_schedule,
|
| 725 |
+
use_tmem_load_red=use_tmem_load_red,
|
| 726 |
+
)
|
| 727 |
+
_FP4_COMPILE_CACHE[key] = cute.compile(
|
| 728 |
+
kernel,
|
| 729 |
+
q_ptr,
|
| 730 |
+
k_ptr,
|
| 731 |
+
q_scale_ptr,
|
| 732 |
+
k_scale_ptr,
|
| 733 |
+
scores_ptr,
|
| 734 |
+
kv_indices_ptr,
|
| 735 |
+
cu_seqlens_q_ptr,
|
| 736 |
+
cu_seqlens_k_ptr,
|
| 737 |
+
cu_page_offsets_ptr,
|
| 738 |
+
qo_offset_ptr,
|
| 739 |
+
problem_size,
|
| 740 |
+
stream,
|
| 741 |
+
)
|
| 742 |
+
return _FP4_COMPILE_CACHE[key]
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def fp4_indexer_block_scores(
|
| 746 |
+
q_fp4: torch.Tensor,
|
| 747 |
+
k_fp4: torch.Tensor,
|
| 748 |
+
q_scale: torch.Tensor,
|
| 749 |
+
k_scale: torch.Tensor,
|
| 750 |
+
cu_seqlens_q: torch.Tensor,
|
| 751 |
+
cu_seqlens_k: torch.Tensor,
|
| 752 |
+
cu_page_offsets: torch.Tensor,
|
| 753 |
+
*,
|
| 754 |
+
max_seqlen_q: int,
|
| 755 |
+
max_seqlen_k: int,
|
| 756 |
+
kv_indices: torch.Tensor,
|
| 757 |
+
fp4_format: str,
|
| 758 |
+
causal: bool = False,
|
| 759 |
+
qo_offset: Optional[torch.Tensor] = None,
|
| 760 |
+
scale_layout: str = _PREORDERED_MMA_SCALE_LAYOUT,
|
| 761 |
+
) -> torch.Tensor:
|
| 762 |
+
"""Return FP4 QK max scores per 128-token KV page.
|
| 763 |
+
|
| 764 |
+
Parameters
|
| 765 |
+
----------
|
| 766 |
+
q_fp4 : torch.Tensor
|
| 767 |
+
Packed FP4 Q tensor with shape ``[total_qo_len, Hq, 64]``. The last
|
| 768 |
+
dimension stores two FP4 values per byte for logical head dimension
|
| 769 |
+
128.
|
| 770 |
+
k_fp4 : torch.Tensor
|
| 771 |
+
Packed paged FP4 K tensor with shape ``[total_pages, Hk, 128, 64]``.
|
| 772 |
+
q_scale : torch.Tensor
|
| 773 |
+
Q scale tensor. With ``scale_layout="public"``, shape is
|
| 774 |
+
``[total_qo_len, Hq, G]``. With ``"preordered_mma"``, use
|
| 775 |
+
``fp4_indexer_reorder_scales_for_mma_cute`` output layout.
|
| 776 |
+
k_scale : torch.Tensor
|
| 777 |
+
K scale tensor. With ``scale_layout="public"``, shape is
|
| 778 |
+
``[total_pages, Hk, 128, G]``. With ``"preordered_mma"``, use the
|
| 779 |
+
preordered MMA scale layout.
|
| 780 |
+
cu_seqlens_q : torch.Tensor
|
| 781 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
|
| 782 |
+
cu_seqlens_k : torch.Tensor
|
| 783 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
|
| 784 |
+
cu_page_offsets : torch.Tensor
|
| 785 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of per-request
|
| 786 |
+
page counts.
|
| 787 |
+
max_seqlen_q : int
|
| 788 |
+
Maximum Q sequence length.
|
| 789 |
+
max_seqlen_k : int
|
| 790 |
+
Maximum KV sequence length.
|
| 791 |
+
kv_indices : torch.Tensor
|
| 792 |
+
Flattened physical page indices with shape ``[sum_pages]`` and dtype
|
| 793 |
+
int32.
|
| 794 |
+
fp4_format : str
|
| 795 |
+
``"mxfp4"`` or ``"nvfp4"``.
|
| 796 |
+
causal : bool, optional
|
| 797 |
+
Whether to apply causal masking.
|
| 798 |
+
qo_offset : torch.Tensor, optional
|
| 799 |
+
Shape ``[batch_size]``, dtype int32. Per-request causal offset. Valid
|
| 800 |
+
only when ``causal=True``.
|
| 801 |
+
scale_layout : str, optional
|
| 802 |
+
``"public"`` accepts logical public scale tensors and launches a scale
|
| 803 |
+
reorder kernel. ``"preordered_mma"`` expects preordered MMA scale
|
| 804 |
+
tensors and skips the reorder.
|
| 805 |
+
|
| 806 |
+
Returns
|
| 807 |
+
-------
|
| 808 |
+
torch.Tensor
|
| 809 |
+
Shape ``[Hq, ceil(max_seqlen_k / 128), total_qo_len]``, dtype float32.
|
| 810 |
+
Entries beyond the valid KV page range are ``-inf``.
|
| 811 |
+
"""
|
| 812 |
+
|
| 813 |
+
spec = normalize_fp4_format(fp4_format)
|
| 814 |
+
causal = bool(causal)
|
| 815 |
+
scale_layout = normalize_scale_layout(scale_layout)
|
| 816 |
+
use_preordered_q_scale_tma = int(max_seqlen_q) >= _PAGE_SIZE
|
| 817 |
+
q_bytes = _as_fp4_thd_bytes(q_fp4, name="q_fp4")
|
| 818 |
+
k_bytes = _as_fp4_paged_hnd_bytes(k_fp4, name="k_fp4")
|
| 819 |
+
total_q, heads_q, _ = (int(v) for v in q_bytes.shape)
|
| 820 |
+
page_count, heads_k, page_size, _ = (int(v) for v in k_bytes.shape)
|
| 821 |
+
if page_size != _PAGE_SIZE:
|
| 822 |
+
raise ValueError(f"k_fp4 page_size must be 128, got {page_size}")
|
| 823 |
+
if heads_q % heads_k != 0:
|
| 824 |
+
raise ValueError("num_qo_heads must be divisible by num_kv_heads")
|
| 825 |
+
_require_cuda_tensor(q_fp4, name="q_fp4")
|
| 826 |
+
_require_cuda_tensor(k_fp4, name="k_fp4")
|
| 827 |
+
device_arch = _device_arch(q_fp4.device)
|
| 828 |
+
use_tmem_load_red = _supports_tmem_load_red(device_arch)
|
| 829 |
+
_require_int32_vector(cu_seqlens_q, name="cu_seqlens_q", device=q_fp4.device)
|
| 830 |
+
_require_int32_vector(cu_seqlens_k, name="cu_seqlens_k", device=q_fp4.device)
|
| 831 |
+
_require_int32_vector(cu_page_offsets, name="cu_page_offsets", device=q_fp4.device)
|
| 832 |
+
if q_scale.device != q_fp4.device or k_scale.device != q_fp4.device:
|
| 833 |
+
raise ValueError("q_scale and k_scale must be on the same CUDA device as q_fp4")
|
| 834 |
+
if scale_layout == _PUBLIC_SCALE_LAYOUT:
|
| 835 |
+
validate_q_scale_thg(q_scale, name="q_scale", fmt=spec, total_q=total_q, heads=heads_q)
|
| 836 |
+
validate_k_scale_phsg(k_scale, name="k_scale", fmt=spec, page_count=page_count, heads=heads_k)
|
| 837 |
+
else:
|
| 838 |
+
validate_mma_scale_storage(q_scale, name="q_scale", fmt=spec, mn=total_q, l=heads_q)
|
| 839 |
+
validate_mma_scale_storage(k_scale, name="k_scale", fmt=spec, mn=_PAGE_SIZE, l=page_count * heads_k)
|
| 840 |
+
batch = int(cu_seqlens_q.shape[0]) - 1
|
| 841 |
+
if batch < 0:
|
| 842 |
+
raise ValueError("cu_seqlens_q must have shape [B + 1]")
|
| 843 |
+
if cu_seqlens_q.shape != cu_seqlens_k.shape or cu_seqlens_q.shape != cu_page_offsets.shape:
|
| 844 |
+
raise ValueError("cu_seqlens_q, cu_seqlens_k, and cu_page_offsets must have shape [B + 1]")
|
| 845 |
+
if q_bytes.data_ptr() % 128 != 0:
|
| 846 |
+
raise ValueError("q_fp4 data pointer must be 128B aligned for TMA")
|
| 847 |
+
if k_bytes.data_ptr() % 128 != 0:
|
| 848 |
+
raise ValueError("k_fp4 data pointer must be 128B aligned for TMA")
|
| 849 |
+
if kv_indices is None:
|
| 850 |
+
raise ValueError("kv_indices is required")
|
| 851 |
+
if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1:
|
| 852 |
+
raise ValueError("kv_indices must have shape [sum_pages], dtype torch.int32, and match q_fp4.device")
|
| 853 |
+
if not kv_indices.is_contiguous():
|
| 854 |
+
raise ValueError("kv_indices must be contiguous")
|
| 855 |
+
if qo_offset is not None:
|
| 856 |
+
if not causal:
|
| 857 |
+
raise ValueError("qo_offset is only valid when causal=True")
|
| 858 |
+
if qo_offset.device != q_fp4.device or qo_offset.dtype != torch.int32 or qo_offset.shape != (batch,):
|
| 859 |
+
raise ValueError("qo_offset must have shape [B], dtype torch.int32, and match q_fp4.device")
|
| 860 |
+
if not qo_offset.is_contiguous():
|
| 861 |
+
raise ValueError("qo_offset must be contiguous")
|
| 862 |
+
|
| 863 |
+
m_extent = int(max_seqlen_q)
|
| 864 |
+
max_k_tiles = ceil_div(int(max_seqlen_k), _PAGE_SIZE)
|
| 865 |
+
n_aligned = max_k_tiles * _PAGE_SIZE
|
| 866 |
+
if max_k_tiles == 0:
|
| 867 |
+
return torch.full(
|
| 868 |
+
(heads_q, 0, total_q),
|
| 869 |
+
float("-inf"),
|
| 870 |
+
dtype=torch.float32,
|
| 871 |
+
device=q_fp4.device,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
scores = torch.empty(
|
| 875 |
+
(heads_q, max_k_tiles, total_q),
|
| 876 |
+
dtype=torch.float32,
|
| 877 |
+
device=q_fp4.device,
|
| 878 |
+
)
|
| 879 |
+
if qo_offset is None:
|
| 880 |
+
qo_offset_arg = torch.empty((batch,), dtype=torch.int32, device=q_fp4.device)
|
| 881 |
+
has_qo_offset = 0
|
| 882 |
+
else:
|
| 883 |
+
qo_offset_arg = qo_offset
|
| 884 |
+
has_qo_offset = 1
|
| 885 |
+
if scale_layout == _PUBLIC_SCALE_LAYOUT:
|
| 886 |
+
q_scale_arg, k_scale_arg = fp4_indexer_reorder_scales_for_mma_cute(
|
| 887 |
+
q_scale,
|
| 888 |
+
k_scale,
|
| 889 |
+
fp4_format=spec.name,
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
q_scale_arg = q_scale
|
| 893 |
+
k_scale_arg = k_scale
|
| 894 |
+
scale_assumed_align = 32
|
| 895 |
+
if q_scale_arg.data_ptr() % scale_assumed_align != 0:
|
| 896 |
+
raise ValueError(f"q_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale")
|
| 897 |
+
if k_scale_arg.data_ptr() % scale_assumed_align != 0:
|
| 898 |
+
raise ValueError(f"k_scale data pointer must be {scale_assumed_align}B aligned for MMA storage scale")
|
| 899 |
+
use_decode_packed_q = int(max_seqlen_q) <= _DECODE_PACK_Q_LEN and heads_q // heads_k == _DECODE_QHEAD_PER_KV
|
| 900 |
+
if use_decode_packed_q:
|
| 901 |
+
q_pack, q_scale_pack = _pack_decode_q_for_mma(
|
| 902 |
+
q_bytes,
|
| 903 |
+
q_scale_arg,
|
| 904 |
+
cu_seqlens_q,
|
| 905 |
+
fmt=spec,
|
| 906 |
+
heads_q=heads_q,
|
| 907 |
+
heads_k=heads_k,
|
| 908 |
+
batch=batch,
|
| 909 |
+
)
|
| 910 |
+
_run_fp4_decode_packed_q_scores(
|
| 911 |
+
q_pack,
|
| 912 |
+
k_bytes,
|
| 913 |
+
q_scale_pack,
|
| 914 |
+
k_scale_arg,
|
| 915 |
+
scores,
|
| 916 |
+
kv_indices,
|
| 917 |
+
cu_seqlens_q,
|
| 918 |
+
cu_seqlens_k,
|
| 919 |
+
cu_page_offsets,
|
| 920 |
+
qo_offset_arg,
|
| 921 |
+
fmt=spec,
|
| 922 |
+
causal=causal,
|
| 923 |
+
has_qo_offset=has_qo_offset,
|
| 924 |
+
heads_q=heads_q,
|
| 925 |
+
heads_k=heads_k,
|
| 926 |
+
batch=batch,
|
| 927 |
+
max_k_tiles=max_k_tiles,
|
| 928 |
+
total_q=total_q,
|
| 929 |
+
device_arch=device_arch,
|
| 930 |
+
use_tmem_load_red=use_tmem_load_red,
|
| 931 |
+
)
|
| 932 |
+
return scores
|
| 933 |
+
prefill_compact_task_count = 0
|
| 934 |
+
prefill_compact_schedule = False
|
| 935 |
+
if causal and has_qo_offset == 0:
|
| 936 |
+
k_tiles_per_cta = k_tiles_per_cta_for(causal)
|
| 937 |
+
q_tile_count = ceil_div(m_extent, _MMA_TILER_MN[0])
|
| 938 |
+
k_group_count = ceil_div(max_k_tiles, k_tiles_per_cta)
|
| 939 |
+
rectangular_task_count = q_tile_count * k_group_count
|
| 940 |
+
prefill_compact_task_count = min(
|
| 941 |
+
rectangular_task_count,
|
| 942 |
+
_causal_compact_task_bound(m_extent, int(max_seqlen_k), k_tiles_per_cta),
|
| 943 |
+
)
|
| 944 |
+
prefill_compact_schedule = prefill_compact_task_count * 20 <= rectangular_task_count * 19
|
| 945 |
+
if prefill_compact_schedule:
|
| 946 |
+
scores.fill_(float("-inf"))
|
| 947 |
+
q_ptr = make_ptr(
|
| 948 |
+
cutlass.Uint8,
|
| 949 |
+
q_bytes.data_ptr(),
|
| 950 |
+
cute.AddressSpace.gmem,
|
| 951 |
+
assumed_align=128,
|
| 952 |
+
)
|
| 953 |
+
k_ptr = make_ptr(
|
| 954 |
+
cutlass.Uint8,
|
| 955 |
+
k_bytes.data_ptr(),
|
| 956 |
+
cute.AddressSpace.gmem,
|
| 957 |
+
assumed_align=128,
|
| 958 |
+
)
|
| 959 |
+
q_scale_ptr = make_ptr(
|
| 960 |
+
spec.cutlass_scale_dtype,
|
| 961 |
+
q_scale_arg.data_ptr(),
|
| 962 |
+
cute.AddressSpace.gmem,
|
| 963 |
+
assumed_align=scale_assumed_align,
|
| 964 |
+
)
|
| 965 |
+
k_scale_ptr = make_ptr(
|
| 966 |
+
spec.cutlass_scale_dtype,
|
| 967 |
+
k_scale_arg.data_ptr(),
|
| 968 |
+
cute.AddressSpace.gmem,
|
| 969 |
+
assumed_align=scale_assumed_align,
|
| 970 |
+
)
|
| 971 |
+
scores_ptr = make_ptr(
|
| 972 |
+
cutlass.Float32,
|
| 973 |
+
scores.data_ptr(),
|
| 974 |
+
cute.AddressSpace.gmem,
|
| 975 |
+
assumed_align=16,
|
| 976 |
+
)
|
| 977 |
+
kv_indices_ptr = make_ptr(
|
| 978 |
+
cutlass.Int32,
|
| 979 |
+
kv_indices.data_ptr(),
|
| 980 |
+
cute.AddressSpace.gmem,
|
| 981 |
+
assumed_align=4,
|
| 982 |
+
)
|
| 983 |
+
cu_seqlens_q_ptr = make_ptr(
|
| 984 |
+
cutlass.Int32,
|
| 985 |
+
cu_seqlens_q.data_ptr(),
|
| 986 |
+
cute.AddressSpace.gmem,
|
| 987 |
+
assumed_align=4,
|
| 988 |
+
)
|
| 989 |
+
cu_seqlens_k_ptr = make_ptr(
|
| 990 |
+
cutlass.Int32,
|
| 991 |
+
cu_seqlens_k.data_ptr(),
|
| 992 |
+
cute.AddressSpace.gmem,
|
| 993 |
+
assumed_align=4,
|
| 994 |
+
)
|
| 995 |
+
cu_page_offsets_ptr = make_ptr(
|
| 996 |
+
cutlass.Int32,
|
| 997 |
+
cu_page_offsets.data_ptr(),
|
| 998 |
+
cute.AddressSpace.gmem,
|
| 999 |
+
assumed_align=4,
|
| 1000 |
+
)
|
| 1001 |
+
qo_offset_ptr = make_ptr(
|
| 1002 |
+
cutlass.Int32,
|
| 1003 |
+
qo_offset_arg.data_ptr(),
|
| 1004 |
+
cute.AddressSpace.gmem,
|
| 1005 |
+
assumed_align=4,
|
| 1006 |
+
)
|
| 1007 |
+
problem_size = (
|
| 1008 |
+
Int32(m_extent),
|
| 1009 |
+
Int32(n_aligned),
|
| 1010 |
+
Int32(_HEAD_DIM),
|
| 1011 |
+
Int32(batch * heads_q),
|
| 1012 |
+
Int32(page_count * heads_k),
|
| 1013 |
+
Int32(heads_q),
|
| 1014 |
+
Int32(heads_k),
|
| 1015 |
+
Int32(batch),
|
| 1016 |
+
Int32(max_k_tiles),
|
| 1017 |
+
Int32(total_q),
|
| 1018 |
+
Int32(has_qo_offset),
|
| 1019 |
+
Int32(prefill_compact_task_count),
|
| 1020 |
+
)
|
| 1021 |
+
stream = cuda.CUstream(torch.cuda.current_stream(q_fp4.device).cuda_stream)
|
| 1022 |
+
compiled = _compile_fp4_qk_kernel(
|
| 1023 |
+
fmt=spec,
|
| 1024 |
+
causal=causal,
|
| 1025 |
+
preordered_q_scale_tma=use_preordered_q_scale_tma,
|
| 1026 |
+
compact_schedule=prefill_compact_schedule,
|
| 1027 |
+
device_arch=device_arch,
|
| 1028 |
+
use_tmem_load_red=use_tmem_load_red,
|
| 1029 |
+
q_ptr=q_ptr,
|
| 1030 |
+
k_ptr=k_ptr,
|
| 1031 |
+
q_scale_ptr=q_scale_ptr,
|
| 1032 |
+
k_scale_ptr=k_scale_ptr,
|
| 1033 |
+
scores_ptr=scores_ptr,
|
| 1034 |
+
kv_indices_ptr=kv_indices_ptr,
|
| 1035 |
+
cu_seqlens_q_ptr=cu_seqlens_q_ptr,
|
| 1036 |
+
cu_seqlens_k_ptr=cu_seqlens_k_ptr,
|
| 1037 |
+
cu_page_offsets_ptr=cu_page_offsets_ptr,
|
| 1038 |
+
qo_offset_ptr=qo_offset_ptr,
|
| 1039 |
+
problem_size=problem_size,
|
| 1040 |
+
stream=stream,
|
| 1041 |
+
)
|
| 1042 |
+
compiled(
|
| 1043 |
+
q_ptr,
|
| 1044 |
+
k_ptr,
|
| 1045 |
+
q_scale_ptr,
|
| 1046 |
+
k_scale_ptr,
|
| 1047 |
+
scores_ptr,
|
| 1048 |
+
kv_indices_ptr,
|
| 1049 |
+
cu_seqlens_q_ptr,
|
| 1050 |
+
cu_seqlens_k_ptr,
|
| 1051 |
+
cu_page_offsets_ptr,
|
| 1052 |
+
qo_offset_ptr,
|
| 1053 |
+
problem_size,
|
| 1054 |
+
stream,
|
| 1055 |
+
)
|
| 1056 |
+
return scores
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
__all__ = [
|
| 1060 |
+
"fp4_indexer_block_scores",
|
| 1061 |
+
]
|
build/torch211-cxx11-cu128-x86_64-linux/interface.py
ADDED
|
@@ -0,0 +1,2011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Sparse attention interface.
|
| 5 |
+
|
| 6 |
+
Current delivery scope:
|
| 7 |
+
- head dimension is supported only for D=128
|
| 8 |
+
|
| 9 |
+
Public API:
|
| 10 |
+
sparse_atten_func(...)
|
| 11 |
+
sparse_decode_atten_func(...)
|
| 12 |
+
SparseDecodePagedAttentionWrapper
|
| 13 |
+
|
| 14 |
+
Internal forward core:
|
| 15 |
+
_sparse_atten_csr_varlen_forward(...)
|
| 16 |
+
|
| 17 |
+
Preprocessing (external, done once):
|
| 18 |
+
q2k_indices [head_kv, total_q, topK] -> sparse_index_utils.build_k2q_csr()
|
| 19 |
+
-> k2q_row_ptr [head_kv, total_rows + 1] int32
|
| 20 |
+
-> k2q_q_indices [head_kv, total_q * topK] int32
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
import os
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import cutlass
|
| 28 |
+
import cutlass.cute as cute
|
| 29 |
+
import torch
|
| 30 |
+
from cutlass import Float32, Int32
|
| 31 |
+
from cutlass.cute.runtime import from_dlpack
|
| 32 |
+
|
| 33 |
+
from .src.sm100.fwd.combine import combine
|
| 34 |
+
from .src.sm100.fwd.atten_fwd import SparseAttentionForwardSm100
|
| 35 |
+
from .src.sm100.fwd.atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100
|
| 36 |
+
from .src.sm100.prepare_scheduler import (
|
| 37 |
+
SparseAttentionSchedule,
|
| 38 |
+
prepare_sparse_fwd_schedule_and_split,
|
| 39 |
+
)
|
| 40 |
+
from .src.sm100.decode_schedule import (
|
| 41 |
+
DecodeAttentionSchedule,
|
| 42 |
+
prepare_decode_schedule,
|
| 43 |
+
)
|
| 44 |
+
from .src.common.cute_dsl_utils import to_cute_tensor as to_cute_tensor_kvouter
|
| 45 |
+
from .src.common.tma_utils import (
|
| 46 |
+
create_q_gather4_tma_desc,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
_compile_cache: dict = {}
|
| 50 |
+
_TEMPERATURE_LSE_FAST_PATH_ABS_TOL = 1e-12
|
| 51 |
+
_SUPPORTED_SPARSE_TOPK = (4, 8, 16, 32)
|
| 52 |
+
_SUPPORTED_FWD_DTYPES = (torch.bfloat16, torch.float8_e4m3fn)
|
| 53 |
+
_SUPPORTED_FWD_MMA_DTYPES = (torch.bfloat16, torch.float8_e4m3fn)
|
| 54 |
+
_SUPPORTED_DECODE_QHEAD_PER_KV = 16
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _normalize_partial_dtype(partial_dtype: torch.dtype) -> torch.dtype:
|
| 58 |
+
supported = {torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn}
|
| 59 |
+
if partial_dtype not in supported:
|
| 60 |
+
raise TypeError(
|
| 61 |
+
"partial_dtype must be one of torch.float32 / torch.bfloat16 / "
|
| 62 |
+
"torch.float16 / torch.float8_e4m3fn, "
|
| 63 |
+
f"got {partial_dtype}"
|
| 64 |
+
)
|
| 65 |
+
return partial_dtype
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _normalize_forward_mma_dtype(dtype: Optional[torch.dtype], fallback: torch.dtype, name: str) -> torch.dtype:
|
| 69 |
+
dtype = fallback if dtype is None else dtype
|
| 70 |
+
if dtype not in _SUPPORTED_FWD_MMA_DTYPES:
|
| 71 |
+
raise TypeError(
|
| 72 |
+
f"{name} must be one of torch.bfloat16 / torch.float8_e4m3fn, got {dtype}"
|
| 73 |
+
)
|
| 74 |
+
return dtype
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _resolve_forward_mma_dtypes(
|
| 78 |
+
q: torch.Tensor,
|
| 79 |
+
k: torch.Tensor,
|
| 80 |
+
v: torch.Tensor,
|
| 81 |
+
qk_dtype: Optional[torch.dtype],
|
| 82 |
+
pv_dtype: Optional[torch.dtype],
|
| 83 |
+
) -> tuple[torch.dtype, torch.dtype]:
|
| 84 |
+
qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype")
|
| 85 |
+
if pv_dtype is None:
|
| 86 |
+
# Preserve the historical fp8 KV-cache path: BF16 Q with FP8 K/V
|
| 87 |
+
# stages both K and V as BF16 compute operands.
|
| 88 |
+
if (
|
| 89 |
+
q.dtype == torch.bfloat16
|
| 90 |
+
and k.dtype == torch.float8_e4m3fn
|
| 91 |
+
and v.dtype == torch.float8_e4m3fn
|
| 92 |
+
):
|
| 93 |
+
pv_dtype = torch.bfloat16
|
| 94 |
+
else:
|
| 95 |
+
pv_dtype = v.dtype
|
| 96 |
+
pv_dtype = _normalize_forward_mma_dtype(pv_dtype, pv_dtype, "pv_dtype")
|
| 97 |
+
|
| 98 |
+
if q.dtype != qk_dtype:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
"qk_dtype must match q storage dtype; Q fp8->bf16 staging is not supported"
|
| 101 |
+
)
|
| 102 |
+
if k.dtype != qk_dtype:
|
| 103 |
+
if not (k.dtype == torch.float8_e4m3fn and qk_dtype == torch.bfloat16):
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"unsupported K storage/qk_dtype combination; only fp8 K -> bf16 QK staging is supported"
|
| 106 |
+
)
|
| 107 |
+
if v.dtype != pv_dtype:
|
| 108 |
+
if not (v.dtype == torch.float8_e4m3fn and pv_dtype == torch.bfloat16):
|
| 109 |
+
raise ValueError(
|
| 110 |
+
"unsupported V storage/pv_dtype combination; only fp8 V -> bf16 PV staging is supported"
|
| 111 |
+
)
|
| 112 |
+
return qk_dtype, pv_dtype
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _to_cute_tensor_meta(t: torch.Tensor, assumed_align: int = 4):
|
| 116 |
+
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True)
|
| 117 |
+
return tensor.mark_layout_dynamic(leading_dim=0)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _torch_dtype_to_cutlass_dtype(dtype: torch.dtype):
|
| 121 |
+
if dtype == torch.bfloat16:
|
| 122 |
+
return cutlass.BFloat16
|
| 123 |
+
if dtype == torch.float16:
|
| 124 |
+
return cutlass.Float16
|
| 125 |
+
if dtype == torch.float8_e4m3fn:
|
| 126 |
+
return cutlass.Float8E4M3FN
|
| 127 |
+
raise TypeError(
|
| 128 |
+
f"Only torch.bfloat16, torch.float16, torch.float8_e4m3fn supported, got {dtype}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _prepare_paged_kv_for_tma(k, v, blk_kv: int):
|
| 133 |
+
page_size = int(k.shape[2])
|
| 134 |
+
if page_size != blk_kv:
|
| 135 |
+
raise ValueError(f"Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}")
|
| 136 |
+
return k, v
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _validate_cu_seqlens(
|
| 140 |
+
cu_seqlens: torch.Tensor,
|
| 141 |
+
*,
|
| 142 |
+
name: str,
|
| 143 |
+
device: torch.device,
|
| 144 |
+
) -> None:
|
| 145 |
+
if cu_seqlens.device != device:
|
| 146 |
+
raise ValueError(f"{name} must be on the same device as q")
|
| 147 |
+
if cu_seqlens.dtype != torch.int32:
|
| 148 |
+
raise TypeError(f"{name} must be torch.int32")
|
| 149 |
+
if cu_seqlens.ndim != 1:
|
| 150 |
+
raise ValueError(f"{name} must have shape [B + 1]")
|
| 151 |
+
if cu_seqlens.shape[0] < 1:
|
| 152 |
+
raise ValueError(f"{name} must have at least one element")
|
| 153 |
+
if not cu_seqlens.is_contiguous():
|
| 154 |
+
raise ValueError(f"{name} must be contiguous")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _csr_row_capacity(k2q_row_ptr: torch.Tensor) -> int:
|
| 158 |
+
return int(k2q_row_ptr.shape[1] - 1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _validate_csr_varlen_inputs(
|
| 162 |
+
q: torch.Tensor,
|
| 163 |
+
k: torch.Tensor,
|
| 164 |
+
v: torch.Tensor,
|
| 165 |
+
k2q_row_ptr: torch.Tensor,
|
| 166 |
+
k2q_q_indices: torch.Tensor,
|
| 167 |
+
topK: int,
|
| 168 |
+
blk_kv: int,
|
| 169 |
+
page_table: Optional[torch.Tensor],
|
| 170 |
+
cu_seqlens_q: torch.Tensor,
|
| 171 |
+
cu_seqlens_k: torch.Tensor,
|
| 172 |
+
seqused_k: Optional[torch.Tensor],
|
| 173 |
+
) -> tuple[int, int]:
|
| 174 |
+
if q.ndim != 3:
|
| 175 |
+
raise ValueError("CSR sparse forward requires q to have shape [total_q, Hq, D]")
|
| 176 |
+
if q.dtype not in _SUPPORTED_FWD_DTYPES:
|
| 177 |
+
raise TypeError(
|
| 178 |
+
"CSR sparse forward supports only torch.bfloat16 and "
|
| 179 |
+
f"torch.float8_e4m3fn Q/K/V, got {q.dtype}"
|
| 180 |
+
)
|
| 181 |
+
if q.device != k.device or q.device != v.device:
|
| 182 |
+
raise ValueError("q, k, v must be on the same device")
|
| 183 |
+
mixed_fp8_kv_bf16_q = (
|
| 184 |
+
q.dtype == torch.bfloat16
|
| 185 |
+
and k.dtype == torch.float8_e4m3fn
|
| 186 |
+
and v.dtype == torch.float8_e4m3fn
|
| 187 |
+
)
|
| 188 |
+
if not mixed_fp8_kv_bf16_q and (q.dtype != k.dtype or q.dtype != v.dtype):
|
| 189 |
+
raise ValueError(
|
| 190 |
+
"q, k, v must have the same dtype, except q=bf16 with fp8_e4m3 K/V cache"
|
| 191 |
+
)
|
| 192 |
+
if q.shape[-1] != k.shape[-1] or q.shape[-1] != v.shape[-1]:
|
| 193 |
+
raise ValueError("q, k, v must have the same head dimension")
|
| 194 |
+
dim = q.shape[-1]
|
| 195 |
+
if dim != 128:
|
| 196 |
+
raise NotImplementedError(
|
| 197 |
+
f"CSR sparse forward currently supports only D=128, got D={dim}"
|
| 198 |
+
)
|
| 199 |
+
if page_table is None:
|
| 200 |
+
if k.shape[-2] != v.shape[-2] or k.shape[-1] != v.shape[-1]:
|
| 201 |
+
raise ValueError("k and v must have the same [Hkv, D] tail dimensions")
|
| 202 |
+
head_kv = k.shape[-2]
|
| 203 |
+
else:
|
| 204 |
+
if k.ndim != 4 or v.ndim != 4:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
"Sparse Page Attention requires k and v to have shape "
|
| 207 |
+
"[num_pages, Hkv, page_size, D]"
|
| 208 |
+
)
|
| 209 |
+
if k.shape[1] != v.shape[1] or k.shape[-1] != v.shape[-1]:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"Sparse Page Attention k and v must have the same Hkv and D"
|
| 212 |
+
)
|
| 213 |
+
head_kv = k.shape[1]
|
| 214 |
+
if (
|
| 215 |
+
q.device != k2q_row_ptr.device
|
| 216 |
+
or q.device != k2q_q_indices.device
|
| 217 |
+
):
|
| 218 |
+
raise ValueError("CSR metadata must be on the same device as q")
|
| 219 |
+
if (
|
| 220 |
+
k2q_row_ptr.dtype != torch.int32
|
| 221 |
+
or k2q_q_indices.dtype != torch.int32
|
| 222 |
+
):
|
| 223 |
+
raise TypeError("CSR metadata tensors must be torch.int32")
|
| 224 |
+
if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2:
|
| 225 |
+
raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2")
|
| 226 |
+
|
| 227 |
+
_validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device)
|
| 228 |
+
_validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device)
|
| 229 |
+
if cu_seqlens_k.shape != cu_seqlens_q.shape:
|
| 230 |
+
raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q")
|
| 231 |
+
batch = int(cu_seqlens_q.shape[0] - 1)
|
| 232 |
+
total_q = q.shape[0]
|
| 233 |
+
|
| 234 |
+
head_q = q.shape[1]
|
| 235 |
+
if head_q % head_kv != 0:
|
| 236 |
+
raise ValueError("q.shape[1] must be divisible by Hkv")
|
| 237 |
+
qhead_per_kv = head_q // head_kv
|
| 238 |
+
if qhead_per_kv not in (1, 2, 4, 8, 16):
|
| 239 |
+
raise NotImplementedError(
|
| 240 |
+
"CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}"
|
| 241 |
+
)
|
| 242 |
+
if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv:
|
| 243 |
+
raise ValueError("CSR metadata head dimension must match KV head count")
|
| 244 |
+
if k2q_q_indices.shape[1] < total_q * topK:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({total_q * topK})"
|
| 247 |
+
)
|
| 248 |
+
if k2q_row_ptr.shape[1] < 1:
|
| 249 |
+
raise ValueError("k2q_row_ptr must contain at least one row pointer column")
|
| 250 |
+
|
| 251 |
+
if page_table is None:
|
| 252 |
+
if seqused_k is not None:
|
| 253 |
+
raise ValueError("seqused_k is only supported together with page_table")
|
| 254 |
+
total_k = k.shape[0]
|
| 255 |
+
if k.ndim != 3 or v.ndim != 3:
|
| 256 |
+
raise ValueError("Sparse Attention requires k and v to have shape [total_k, Hkv, D]")
|
| 257 |
+
if k.shape != (total_k, head_kv, q.shape[-1]) or v.shape != (total_k, head_kv, q.shape[-1]):
|
| 258 |
+
raise ValueError("Sparse Attention k and v must match [total_k, Hkv, D]")
|
| 259 |
+
else:
|
| 260 |
+
if page_table.device != q.device:
|
| 261 |
+
raise ValueError("page_table must be on the same device as q")
|
| 262 |
+
if page_table.dtype != torch.int32:
|
| 263 |
+
raise TypeError("page_table must be torch.int32")
|
| 264 |
+
if page_table.ndim != 2 or page_table.shape[0] != batch:
|
| 265 |
+
raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
|
| 266 |
+
if page_table.stride(-1) != 1:
|
| 267 |
+
raise ValueError("page_table must be contiguous in the last dimension")
|
| 268 |
+
if k.ndim != 4 or v.ndim != 4:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
"Sparse Page Attention requires k and v to have shape "
|
| 271 |
+
"[num_pages, Hkv, page_size, D]"
|
| 272 |
+
)
|
| 273 |
+
if k.shape != v.shape:
|
| 274 |
+
raise ValueError(f"k and v must have the same shape, got {k.shape} and {v.shape}")
|
| 275 |
+
if k.shape[1] != head_kv or k.shape[3] != q.shape[-1]:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
"Sparse Page Attention k and v must match "
|
| 278 |
+
"[num_pages, Hkv, page_size, D]"
|
| 279 |
+
)
|
| 280 |
+
page_size = int(k.shape[2])
|
| 281 |
+
if page_size != blk_kv:
|
| 282 |
+
raise ValueError(
|
| 283 |
+
f"Unsupported Sparse Page Attention page_size={page_size} for blk_kv={blk_kv}; "
|
| 284 |
+
"require page_size == blk_kv"
|
| 285 |
+
)
|
| 286 |
+
if seqused_k is not None:
|
| 287 |
+
if seqused_k.device != q.device:
|
| 288 |
+
raise ValueError("seqused_k must be on the same device as q")
|
| 289 |
+
if seqused_k.dtype != torch.int32:
|
| 290 |
+
raise TypeError("seqused_k must be torch.int32")
|
| 291 |
+
if seqused_k.shape != (batch,):
|
| 292 |
+
raise ValueError("seqused_k must have shape [B]")
|
| 293 |
+
if not seqused_k.is_contiguous():
|
| 294 |
+
raise ValueError("seqused_k must be contiguous")
|
| 295 |
+
if topK not in _SUPPORTED_SPARSE_TOPK:
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}"
|
| 298 |
+
)
|
| 299 |
+
return batch, head_kv
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _validate_csr_varlen_nvfp4_kv_inputs(
|
| 303 |
+
q: torch.Tensor,
|
| 304 |
+
k: torch.Tensor,
|
| 305 |
+
v: torch.Tensor,
|
| 306 |
+
k_scale_128x4: torch.Tensor,
|
| 307 |
+
v_scale_128x4: torch.Tensor,
|
| 308 |
+
k_global_scale: Optional[torch.Tensor],
|
| 309 |
+
v_global_scale: Optional[torch.Tensor],
|
| 310 |
+
k2q_row_ptr: torch.Tensor,
|
| 311 |
+
k2q_q_indices: torch.Tensor,
|
| 312 |
+
topK: int,
|
| 313 |
+
blk_kv: int,
|
| 314 |
+
page_table: Optional[torch.Tensor],
|
| 315 |
+
cu_seqlens_q: torch.Tensor,
|
| 316 |
+
cu_seqlens_k: torch.Tensor,
|
| 317 |
+
seqused_k: Optional[torch.Tensor],
|
| 318 |
+
) -> tuple[int, int]:
|
| 319 |
+
if q.ndim != 3:
|
| 320 |
+
raise ValueError("KVFP4 CSR sparse forward requires q to have shape [total_q, Hq, D]")
|
| 321 |
+
if q.dtype not in (torch.bfloat16, torch.float8_e4m3fn):
|
| 322 |
+
raise TypeError(f"KVFP4 CSR sparse forward requires BF16 or FP8 E4M3 q, got {q.dtype}")
|
| 323 |
+
if q.shape[-1] != 128:
|
| 324 |
+
raise NotImplementedError(
|
| 325 |
+
f"KVFP4 CSR sparse forward currently supports only D=128, got {q.shape[-1]}"
|
| 326 |
+
)
|
| 327 |
+
if k.dtype != torch.uint8 or v.dtype != torch.uint8:
|
| 328 |
+
raise TypeError(f"KVFP4 k/v must be torch.uint8, got {k.dtype} and {v.dtype}")
|
| 329 |
+
if k_scale_128x4.dtype != torch.uint8 or v_scale_128x4.dtype != torch.uint8:
|
| 330 |
+
raise TypeError(
|
| 331 |
+
"KVFP4 block scales must be torch.uint8 E4M3 tensors, got "
|
| 332 |
+
f"{k_scale_128x4.dtype} and {v_scale_128x4.dtype}"
|
| 333 |
+
)
|
| 334 |
+
if k_global_scale is not None and k_global_scale.dtype != torch.float32:
|
| 335 |
+
raise TypeError("KVFP4 K global scale must be a torch.float32 tensor or None")
|
| 336 |
+
if v_global_scale is not None and v_global_scale.dtype != torch.float32:
|
| 337 |
+
raise TypeError("KVFP4 V global scale must be a torch.float32 tensor or None")
|
| 338 |
+
tensors = (
|
| 339 |
+
k,
|
| 340 |
+
v,
|
| 341 |
+
k_scale_128x4,
|
| 342 |
+
v_scale_128x4,
|
| 343 |
+
k2q_row_ptr,
|
| 344 |
+
k2q_q_indices,
|
| 345 |
+
cu_seqlens_q,
|
| 346 |
+
cu_seqlens_k,
|
| 347 |
+
)
|
| 348 |
+
optional_tensors = tuple(t for t in (k_global_scale, v_global_scale) if t is not None)
|
| 349 |
+
if any(t.device != q.device for t in tensors + optional_tensors):
|
| 350 |
+
raise ValueError("KVFP4 inputs and metadata must be on the same device as q")
|
| 351 |
+
if k.shape != v.shape:
|
| 352 |
+
raise ValueError(f"KVFP4 k and v must have the same shape, got {k.shape} and {v.shape}")
|
| 353 |
+
packed_dim = q.shape[-1] // 2
|
| 354 |
+
scale_cols = q.shape[-1] // 16
|
| 355 |
+
if k_scale_128x4.ndim != 2 or v_scale_128x4.ndim != 2:
|
| 356 |
+
raise ValueError("KVFP4 block scales must be rank-2 128x4 tiled tensors")
|
| 357 |
+
if k_scale_128x4.shape[1] < scale_cols or v_scale_128x4.shape[1] < scale_cols:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"KVFP4 block scales must have at least D/16 columns; "
|
| 360 |
+
f"need {scale_cols}, got {k_scale_128x4.shape[1]} and {v_scale_128x4.shape[1]}"
|
| 361 |
+
)
|
| 362 |
+
if k_global_scale is not None and k_global_scale.numel() < 1:
|
| 363 |
+
raise ValueError("KVFP4 K global scale must contain at least one element")
|
| 364 |
+
if v_global_scale is not None and v_global_scale.numel() < 1:
|
| 365 |
+
raise ValueError("KVFP4 V global scale must contain at least one element")
|
| 366 |
+
|
| 367 |
+
if page_table is None:
|
| 368 |
+
if seqused_k is not None:
|
| 369 |
+
raise ValueError("seqused_k is only supported together with page_table")
|
| 370 |
+
if k.ndim != 3:
|
| 371 |
+
raise ValueError("KVFP4 Sparse Attention requires k/v shape [total_k, Hkv, D/2]")
|
| 372 |
+
if k.shape[-1] != packed_dim:
|
| 373 |
+
raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}")
|
| 374 |
+
total_k = int(k.shape[0])
|
| 375 |
+
head_kv = int(k.shape[1])
|
| 376 |
+
required_scale_rows = total_k * head_kv
|
| 377 |
+
else:
|
| 378 |
+
if k.ndim != 4:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
"KVFP4 Sparse Page Attention requires k/v shape "
|
| 381 |
+
"[num_pages, Hkv, page_size, D/2]"
|
| 382 |
+
)
|
| 383 |
+
if k.shape[-1] != packed_dim:
|
| 384 |
+
raise ValueError(f"KVFP4 packed K/V last dimension must be D/2={packed_dim}")
|
| 385 |
+
page_size = int(k.shape[2])
|
| 386 |
+
if page_size != int(blk_kv):
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"KVFP4 Sparse Page Attention requires page_size == blk_kv, got {page_size} vs {blk_kv}"
|
| 389 |
+
)
|
| 390 |
+
head_kv = int(k.shape[1])
|
| 391 |
+
required_scale_rows = int(k.shape[0]) * head_kv * page_size
|
| 392 |
+
if page_table.device != q.device:
|
| 393 |
+
raise ValueError("page_table must be on the same device as q")
|
| 394 |
+
if page_table.dtype != torch.int32:
|
| 395 |
+
raise TypeError("page_table must be torch.int32")
|
| 396 |
+
if page_table.ndim != 2:
|
| 397 |
+
raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
|
| 398 |
+
if page_table.stride(-1) != 1:
|
| 399 |
+
raise ValueError("page_table must be contiguous in the last dimension")
|
| 400 |
+
if seqused_k is not None:
|
| 401 |
+
if seqused_k.device != q.device:
|
| 402 |
+
raise ValueError("seqused_k must be on the same device as q")
|
| 403 |
+
if seqused_k.dtype != torch.int32:
|
| 404 |
+
raise TypeError("seqused_k must be torch.int32")
|
| 405 |
+
if not seqused_k.is_contiguous():
|
| 406 |
+
raise ValueError("seqused_k must be contiguous")
|
| 407 |
+
|
| 408 |
+
padded_scale_rows = ((required_scale_rows + 127) // 128) * 128
|
| 409 |
+
padded_scale_cols = ((scale_cols + 3) // 4) * 4
|
| 410 |
+
for name, scale in (("k_scale_128x4", k_scale_128x4), ("v_scale_128x4", v_scale_128x4)):
|
| 411 |
+
if scale.shape[0] < padded_scale_rows or scale.shape[1] < padded_scale_cols:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
f"{name} is too small for 128x4 layout: got {tuple(scale.shape)}, "
|
| 414 |
+
f"need at least {(padded_scale_rows, padded_scale_cols)}"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if k2q_row_ptr.device != q.device or k2q_q_indices.device != q.device:
|
| 418 |
+
raise ValueError("CSR metadata must be on the same device as q")
|
| 419 |
+
if k2q_row_ptr.dtype != torch.int32 or k2q_q_indices.dtype != torch.int32:
|
| 420 |
+
raise TypeError("CSR metadata tensors must be torch.int32")
|
| 421 |
+
if k2q_row_ptr.ndim != 2 or k2q_q_indices.ndim != 2:
|
| 422 |
+
raise ValueError("k2q_row_ptr and k2q_q_indices must be rank-2")
|
| 423 |
+
_validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", device=q.device)
|
| 424 |
+
_validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k", device=q.device)
|
| 425 |
+
if cu_seqlens_k.shape != cu_seqlens_q.shape:
|
| 426 |
+
raise ValueError("cu_seqlens_k must have shape [B + 1] matching cu_seqlens_q")
|
| 427 |
+
batch = int(cu_seqlens_q.shape[0] - 1)
|
| 428 |
+
if page_table is not None and page_table.shape[0] != batch:
|
| 429 |
+
raise ValueError("page_table must have shape [B, max_num_pages_per_seq]")
|
| 430 |
+
if seqused_k is not None and seqused_k.shape != (batch,):
|
| 431 |
+
raise ValueError("seqused_k must have shape [B]")
|
| 432 |
+
head_q = int(q.shape[1])
|
| 433 |
+
if head_q % head_kv != 0:
|
| 434 |
+
raise ValueError("q.shape[1] must be divisible by Hkv")
|
| 435 |
+
qhead_per_kv = head_q // head_kv
|
| 436 |
+
if qhead_per_kv not in (1, 2, 4, 8, 16):
|
| 437 |
+
raise NotImplementedError(
|
| 438 |
+
"KVFP4 CSR forward is currently supported only for qhead_per_kv in {1, 2, 4, 8, 16}"
|
| 439 |
+
)
|
| 440 |
+
if k2q_row_ptr.shape[0] != head_kv or k2q_q_indices.shape[0] != head_kv:
|
| 441 |
+
raise ValueError("CSR metadata head dimension must match KV head count")
|
| 442 |
+
if k2q_q_indices.shape[1] < q.shape[0] * topK:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"k2q_q_indices.shape[1] ({k2q_q_indices.shape[1]}) must be >= total_q * topK ({q.shape[0] * topK})"
|
| 445 |
+
)
|
| 446 |
+
if k2q_row_ptr.shape[1] < 1:
|
| 447 |
+
raise ValueError("k2q_row_ptr must contain at least one row pointer column")
|
| 448 |
+
if topK not in _SUPPORTED_SPARSE_TOPK:
|
| 449 |
+
raise ValueError(
|
| 450 |
+
f"KVFP4 CSR sparse forward supports topK in {_SUPPORTED_SPARSE_TOPK}, got {topK}"
|
| 451 |
+
)
|
| 452 |
+
return batch, head_kv
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def _validate_sparse_decode_inputs(
|
| 456 |
+
q: torch.Tensor,
|
| 457 |
+
k: torch.Tensor,
|
| 458 |
+
v: torch.Tensor,
|
| 459 |
+
q2k_indices: Optional[torch.Tensor],
|
| 460 |
+
*,
|
| 461 |
+
page_table: torch.Tensor,
|
| 462 |
+
seqused_k: torch.Tensor,
|
| 463 |
+
seqlen_q: int,
|
| 464 |
+
max_seqlen_k: int,
|
| 465 |
+
blk_kv: int,
|
| 466 |
+
causal: bool,
|
| 467 |
+
) -> tuple[int, int]:
|
| 468 |
+
if q.ndim != 3:
|
| 469 |
+
raise ValueError("decode attention requires q to have shape [total_q, Hq, D]")
|
| 470 |
+
if k.ndim != 4 or v.ndim != 4:
|
| 471 |
+
raise ValueError(
|
| 472 |
+
"decode attention requires paged k/v with shape [num_pages, Hkv, page_size, D]"
|
| 473 |
+
)
|
| 474 |
+
if q.device != k.device or q.device != v.device:
|
| 475 |
+
raise ValueError("decode q, k, and v must be on the same device")
|
| 476 |
+
if q.dtype != torch.float8_e4m3fn or k.dtype != q.dtype or v.dtype != q.dtype:
|
| 477 |
+
raise TypeError(
|
| 478 |
+
"decode attention currently supports only torch.float8_e4m3fn Q/K/V"
|
| 479 |
+
)
|
| 480 |
+
if k.shape != v.shape:
|
| 481 |
+
raise ValueError(f"decode k and v must have the same shape, got {k.shape} and {v.shape}")
|
| 482 |
+
if q.shape[-1] != 128 or k.shape[-1] != 128:
|
| 483 |
+
raise NotImplementedError(
|
| 484 |
+
f"decode attention currently supports only D=128, got q={q.shape[-1]} k={k.shape[-1]}"
|
| 485 |
+
)
|
| 486 |
+
if not bool(causal):
|
| 487 |
+
raise NotImplementedError("decode attention currently supports only causal=True")
|
| 488 |
+
page_size = int(k.shape[2])
|
| 489 |
+
if page_size != int(blk_kv):
|
| 490 |
+
raise ValueError(f"decode attention requires page_size == blk_kv, got {page_size} vs {blk_kv}")
|
| 491 |
+
|
| 492 |
+
head_kv = int(k.shape[1])
|
| 493 |
+
head_q = int(q.shape[1])
|
| 494 |
+
if head_q % head_kv != 0:
|
| 495 |
+
raise ValueError("decode q.shape[1] must be divisible by Hkv")
|
| 496 |
+
qhead_per_kv = head_q // head_kv
|
| 497 |
+
if qhead_per_kv != _SUPPORTED_DECODE_QHEAD_PER_KV:
|
| 498 |
+
raise NotImplementedError(
|
| 499 |
+
"decode attention currently supports only "
|
| 500 |
+
f"qhead_per_kv={_SUPPORTED_DECODE_QHEAD_PER_KV}, got {qhead_per_kv}"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if page_table is None:
|
| 504 |
+
raise ValueError("decode attention requires page_table")
|
| 505 |
+
if page_table.device != q.device:
|
| 506 |
+
raise ValueError("decode page_table must be on the same device as q")
|
| 507 |
+
if page_table.dtype != torch.int32:
|
| 508 |
+
raise TypeError("decode page_table must be torch.int32")
|
| 509 |
+
if page_table.ndim != 2:
|
| 510 |
+
raise ValueError("decode page_table must have shape [B, max_num_pages_per_seq]")
|
| 511 |
+
batch = int(page_table.shape[0])
|
| 512 |
+
if page_table.stride(-1) != 1:
|
| 513 |
+
raise ValueError("decode page_table must be contiguous in the last dimension")
|
| 514 |
+
|
| 515 |
+
if seqused_k is None:
|
| 516 |
+
raise ValueError("decode attention requires seqused_k")
|
| 517 |
+
if seqused_k.device != q.device:
|
| 518 |
+
raise ValueError("decode seqused_k must be on the same device as q")
|
| 519 |
+
if seqused_k.dtype != torch.int32:
|
| 520 |
+
raise TypeError("decode seqused_k must be torch.int32")
|
| 521 |
+
if seqused_k.shape != (batch,):
|
| 522 |
+
raise ValueError("decode seqused_k must have shape [B]")
|
| 523 |
+
if not seqused_k.is_contiguous():
|
| 524 |
+
raise ValueError("decode seqused_k must be contiguous")
|
| 525 |
+
|
| 526 |
+
seqlen_q = int(seqlen_q)
|
| 527 |
+
max_seqlen_k = int(max_seqlen_k)
|
| 528 |
+
if seqlen_q <= 0 or max_seqlen_k <= 0:
|
| 529 |
+
raise ValueError("decode seqlen_q and max_seqlen_k must be positive")
|
| 530 |
+
if int(q.shape[0]) != batch * seqlen_q:
|
| 531 |
+
raise ValueError("decode q.shape[0] must equal batch * seqlen_q")
|
| 532 |
+
|
| 533 |
+
if q2k_indices is not None:
|
| 534 |
+
if q2k_indices.device != q.device:
|
| 535 |
+
raise ValueError("decode q2k_indices must be on the same device as q")
|
| 536 |
+
if q2k_indices.dtype != torch.int32:
|
| 537 |
+
raise TypeError("decode q2k_indices must be torch.int32")
|
| 538 |
+
if q2k_indices.ndim != 3:
|
| 539 |
+
raise ValueError("decode q2k_indices must have shape [Hkv, total_q, topK]")
|
| 540 |
+
if q2k_indices.shape[0] != head_kv or q2k_indices.shape[1] != q.shape[0]:
|
| 541 |
+
raise ValueError("decode q2k_indices must match [Hkv, total_q, topK]")
|
| 542 |
+
if not q2k_indices.is_contiguous():
|
| 543 |
+
raise ValueError("decode q2k_indices must be contiguous")
|
| 544 |
+
return batch, head_kv
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _validate_schedule_common(
|
| 548 |
+
schedule: SparseAttentionSchedule,
|
| 549 |
+
*,
|
| 550 |
+
device: torch.device,
|
| 551 |
+
) -> None:
|
| 552 |
+
if schedule.scheduler_metadata is None:
|
| 553 |
+
raise ValueError("schedule.scheduler_metadata is required")
|
| 554 |
+
if schedule.work_count is None:
|
| 555 |
+
raise ValueError("schedule.work_count is required")
|
| 556 |
+
metadata = schedule.scheduler_metadata
|
| 557 |
+
work_count = schedule.work_count
|
| 558 |
+
if metadata.device != device or work_count.device != device:
|
| 559 |
+
raise ValueError("schedule tensors must be on the same device as q")
|
| 560 |
+
if metadata.dtype != torch.int32 or work_count.dtype != torch.int32:
|
| 561 |
+
raise TypeError("schedule.scheduler_metadata and schedule.work_count must be torch.int32")
|
| 562 |
+
if metadata.ndim != 2 or metadata.shape[1] != 6:
|
| 563 |
+
raise ValueError("schedule.scheduler_metadata must have shape [capacity, 6]")
|
| 564 |
+
if work_count.shape != (1,):
|
| 565 |
+
raise ValueError("schedule.work_count must have shape [1]")
|
| 566 |
+
if not metadata.is_contiguous() or not work_count.is_contiguous():
|
| 567 |
+
raise ValueError("schedule.scheduler_metadata and schedule.work_count must be contiguous")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _validate_fwd_schedule(
|
| 571 |
+
schedule: SparseAttentionSchedule,
|
| 572 |
+
*,
|
| 573 |
+
q: torch.Tensor,
|
| 574 |
+
k2q_q_indices: torch.Tensor,
|
| 575 |
+
head_kv: int,
|
| 576 |
+
) -> None:
|
| 577 |
+
_validate_schedule_common(schedule, device=q.device)
|
| 578 |
+
if schedule.qsplit_indices is None:
|
| 579 |
+
raise ValueError("schedule.qsplit_indices is required for forward")
|
| 580 |
+
if schedule.split_counts is None:
|
| 581 |
+
raise ValueError("schedule.split_counts is required for forward")
|
| 582 |
+
qsplit = schedule.qsplit_indices
|
| 583 |
+
split_counts = schedule.split_counts
|
| 584 |
+
if qsplit.device != q.device or split_counts.device != q.device:
|
| 585 |
+
raise ValueError("forward schedule tensors must be on the same device as q")
|
| 586 |
+
if qsplit.dtype != torch.int32 or split_counts.dtype != torch.int32:
|
| 587 |
+
raise TypeError("schedule.qsplit_indices and schedule.split_counts must be torch.int32")
|
| 588 |
+
if qsplit.shape != k2q_q_indices.shape:
|
| 589 |
+
raise ValueError("schedule.qsplit_indices shape must match k2q_q_indices")
|
| 590 |
+
total_q = q.shape[0]
|
| 591 |
+
if split_counts.shape != (total_q, head_kv):
|
| 592 |
+
raise ValueError(
|
| 593 |
+
"schedule.split_counts must have shape "
|
| 594 |
+
f"({total_q}, {head_kv}), got {tuple(split_counts.shape)}"
|
| 595 |
+
)
|
| 596 |
+
if not qsplit.is_contiguous() or not split_counts.is_contiguous():
|
| 597 |
+
raise ValueError("schedule.qsplit_indices and schedule.split_counts must be contiguous")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def sparse_atten_func(
|
| 601 |
+
q: torch.Tensor,
|
| 602 |
+
k: torch.Tensor,
|
| 603 |
+
v: torch.Tensor,
|
| 604 |
+
k2q_row_ptr: torch.Tensor,
|
| 605 |
+
k2q_q_indices: torch.Tensor,
|
| 606 |
+
topK: int,
|
| 607 |
+
*,
|
| 608 |
+
cu_seqlens_q: torch.Tensor,
|
| 609 |
+
cu_seqlens_k: torch.Tensor,
|
| 610 |
+
max_seqlen_q: int,
|
| 611 |
+
max_seqlen_k: int,
|
| 612 |
+
blk_kv: int = 128,
|
| 613 |
+
causal: bool = False,
|
| 614 |
+
softmax_scale: Optional[float] = None,
|
| 615 |
+
lse_temperature_scale: float = 1.0,
|
| 616 |
+
return_temperature_lse: bool = False,
|
| 617 |
+
partial_dtype: torch.dtype = torch.bfloat16,
|
| 618 |
+
return_softmax_lse: bool = False,
|
| 619 |
+
page_table: Optional[torch.Tensor] = None,
|
| 620 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 621 |
+
schedule: Optional[SparseAttentionSchedule] = None,
|
| 622 |
+
usable_SM_count: int = -1,
|
| 623 |
+
qk_dtype: Optional[torch.dtype] = None,
|
| 624 |
+
pv_dtype: Optional[torch.dtype] = None,
|
| 625 |
+
):
|
| 626 |
+
"""Run SM100 CSR block-sparse varlen attention.
|
| 627 |
+
|
| 628 |
+
This is the public forward-only sparse attention API. It consumes
|
| 629 |
+
query-to-key block selections converted to CSR metadata by
|
| 630 |
+
``build_k2q_csr`` and supports both dense KV layout and paged KV layout.
|
| 631 |
+
|
| 632 |
+
Parameters
|
| 633 |
+
----------
|
| 634 |
+
q : torch.Tensor
|
| 635 |
+
Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and
|
| 636 |
+
FP8 E4M3.
|
| 637 |
+
k : torch.Tensor
|
| 638 |
+
Dense layout ``[total_k, Hkv, 128]`` or paged layout
|
| 639 |
+
``[num_pages, Hkv, blk_kv, 128]``. For BF16 Q with FP8 K/V cache, K
|
| 640 |
+
may be FP8 E4M3 while QK compute uses BF16 staging.
|
| 641 |
+
v : torch.Tensor
|
| 642 |
+
Same layout and head count as ``k``.
|
| 643 |
+
k2q_row_ptr : torch.Tensor
|
| 644 |
+
CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32.
|
| 645 |
+
k2q_q_indices : torch.Tensor
|
| 646 |
+
CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype
|
| 647 |
+
int32.
|
| 648 |
+
topK : int
|
| 649 |
+
Number of selected KV blocks per query. Supported values are
|
| 650 |
+
``4, 8, 16, 32``.
|
| 651 |
+
cu_seqlens_q : torch.Tensor
|
| 652 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
|
| 653 |
+
cu_seqlens_k : torch.Tensor
|
| 654 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
|
| 655 |
+
max_seqlen_q : int
|
| 656 |
+
Maximum Q sequence length in the batch.
|
| 657 |
+
max_seqlen_k : int
|
| 658 |
+
Maximum KV sequence length in the batch.
|
| 659 |
+
blk_kv : int, optional
|
| 660 |
+
KV block size. Paged KV requires ``k.shape[2] == blk_kv``.
|
| 661 |
+
causal : bool, optional
|
| 662 |
+
Whether to apply causal masking.
|
| 663 |
+
softmax_scale : float, optional
|
| 664 |
+
Softmax scale. Defaults to ``1 / sqrt(128)``.
|
| 665 |
+
lse_temperature_scale : float, optional
|
| 666 |
+
Extra divisor used only for temperature-scaled LSE output.
|
| 667 |
+
return_temperature_lse : bool, optional
|
| 668 |
+
If True, also return LSE computed with logits scaled by
|
| 669 |
+
``softmax_scale / lse_temperature_scale``. Requires
|
| 670 |
+
``return_softmax_lse=True``.
|
| 671 |
+
partial_dtype : torch.dtype, optional
|
| 672 |
+
Accumulation dtype for per-block partial O. Supported values are
|
| 673 |
+
FP32, BF16, FP16, and FP8 E4M3.
|
| 674 |
+
return_softmax_lse : bool, optional
|
| 675 |
+
If True, return ``(out, softmax_lse)`` or
|
| 676 |
+
``(out, softmax_lse, temperature_lse)``.
|
| 677 |
+
page_table : torch.Tensor, optional
|
| 678 |
+
Paged-KV physical page table with shape
|
| 679 |
+
``[batch_size, max_num_pages_per_seq]`` and dtype int32.
|
| 680 |
+
seqused_k : torch.Tensor, optional
|
| 681 |
+
Shape ``[batch_size]``, dtype int32. Effective KV length per request
|
| 682 |
+
for paged causal attention.
|
| 683 |
+
schedule : SparseAttentionSchedule, optional
|
| 684 |
+
Prebuilt sparse forward schedule. If omitted, the schedule is built
|
| 685 |
+
during the call.
|
| 686 |
+
usable_SM_count : int, optional
|
| 687 |
+
Maximum number of SMs used by the scheduler. ``-1`` uses all SMs.
|
| 688 |
+
qk_dtype : torch.dtype, optional
|
| 689 |
+
Compile-time MMA operand dtype for QK. Defaults to Q storage dtype,
|
| 690 |
+
except supported FP8 K/V cache staging modes.
|
| 691 |
+
pv_dtype : torch.dtype, optional
|
| 692 |
+
Compile-time MMA operand dtype for PV. Defaults to V storage dtype,
|
| 693 |
+
except supported FP8 K/V cache staging modes.
|
| 694 |
+
|
| 695 |
+
Returns
|
| 696 |
+
-------
|
| 697 |
+
torch.Tensor or tuple[torch.Tensor, torch.Tensor]
|
| 698 |
+
Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE
|
| 699 |
+
outputs have shape ``[total_q, Hq]`` and dtype float32.
|
| 700 |
+
|
| 701 |
+
Notes
|
| 702 |
+
-----
|
| 703 |
+
``Hq / Hkv`` must be one of ``1, 2, 4, 8, 16``. Current kernels support
|
| 704 |
+
head dimension 128 only.
|
| 705 |
+
"""
|
| 706 |
+
if softmax_scale is None:
|
| 707 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 708 |
+
lse_temperature_scale = float(lse_temperature_scale)
|
| 709 |
+
if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
|
| 710 |
+
raise ValueError(
|
| 711 |
+
f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
|
| 712 |
+
)
|
| 713 |
+
return_temperature_lse = bool(return_temperature_lse)
|
| 714 |
+
if return_temperature_lse and not return_softmax_lse:
|
| 715 |
+
raise ValueError("return_temperature_lse=True requires return_softmax_lse=True")
|
| 716 |
+
partial_dtype = _normalize_partial_dtype(partial_dtype)
|
| 717 |
+
qk_dtype, pv_dtype = _resolve_forward_mma_dtypes(q, k, v, qk_dtype, pv_dtype)
|
| 718 |
+
|
| 719 |
+
if cu_seqlens_q is None or cu_seqlens_k is None:
|
| 720 |
+
raise ValueError(
|
| 721 |
+
"sparse_atten_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k"
|
| 722 |
+
)
|
| 723 |
+
batch, head_kv = _validate_csr_varlen_inputs(
|
| 724 |
+
q,
|
| 725 |
+
k,
|
| 726 |
+
v,
|
| 727 |
+
k2q_row_ptr,
|
| 728 |
+
k2q_q_indices,
|
| 729 |
+
topK,
|
| 730 |
+
blk_kv,
|
| 731 |
+
page_table,
|
| 732 |
+
cu_seqlens_q,
|
| 733 |
+
cu_seqlens_k,
|
| 734 |
+
seqused_k,
|
| 735 |
+
)
|
| 736 |
+
max_seqlen_q = int(max_seqlen_q)
|
| 737 |
+
max_seqlen_k = int(max_seqlen_k)
|
| 738 |
+
|
| 739 |
+
return _sparse_atten_csr_varlen_forward(
|
| 740 |
+
q.contiguous(),
|
| 741 |
+
k.contiguous(),
|
| 742 |
+
v.contiguous(),
|
| 743 |
+
k2q_row_ptr.contiguous(),
|
| 744 |
+
k2q_q_indices.contiguous(),
|
| 745 |
+
int(topK),
|
| 746 |
+
int(blk_kv),
|
| 747 |
+
bool(causal),
|
| 748 |
+
float(softmax_scale),
|
| 749 |
+
lse_temperature_scale,
|
| 750 |
+
return_temperature_lse,
|
| 751 |
+
partial_dtype,
|
| 752 |
+
bool(return_softmax_lse),
|
| 753 |
+
cu_seqlens_q.contiguous(),
|
| 754 |
+
cu_seqlens_k.contiguous(),
|
| 755 |
+
None if page_table is None else page_table.contiguous(),
|
| 756 |
+
None if seqused_k is None else seqused_k.contiguous(),
|
| 757 |
+
schedule,
|
| 758 |
+
int(usable_SM_count),
|
| 759 |
+
int(batch),
|
| 760 |
+
int(head_kv),
|
| 761 |
+
int(max_seqlen_q),
|
| 762 |
+
int(max_seqlen_k),
|
| 763 |
+
qk_dtype,
|
| 764 |
+
pv_dtype,
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def sparse_atten_nvfp4_kv_func(
|
| 769 |
+
q: torch.Tensor,
|
| 770 |
+
k: torch.Tensor,
|
| 771 |
+
v: torch.Tensor,
|
| 772 |
+
k_scale_128x4: torch.Tensor,
|
| 773 |
+
v_scale_128x4: torch.Tensor,
|
| 774 |
+
k_global_scale: Optional[torch.Tensor],
|
| 775 |
+
v_global_scale: Optional[torch.Tensor],
|
| 776 |
+
k2q_row_ptr: torch.Tensor,
|
| 777 |
+
k2q_q_indices: torch.Tensor,
|
| 778 |
+
topK: int,
|
| 779 |
+
*,
|
| 780 |
+
cu_seqlens_q: torch.Tensor,
|
| 781 |
+
cu_seqlens_k: torch.Tensor,
|
| 782 |
+
max_seqlen_q: int,
|
| 783 |
+
max_seqlen_k: int,
|
| 784 |
+
blk_kv: int = 128,
|
| 785 |
+
causal: bool = False,
|
| 786 |
+
softmax_scale: Optional[float] = None,
|
| 787 |
+
lse_temperature_scale: float = 1.0,
|
| 788 |
+
return_temperature_lse: bool = False,
|
| 789 |
+
partial_dtype: torch.dtype = torch.bfloat16,
|
| 790 |
+
return_softmax_lse: bool = False,
|
| 791 |
+
page_table: Optional[torch.Tensor] = None,
|
| 792 |
+
seqused_k: Optional[torch.Tensor] = None,
|
| 793 |
+
schedule: Optional[SparseAttentionSchedule] = None,
|
| 794 |
+
):
|
| 795 |
+
"""Run SM100 CSR sparse attention with packed NVFP4 K/V.
|
| 796 |
+
|
| 797 |
+
Parameters
|
| 798 |
+
----------
|
| 799 |
+
q : torch.Tensor
|
| 800 |
+
Shape ``[total_q, Hq, 128]`` on CUDA. Supported dtypes are BF16 and
|
| 801 |
+
FP8 E4M3.
|
| 802 |
+
k : torch.Tensor
|
| 803 |
+
Packed NVFP4 K data. Dense layout is ``[total_k, Hkv, 64]``; paged
|
| 804 |
+
layout is ``[num_pages, Hkv, blk_kv, 64]``. Dtype must be uint8
|
| 805 |
+
because each byte packs two FP4 values.
|
| 806 |
+
v : torch.Tensor
|
| 807 |
+
Packed NVFP4 V data with the same shape as ``k``.
|
| 808 |
+
k_scale_128x4 : torch.Tensor
|
| 809 |
+
K block scales in cuBLAS/cuDNN 128x4 tiled storage. Dtype uint8
|
| 810 |
+
containing FP8 E4M3 scale values.
|
| 811 |
+
v_scale_128x4 : torch.Tensor
|
| 812 |
+
V block scales in the same 128x4 tiled storage.
|
| 813 |
+
k_global_scale : torch.Tensor, optional
|
| 814 |
+
FP32 tensor/global dequant scale for K. May be ``None``.
|
| 815 |
+
v_global_scale : torch.Tensor, optional
|
| 816 |
+
FP32 tensor/global dequant scale for V. May be ``None``. The V global
|
| 817 |
+
scale is applied in the combine stage.
|
| 818 |
+
k2q_row_ptr : torch.Tensor
|
| 819 |
+
CSR row pointers with shape ``[Hkv, total_rows + 1]`` and dtype int32.
|
| 820 |
+
k2q_q_indices : torch.Tensor
|
| 821 |
+
CSR query indices with shape ``[Hkv, >= total_q * topK]`` and dtype
|
| 822 |
+
int32.
|
| 823 |
+
topK : int
|
| 824 |
+
Number of selected KV blocks per query. Supported values are
|
| 825 |
+
``4, 8, 16, 32``.
|
| 826 |
+
cu_seqlens_q, cu_seqlens_k : torch.Tensor
|
| 827 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q and KV
|
| 828 |
+
lengths.
|
| 829 |
+
max_seqlen_q, max_seqlen_k : int
|
| 830 |
+
Maximum Q and KV sequence lengths in the batch.
|
| 831 |
+
blk_kv : int, optional
|
| 832 |
+
KV block/page size. Paged KV requires ``k.shape[2] == blk_kv``.
|
| 833 |
+
causal : bool, optional
|
| 834 |
+
Whether to apply causal masking.
|
| 835 |
+
softmax_scale : float, optional
|
| 836 |
+
Softmax scale. Defaults to ``1 / sqrt(128)``.
|
| 837 |
+
lse_temperature_scale : float, optional
|
| 838 |
+
Extra divisor used only for temperature-scaled LSE output.
|
| 839 |
+
return_temperature_lse : bool, optional
|
| 840 |
+
If True, also return temperature-scaled LSE. Requires
|
| 841 |
+
``return_softmax_lse=True``.
|
| 842 |
+
partial_dtype : torch.dtype, optional
|
| 843 |
+
Accumulation dtype for per-block partial O.
|
| 844 |
+
return_softmax_lse : bool, optional
|
| 845 |
+
If True, return LSE together with the output.
|
| 846 |
+
page_table : torch.Tensor, optional
|
| 847 |
+
Paged-KV physical page table with shape
|
| 848 |
+
``[batch_size, max_num_pages_per_seq]`` and dtype int32.
|
| 849 |
+
seqused_k : torch.Tensor, optional
|
| 850 |
+
Effective KV length per request for paged causal attention.
|
| 851 |
+
schedule : SparseAttentionSchedule, optional
|
| 852 |
+
Prebuilt sparse forward schedule.
|
| 853 |
+
|
| 854 |
+
Returns
|
| 855 |
+
-------
|
| 856 |
+
torch.Tensor or tuple[torch.Tensor, torch.Tensor]
|
| 857 |
+
Output shape ``[total_q, Hq, 128]`` with BF16 dtype. Optional LSE
|
| 858 |
+
outputs have shape ``[total_q, Hq]`` and dtype float32.
|
| 859 |
+
"""
|
| 860 |
+
|
| 861 |
+
if softmax_scale is None:
|
| 862 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 863 |
+
lse_temperature_scale = float(lse_temperature_scale)
|
| 864 |
+
if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
|
| 865 |
+
raise ValueError(
|
| 866 |
+
f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
|
| 867 |
+
)
|
| 868 |
+
return_temperature_lse = bool(return_temperature_lse)
|
| 869 |
+
if return_temperature_lse and not return_softmax_lse:
|
| 870 |
+
raise ValueError("return_temperature_lse=True requires return_softmax_lse=True")
|
| 871 |
+
partial_dtype = _normalize_partial_dtype(partial_dtype)
|
| 872 |
+
|
| 873 |
+
if cu_seqlens_q is None or cu_seqlens_k is None:
|
| 874 |
+
raise ValueError(
|
| 875 |
+
"sparse_atten_nvfp4_kv_func requires CSR varlen metadata: pass cu_seqlens_q and cu_seqlens_k"
|
| 876 |
+
)
|
| 877 |
+
batch, head_kv = _validate_csr_varlen_nvfp4_kv_inputs(
|
| 878 |
+
q,
|
| 879 |
+
k,
|
| 880 |
+
v,
|
| 881 |
+
k_scale_128x4,
|
| 882 |
+
v_scale_128x4,
|
| 883 |
+
k_global_scale,
|
| 884 |
+
v_global_scale,
|
| 885 |
+
k2q_row_ptr,
|
| 886 |
+
k2q_q_indices,
|
| 887 |
+
topK,
|
| 888 |
+
blk_kv,
|
| 889 |
+
page_table,
|
| 890 |
+
cu_seqlens_q,
|
| 891 |
+
cu_seqlens_k,
|
| 892 |
+
seqused_k,
|
| 893 |
+
)
|
| 894 |
+
total_q, head_q, dim = q.shape
|
| 895 |
+
max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr)
|
| 896 |
+
temperature_lse_fast_path = (
|
| 897 |
+
return_temperature_lse
|
| 898 |
+
and math.isclose(
|
| 899 |
+
float(lse_temperature_scale),
|
| 900 |
+
1.0,
|
| 901 |
+
rel_tol=0.0,
|
| 902 |
+
abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL,
|
| 903 |
+
)
|
| 904 |
+
)
|
| 905 |
+
kernel_return_temperature_lse = (
|
| 906 |
+
return_temperature_lse and not temperature_lse_fast_path
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
O_partial = torch.empty(
|
| 910 |
+
topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device
|
| 911 |
+
)
|
| 912 |
+
LSE_partial = torch.empty(
|
| 913 |
+
topK, total_q, head_q, dtype=torch.float32, device=q.device
|
| 914 |
+
)
|
| 915 |
+
LSE_temperature_partial = (
|
| 916 |
+
torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device)
|
| 917 |
+
if kernel_return_temperature_lse
|
| 918 |
+
else None
|
| 919 |
+
)
|
| 920 |
+
O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device)
|
| 921 |
+
LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device)
|
| 922 |
+
LSE_temperature_out = (
|
| 923 |
+
torch.empty_like(LSE_out) if kernel_return_temperature_lse else None
|
| 924 |
+
)
|
| 925 |
+
if schedule is None:
|
| 926 |
+
k2q_qsplit_indices = torch.empty_like(k2q_q_indices)
|
| 927 |
+
split_counts = torch.zeros(
|
| 928 |
+
(total_q, head_kv),
|
| 929 |
+
dtype=torch.int32,
|
| 930 |
+
device=q.device,
|
| 931 |
+
)
|
| 932 |
+
else:
|
| 933 |
+
_validate_fwd_schedule(
|
| 934 |
+
schedule,
|
| 935 |
+
q=q,
|
| 936 |
+
k2q_q_indices=k2q_q_indices,
|
| 937 |
+
head_kv=head_kv,
|
| 938 |
+
)
|
| 939 |
+
k2q_qsplit_indices = schedule.qsplit_indices
|
| 940 |
+
split_counts = schedule.split_counts
|
| 941 |
+
|
| 942 |
+
schedule = _call_sparse_forward_sm100_csr_varlen_nvfp4_kv(
|
| 943 |
+
q.contiguous(),
|
| 944 |
+
k.contiguous(),
|
| 945 |
+
v.contiguous(),
|
| 946 |
+
k_scale_128x4.contiguous(),
|
| 947 |
+
v_scale_128x4.contiguous(),
|
| 948 |
+
None if k_global_scale is None else k_global_scale.contiguous(),
|
| 949 |
+
None if v_global_scale is None else v_global_scale.contiguous(),
|
| 950 |
+
k2q_row_ptr.contiguous(),
|
| 951 |
+
k2q_q_indices.contiguous(),
|
| 952 |
+
k2q_qsplit_indices.contiguous(),
|
| 953 |
+
split_counts.contiguous(),
|
| 954 |
+
cu_seqlens_q.contiguous(),
|
| 955 |
+
cu_seqlens_k.contiguous(),
|
| 956 |
+
None if page_table is None else page_table.contiguous(),
|
| 957 |
+
None if seqused_k is None else seqused_k.contiguous(),
|
| 958 |
+
O_partial,
|
| 959 |
+
LSE_partial,
|
| 960 |
+
LSE_temperature_partial,
|
| 961 |
+
float(softmax_scale),
|
| 962 |
+
lse_temperature_scale,
|
| 963 |
+
kernel_return_temperature_lse,
|
| 964 |
+
max_num_kv_blocks,
|
| 965 |
+
int(blk_kv),
|
| 966 |
+
head_kv,
|
| 967 |
+
int(max_seqlen_q),
|
| 968 |
+
causal=bool(causal),
|
| 969 |
+
schedule=schedule,
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
combine(
|
| 973 |
+
O_partial,
|
| 974 |
+
LSE_partial,
|
| 975 |
+
O_out,
|
| 976 |
+
LSE_out,
|
| 977 |
+
lse_temperature_partial=LSE_temperature_partial,
|
| 978 |
+
lse_temperature_out=LSE_temperature_out,
|
| 979 |
+
cu_seqlens=cu_seqlens_q,
|
| 980 |
+
split_counts=split_counts,
|
| 981 |
+
output_scale=v_global_scale,
|
| 982 |
+
use_pdl=True,
|
| 983 |
+
)
|
| 984 |
+
if temperature_lse_fast_path:
|
| 985 |
+
LSE_temperature_out = LSE_out
|
| 986 |
+
|
| 987 |
+
if return_softmax_lse:
|
| 988 |
+
if return_temperature_lse:
|
| 989 |
+
return O_out, LSE_out, LSE_temperature_out
|
| 990 |
+
return O_out, LSE_out
|
| 991 |
+
return O_out
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def sparse_decode_atten_func(
|
| 995 |
+
q: torch.Tensor,
|
| 996 |
+
k: torch.Tensor,
|
| 997 |
+
v: torch.Tensor,
|
| 998 |
+
q2k_indices: Optional[torch.Tensor] = None,
|
| 999 |
+
*,
|
| 1000 |
+
page_table: torch.Tensor,
|
| 1001 |
+
seqused_k: torch.Tensor,
|
| 1002 |
+
seqlen_q: int,
|
| 1003 |
+
max_seqlen_k: int,
|
| 1004 |
+
blk_kv: int = 128,
|
| 1005 |
+
causal: bool = True,
|
| 1006 |
+
softmax_scale: Optional[float] = None,
|
| 1007 |
+
return_softmax_lse: bool = False,
|
| 1008 |
+
schedule: Optional[DecodeAttentionSchedule] = None,
|
| 1009 |
+
O_partial: Optional[torch.Tensor] = None,
|
| 1010 |
+
LSE_partial: Optional[torch.Tensor] = None,
|
| 1011 |
+
):
|
| 1012 |
+
"""Run forward-only paged FP8 decode attention.
|
| 1013 |
+
|
| 1014 |
+
Parameters
|
| 1015 |
+
----------
|
| 1016 |
+
q : torch.Tensor
|
| 1017 |
+
Shape ``[batch_size * seqlen_q, Hq, 128]``. Dtype must be FP8 E4M3.
|
| 1018 |
+
k : torch.Tensor
|
| 1019 |
+
Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]`` and FP8
|
| 1020 |
+
E4M3 dtype.
|
| 1021 |
+
v : torch.Tensor
|
| 1022 |
+
Paged V cache with the same shape and dtype as ``k``.
|
| 1023 |
+
q2k_indices : torch.Tensor, optional
|
| 1024 |
+
Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and dtype
|
| 1025 |
+
int32. ``None`` selects the dense all-KV decode path.
|
| 1026 |
+
page_table : torch.Tensor
|
| 1027 |
+
Physical page table with shape ``[batch_size, max_num_pages_per_seq]``
|
| 1028 |
+
and dtype int32.
|
| 1029 |
+
seqused_k : torch.Tensor
|
| 1030 |
+
Shape ``[batch_size]``, dtype int32. Effective KV length per request.
|
| 1031 |
+
seqlen_q : int
|
| 1032 |
+
Uniform query length per request. Ragged Q lengths should use prefill
|
| 1033 |
+
or append paths instead.
|
| 1034 |
+
max_seqlen_k : int
|
| 1035 |
+
Maximum KV sequence length in the batch.
|
| 1036 |
+
blk_kv : int, optional
|
| 1037 |
+
Page size. Must match ``k.shape[2]``.
|
| 1038 |
+
causal : bool, optional
|
| 1039 |
+
Whether to apply causal masking. Current decode kernel requires True.
|
| 1040 |
+
softmax_scale : float, optional
|
| 1041 |
+
Softmax scale. Defaults to ``1 / sqrt(128)``.
|
| 1042 |
+
return_softmax_lse : bool, optional
|
| 1043 |
+
If True, return ``(out, lse)``.
|
| 1044 |
+
schedule : DecodeAttentionSchedule, optional
|
| 1045 |
+
Prebuilt decode schedule.
|
| 1046 |
+
O_partial, LSE_partial : torch.Tensor, optional
|
| 1047 |
+
Optional split-KV partial workspaces. Normally owned by
|
| 1048 |
+
``SparseDecodePagedAttentionWrapper``.
|
| 1049 |
+
|
| 1050 |
+
Returns
|
| 1051 |
+
-------
|
| 1052 |
+
torch.Tensor or tuple[torch.Tensor, torch.Tensor]
|
| 1053 |
+
BF16 output with shape ``q.shape``. Optional LSE has shape
|
| 1054 |
+
``[batch_size * seqlen_q, Hq]`` and dtype float32.
|
| 1055 |
+
"""
|
| 1056 |
+
if softmax_scale is None:
|
| 1057 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 1058 |
+
batch, head_kv = _validate_sparse_decode_inputs(
|
| 1059 |
+
q,
|
| 1060 |
+
k,
|
| 1061 |
+
v,
|
| 1062 |
+
q2k_indices,
|
| 1063 |
+
page_table=page_table,
|
| 1064 |
+
seqused_k=seqused_k,
|
| 1065 |
+
seqlen_q=seqlen_q,
|
| 1066 |
+
max_seqlen_k=max_seqlen_k,
|
| 1067 |
+
blk_kv=blk_kv,
|
| 1068 |
+
causal=causal,
|
| 1069 |
+
)
|
| 1070 |
+
head_q = int(q.shape[1])
|
| 1071 |
+
head_dim = int(q.shape[2])
|
| 1072 |
+
if schedule is None:
|
| 1073 |
+
schedule = prepare_decode_schedule(
|
| 1074 |
+
seqused_k=seqused_k.contiguous(),
|
| 1075 |
+
page_size=int(blk_kv),
|
| 1076 |
+
seqlen_q=int(seqlen_q),
|
| 1077 |
+
num_qo_heads=head_q,
|
| 1078 |
+
num_kv_heads=head_kv,
|
| 1079 |
+
head_dim=head_dim,
|
| 1080 |
+
max_seqlen_k=int(max_seqlen_k),
|
| 1081 |
+
)
|
| 1082 |
+
if schedule.split_kv:
|
| 1083 |
+
if O_partial is None:
|
| 1084 |
+
O_partial = torch.empty(
|
| 1085 |
+
(schedule.partial_rows, head_q, head_dim),
|
| 1086 |
+
dtype=torch.float32,
|
| 1087 |
+
device=q.device,
|
| 1088 |
+
)
|
| 1089 |
+
if LSE_partial is None:
|
| 1090 |
+
LSE_partial = torch.empty(
|
| 1091 |
+
(schedule.partial_rows, head_q),
|
| 1092 |
+
dtype=torch.float32,
|
| 1093 |
+
device=q.device,
|
| 1094 |
+
)
|
| 1095 |
+
out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device)
|
| 1096 |
+
lse = torch.empty(
|
| 1097 |
+
q.shape[:2] if (return_softmax_lse or schedule.split_kv) else (1, head_q),
|
| 1098 |
+
dtype=torch.float32,
|
| 1099 |
+
device=q.device,
|
| 1100 |
+
)
|
| 1101 |
+
_call_sparse_decode_forward_sm100_paged_fp8(
|
| 1102 |
+
q.contiguous(),
|
| 1103 |
+
k.contiguous(),
|
| 1104 |
+
v.contiguous(),
|
| 1105 |
+
None if q2k_indices is None else q2k_indices.contiguous(),
|
| 1106 |
+
page_table.contiguous(),
|
| 1107 |
+
seqused_k.contiguous(),
|
| 1108 |
+
out,
|
| 1109 |
+
lse,
|
| 1110 |
+
schedule,
|
| 1111 |
+
O_partial,
|
| 1112 |
+
LSE_partial,
|
| 1113 |
+
softmax_scale=float(softmax_scale),
|
| 1114 |
+
seqlen_q=int(seqlen_q),
|
| 1115 |
+
max_seqlen_k=int(max_seqlen_k),
|
| 1116 |
+
blk_kv=int(blk_kv),
|
| 1117 |
+
causal=bool(causal),
|
| 1118 |
+
return_lse=bool(return_softmax_lse),
|
| 1119 |
+
)
|
| 1120 |
+
if return_softmax_lse:
|
| 1121 |
+
return out, lse
|
| 1122 |
+
return out
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
class SparseDecodePagedAttentionWrapper:
|
| 1126 |
+
"""Plan/run helper for paged FP8 decode attention.
|
| 1127 |
+
|
| 1128 |
+
Use this wrapper when the same page table shape and sequence metadata are
|
| 1129 |
+
reused across multiple decode layers. ``plan`` validates metadata and
|
| 1130 |
+
allocates persistent schedules/workspaces; ``run`` then launches the decode
|
| 1131 |
+
kernel with lower per-call overhead than ``sparse_decode_atten_func``.
|
| 1132 |
+
"""
|
| 1133 |
+
|
| 1134 |
+
def __init__(self, *, blk_kv: int = 128, causal: bool = True):
|
| 1135 |
+
self.blk_kv = int(blk_kv)
|
| 1136 |
+
self.causal = bool(causal)
|
| 1137 |
+
self.batch: Optional[int] = None
|
| 1138 |
+
self.num_qo_heads: Optional[int] = None
|
| 1139 |
+
self.num_kv_heads: Optional[int] = None
|
| 1140 |
+
self.head_dim: Optional[int] = None
|
| 1141 |
+
self.page_table: Optional[torch.Tensor] = None
|
| 1142 |
+
self.seqused_k: Optional[torch.Tensor] = None
|
| 1143 |
+
self.q2k_indices: Optional[torch.Tensor] = None
|
| 1144 |
+
self.seqlen_q: Optional[int] = None
|
| 1145 |
+
self.max_seqlen_k: Optional[int] = None
|
| 1146 |
+
self.is_sparse: bool = False
|
| 1147 |
+
self.decode_schedule: Optional[DecodeAttentionSchedule] = None
|
| 1148 |
+
self.request_indices: Optional[torch.Tensor] = None
|
| 1149 |
+
self.qo_tile_indices: Optional[torch.Tensor] = None
|
| 1150 |
+
self.kv_tile_indices: Optional[torch.Tensor] = None
|
| 1151 |
+
self.merge_indptr: Optional[torch.Tensor] = None
|
| 1152 |
+
self.o_indptr: Optional[torch.Tensor] = None
|
| 1153 |
+
self.block_valid_mask: Optional[torch.Tensor] = None
|
| 1154 |
+
self.kv_pages: Optional[torch.Tensor] = None
|
| 1155 |
+
self.split_counts: Optional[torch.Tensor] = None
|
| 1156 |
+
self.split_kv: bool = False
|
| 1157 |
+
self.cta_tile_q: int = 0
|
| 1158 |
+
self.num_q_tiles: int = 0
|
| 1159 |
+
self.kv_chunk_size_pages: int = 0
|
| 1160 |
+
self.kv_chunk_size_tokens: int = 0
|
| 1161 |
+
self.work_count: int = 0
|
| 1162 |
+
self.padded_work_count: int = 0
|
| 1163 |
+
self.O_partial: Optional[torch.Tensor] = None
|
| 1164 |
+
self.LSE_partial: Optional[torch.Tensor] = None
|
| 1165 |
+
# Cached dummy buffers used in non-split path to satisfy the kernel's
|
| 1166 |
+
# positional arg signature without per-call torch.empty (saves ~5us
|
| 1167 |
+
# on every run() for small kv).
|
| 1168 |
+
self._O_partial_dummy: Optional[torch.Tensor] = None
|
| 1169 |
+
self._LSE_partial_dummy: Optional[torch.Tensor] = None
|
| 1170 |
+
# When the caller doesn't ask for LSE, the kernel still needs a valid
|
| 1171 |
+
# tensor pointer to write to. Cache a small placeholder so run() can
|
| 1172 |
+
# skip the per-call torch.empty for it as well.
|
| 1173 |
+
self._lse_dummy: Optional[torch.Tensor] = None
|
| 1174 |
+
|
| 1175 |
+
def plan(
|
| 1176 |
+
self,
|
| 1177 |
+
*,
|
| 1178 |
+
page_table: torch.Tensor,
|
| 1179 |
+
seqused_k: torch.Tensor,
|
| 1180 |
+
seqlen_q: int,
|
| 1181 |
+
max_seqlen_k: int,
|
| 1182 |
+
q2k_indices: Optional[torch.Tensor] = None,
|
| 1183 |
+
num_qo_heads: Optional[int] = None,
|
| 1184 |
+
num_kv_heads: Optional[int] = None,
|
| 1185 |
+
head_dim: Optional[int] = 128,
|
| 1186 |
+
enable_cuda_graph: bool = False,
|
| 1187 |
+
max_grid_size: Optional[int] = None,
|
| 1188 |
+
fixed_split_size: Optional[int] = None,
|
| 1189 |
+
disable_split_kv: bool = False,
|
| 1190 |
+
) -> "SparseDecodePagedAttentionWrapper":
|
| 1191 |
+
"""Prepare decode scheduling metadata and reusable workspaces.
|
| 1192 |
+
|
| 1193 |
+
Parameters
|
| 1194 |
+
----------
|
| 1195 |
+
page_table : torch.Tensor
|
| 1196 |
+
Shape ``[batch_size, max_num_pages_per_seq]``, dtype int32. Maps
|
| 1197 |
+
logical pages to physical KV-cache pages.
|
| 1198 |
+
seqused_k : torch.Tensor
|
| 1199 |
+
Shape ``[batch_size]``, dtype int32. Effective KV length per
|
| 1200 |
+
request.
|
| 1201 |
+
seqlen_q : int
|
| 1202 |
+
Uniform query length per request.
|
| 1203 |
+
max_seqlen_k : int
|
| 1204 |
+
Maximum KV sequence length in the batch.
|
| 1205 |
+
q2k_indices : torch.Tensor, optional
|
| 1206 |
+
Sparse selected KV blocks with shape ``[Hkv, total_q, topK]`` and
|
| 1207 |
+
dtype int32. ``None`` selects the dense all-KV path.
|
| 1208 |
+
num_qo_heads : int
|
| 1209 |
+
Number of Q/O heads.
|
| 1210 |
+
num_kv_heads : int
|
| 1211 |
+
Number of KV heads. Current decode kernel requires
|
| 1212 |
+
``num_qo_heads / num_kv_heads == 16`` at run time.
|
| 1213 |
+
head_dim : int, optional
|
| 1214 |
+
Head dimension. Must be 128.
|
| 1215 |
+
enable_cuda_graph : bool, optional
|
| 1216 |
+
Build schedule metadata compatible with CUDA graph capture.
|
| 1217 |
+
max_grid_size : int, optional
|
| 1218 |
+
Override maximum CTA count used by the scheduler.
|
| 1219 |
+
fixed_split_size : int, optional
|
| 1220 |
+
Force a fixed split-KV chunk size in pages.
|
| 1221 |
+
disable_split_kv : bool, optional
|
| 1222 |
+
Disable split-KV even for long KV sequences.
|
| 1223 |
+
|
| 1224 |
+
Returns
|
| 1225 |
+
-------
|
| 1226 |
+
SparseDecodePagedAttentionWrapper
|
| 1227 |
+
``self``, planned and ready for ``run``.
|
| 1228 |
+
"""
|
| 1229 |
+
if page_table.ndim != 2:
|
| 1230 |
+
raise ValueError("decode plan requires page_table with shape [B, max_num_pages_per_seq]")
|
| 1231 |
+
if page_table.dtype != torch.int32:
|
| 1232 |
+
raise TypeError("decode plan requires page_table to be torch.int32")
|
| 1233 |
+
if seqused_k.dtype != torch.int32:
|
| 1234 |
+
raise TypeError("decode plan requires seqused_k to be torch.int32")
|
| 1235 |
+
if not page_table.is_cuda or not seqused_k.is_cuda:
|
| 1236 |
+
raise ValueError("decode plan requires page_table and seqused_k to be CUDA tensors")
|
| 1237 |
+
if page_table.device != seqused_k.device:
|
| 1238 |
+
raise ValueError("decode plan requires page_table and seqused_k on the same device")
|
| 1239 |
+
if page_table.stride(-1) != 1:
|
| 1240 |
+
raise ValueError("decode plan requires page_table contiguous in the last dimension")
|
| 1241 |
+
if seqused_k.shape != (int(page_table.shape[0]),):
|
| 1242 |
+
raise ValueError("decode plan requires seqused_k with shape [B]")
|
| 1243 |
+
if q2k_indices is not None and q2k_indices.dtype != torch.int32:
|
| 1244 |
+
raise TypeError("decode plan requires q2k_indices to be torch.int32")
|
| 1245 |
+
if int(seqlen_q) <= 0 or int(max_seqlen_k) <= 0:
|
| 1246 |
+
raise ValueError("decode plan requires positive seqlen_q and max_seqlen_k")
|
| 1247 |
+
if num_qo_heads is None or num_kv_heads is None or head_dim is None:
|
| 1248 |
+
raise ValueError("decode plan requires num_qo_heads, num_kv_heads, and head_dim")
|
| 1249 |
+
if head_dim is not None and int(head_dim) != 128:
|
| 1250 |
+
raise NotImplementedError("decode plan currently supports only head_dim=128")
|
| 1251 |
+
if int(num_qo_heads) % int(num_kv_heads) != 0:
|
| 1252 |
+
raise ValueError("decode plan requires num_qo_heads divisible by num_kv_heads")
|
| 1253 |
+
|
| 1254 |
+
self.batch = int(page_table.shape[0])
|
| 1255 |
+
self.num_qo_heads = None if num_qo_heads is None else int(num_qo_heads)
|
| 1256 |
+
self.num_kv_heads = None if num_kv_heads is None else int(num_kv_heads)
|
| 1257 |
+
self.head_dim = None if head_dim is None else int(head_dim)
|
| 1258 |
+
self.page_table = page_table.contiguous()
|
| 1259 |
+
self.seqused_k = seqused_k.contiguous()
|
| 1260 |
+
self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous()
|
| 1261 |
+
self.seqlen_q = int(seqlen_q)
|
| 1262 |
+
self.max_seqlen_k = int(max_seqlen_k)
|
| 1263 |
+
self.is_sparse = q2k_indices is not None
|
| 1264 |
+
|
| 1265 |
+
# max_grid_size is hardcoded to num_sms (1 CTA/SM) inside the C++
|
| 1266 |
+
# schedule launcher because the decode attn kernel always runs at
|
| 1267 |
+
# 1 CTA/SM (its register/smem budget saturates the SM). Callers
|
| 1268 |
+
# can still override via the explicit max_grid_size kwarg.
|
| 1269 |
+
schedule = prepare_decode_schedule(
|
| 1270 |
+
seqused_k=self.seqused_k,
|
| 1271 |
+
page_size=self.blk_kv,
|
| 1272 |
+
seqlen_q=self.seqlen_q,
|
| 1273 |
+
num_qo_heads=self.num_qo_heads,
|
| 1274 |
+
num_kv_heads=self.num_kv_heads,
|
| 1275 |
+
head_dim=self.head_dim,
|
| 1276 |
+
max_seqlen_k=self.max_seqlen_k,
|
| 1277 |
+
enable_cuda_graph=bool(enable_cuda_graph),
|
| 1278 |
+
max_grid_size=max_grid_size,
|
| 1279 |
+
fixed_split_size=fixed_split_size,
|
| 1280 |
+
disable_split_kv=bool(disable_split_kv),
|
| 1281 |
+
)
|
| 1282 |
+
self.decode_schedule = schedule
|
| 1283 |
+
self.request_indices = schedule.request_indices
|
| 1284 |
+
self.qo_tile_indices = schedule.qo_tile_indices
|
| 1285 |
+
self.kv_tile_indices = schedule.kv_tile_indices
|
| 1286 |
+
self.merge_indptr = schedule.merge_indptr
|
| 1287 |
+
self.o_indptr = schedule.o_indptr
|
| 1288 |
+
self.block_valid_mask = schedule.block_valid_mask
|
| 1289 |
+
self.kv_pages = schedule.kv_pages
|
| 1290 |
+
self.split_counts = schedule.split_counts
|
| 1291 |
+
self.split_kv = schedule.split_kv
|
| 1292 |
+
self.cta_tile_q = schedule.cta_tile_q
|
| 1293 |
+
self.num_q_tiles = schedule.num_q_tiles
|
| 1294 |
+
self.kv_chunk_size_pages = schedule.kv_chunk_size_pages
|
| 1295 |
+
self.kv_chunk_size_tokens = schedule.kv_chunk_size_tokens
|
| 1296 |
+
self.work_count = schedule.work_count
|
| 1297 |
+
self.padded_work_count = schedule.padded_work_count
|
| 1298 |
+
if schedule.split_kv:
|
| 1299 |
+
self.O_partial = torch.empty(
|
| 1300 |
+
(schedule.partial_rows, self.num_qo_heads, self.head_dim),
|
| 1301 |
+
dtype=torch.float32,
|
| 1302 |
+
device=page_table.device,
|
| 1303 |
+
)
|
| 1304 |
+
self.LSE_partial = torch.empty(
|
| 1305 |
+
(schedule.partial_rows, self.num_qo_heads),
|
| 1306 |
+
dtype=torch.float32,
|
| 1307 |
+
device=page_table.device,
|
| 1308 |
+
)
|
| 1309 |
+
self._O_partial_dummy = None
|
| 1310 |
+
self._LSE_partial_dummy = None
|
| 1311 |
+
else:
|
| 1312 |
+
self.O_partial = None
|
| 1313 |
+
self.LSE_partial = None
|
| 1314 |
+
# decode_forward_paged_fp8 always wants non-None partial buffers
|
| 1315 |
+
# for the kernel's positional arg layout (compile keeps the slot
|
| 1316 |
+
# alive even when split_kv=False). Allocate once here and reuse.
|
| 1317 |
+
self._O_partial_dummy = torch.empty(
|
| 1318 |
+
(1, self.head_dim),
|
| 1319 |
+
dtype=torch.float32,
|
| 1320 |
+
device=page_table.device,
|
| 1321 |
+
)
|
| 1322 |
+
self._LSE_partial_dummy = torch.empty(
|
| 1323 |
+
(1, self.num_qo_heads),
|
| 1324 |
+
dtype=torch.float32,
|
| 1325 |
+
device=page_table.device,
|
| 1326 |
+
)
|
| 1327 |
+
# LSE dummy is shape (1, head_q) — used when caller doesn't request
|
| 1328 |
+
# LSE and the schedule isn't split-KV (split-KV always writes LSE).
|
| 1329 |
+
self._lse_dummy = torch.empty(
|
| 1330 |
+
(1, self.num_qo_heads),
|
| 1331 |
+
dtype=torch.float32,
|
| 1332 |
+
device=page_table.device,
|
| 1333 |
+
)
|
| 1334 |
+
return self
|
| 1335 |
+
|
| 1336 |
+
def run(
|
| 1337 |
+
self,
|
| 1338 |
+
q: torch.Tensor,
|
| 1339 |
+
k: torch.Tensor,
|
| 1340 |
+
v: torch.Tensor,
|
| 1341 |
+
*,
|
| 1342 |
+
softmax_scale: Optional[float] = None,
|
| 1343 |
+
return_softmax_lse: bool = False,
|
| 1344 |
+
out: Optional[torch.Tensor] = None,
|
| 1345 |
+
lse: Optional[torch.Tensor] = None,
|
| 1346 |
+
):
|
| 1347 |
+
"""Launch decode using metadata cached by ``plan``.
|
| 1348 |
+
|
| 1349 |
+
Parameters
|
| 1350 |
+
----------
|
| 1351 |
+
q : torch.Tensor
|
| 1352 |
+
Shape ``[batch_size * seqlen_q, Hq, 128]`` and dtype FP8 E4M3.
|
| 1353 |
+
k : torch.Tensor
|
| 1354 |
+
Paged K cache with shape ``[num_pages, Hkv, blk_kv, 128]``.
|
| 1355 |
+
v : torch.Tensor
|
| 1356 |
+
Paged V cache with the same shape as ``k``.
|
| 1357 |
+
softmax_scale : float, optional
|
| 1358 |
+
Softmax scale. Defaults to ``1 / sqrt(128)``.
|
| 1359 |
+
return_softmax_lse : bool, optional
|
| 1360 |
+
If True, return ``(out, lse)``.
|
| 1361 |
+
out : torch.Tensor, optional
|
| 1362 |
+
Preallocated BF16 output buffer with shape ``q.shape``.
|
| 1363 |
+
lse : torch.Tensor, optional
|
| 1364 |
+
Preallocated float32 LSE buffer with shape ``[total_q, Hq]``.
|
| 1365 |
+
|
| 1366 |
+
Returns
|
| 1367 |
+
-------
|
| 1368 |
+
torch.Tensor or tuple[torch.Tensor, torch.Tensor]
|
| 1369 |
+
BF16 output, optionally with float32 LSE.
|
| 1370 |
+
"""
|
| 1371 |
+
if self.decode_schedule is None:
|
| 1372 |
+
raise RuntimeError("decode wrapper must be planned before run")
|
| 1373 |
+
if self.is_sparse:
|
| 1374 |
+
# Sparse path still goes through the validating wrapper for now;
|
| 1375 |
+
# only the dense fast path is collapsed.
|
| 1376 |
+
return sparse_decode_atten_func(
|
| 1377 |
+
q, k, v, self.q2k_indices,
|
| 1378 |
+
page_table=self.page_table, seqused_k=self.seqused_k,
|
| 1379 |
+
seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k,
|
| 1380 |
+
blk_kv=self.blk_kv, causal=self.causal,
|
| 1381 |
+
softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse,
|
| 1382 |
+
schedule=self.decode_schedule,
|
| 1383 |
+
O_partial=self.O_partial, LSE_partial=self.LSE_partial,
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
if softmax_scale is None:
|
| 1387 |
+
softmax_scale = q.shape[-1] ** -0.5
|
| 1388 |
+
if out is None:
|
| 1389 |
+
out = torch.empty(q.shape, dtype=torch.bfloat16, device=q.device)
|
| 1390 |
+
if lse is None:
|
| 1391 |
+
if return_softmax_lse or self.split_kv:
|
| 1392 |
+
# Real LSE needed — must allocate per-call (shape depends on q).
|
| 1393 |
+
lse = torch.empty(
|
| 1394 |
+
q.shape[:2], dtype=torch.float32, device=q.device,
|
| 1395 |
+
)
|
| 1396 |
+
else:
|
| 1397 |
+
# Kernel only needs a valid pointer; reuse cached dummy.
|
| 1398 |
+
lse = self._lse_dummy
|
| 1399 |
+
from .src.sm100.fwd_decode import decode_forward_paged_fp8
|
| 1400 |
+
schedule = self.decode_schedule
|
| 1401 |
+
decode_forward_paged_fp8(
|
| 1402 |
+
q, k, v,
|
| 1403 |
+
self.page_table, self.seqused_k,
|
| 1404 |
+
out, lse,
|
| 1405 |
+
schedule.request_indices, schedule.qo_tile_indices,
|
| 1406 |
+
schedule.kv_tile_indices, schedule.block_valid_mask,
|
| 1407 |
+
schedule.split_counts, schedule.o_indptr, schedule.merge_indptr,
|
| 1408 |
+
self.O_partial, self.LSE_partial,
|
| 1409 |
+
softmax_scale=float(softmax_scale),
|
| 1410 |
+
seqlen_q=self.seqlen_q,
|
| 1411 |
+
page_size=self.blk_kv,
|
| 1412 |
+
kv_chunk_size_pages=int(schedule.kv_chunk_size_pages),
|
| 1413 |
+
max_split_count=int(schedule.max_split_count),
|
| 1414 |
+
split_kv=bool(schedule.split_kv),
|
| 1415 |
+
causal=self.causal,
|
| 1416 |
+
return_lse=bool(return_softmax_lse),
|
| 1417 |
+
# cached dummies — avoid per-call torch.empty inside run_decode_attention
|
| 1418 |
+
O_partial_dummy=self._O_partial_dummy,
|
| 1419 |
+
LSE_partial_dummy=self._LSE_partial_dummy,
|
| 1420 |
+
)
|
| 1421 |
+
if return_softmax_lse:
|
| 1422 |
+
return out, lse
|
| 1423 |
+
return out
|
| 1424 |
+
|
| 1425 |
+
|
| 1426 |
+
def _sparse_atten_csr_varlen_forward(
|
| 1427 |
+
q: torch.Tensor,
|
| 1428 |
+
k: torch.Tensor,
|
| 1429 |
+
v: torch.Tensor,
|
| 1430 |
+
k2q_row_ptr: torch.Tensor,
|
| 1431 |
+
k2q_q_indices: torch.Tensor,
|
| 1432 |
+
topK: int,
|
| 1433 |
+
blk_kv: int,
|
| 1434 |
+
causal: bool,
|
| 1435 |
+
softmax_scale: float,
|
| 1436 |
+
lse_temperature_scale: float,
|
| 1437 |
+
return_temperature_lse: bool,
|
| 1438 |
+
partial_dtype: torch.dtype,
|
| 1439 |
+
return_softmax_lse: bool,
|
| 1440 |
+
cu_seqlens_q: torch.Tensor,
|
| 1441 |
+
cu_seqlens_k: torch.Tensor,
|
| 1442 |
+
page_table: Optional[torch.Tensor],
|
| 1443 |
+
seqused_k: Optional[torch.Tensor],
|
| 1444 |
+
schedule: Optional[SparseAttentionSchedule],
|
| 1445 |
+
usable_SM_count: int,
|
| 1446 |
+
batch: int,
|
| 1447 |
+
head_kv: int,
|
| 1448 |
+
max_seqlen_q: int,
|
| 1449 |
+
max_seqlen_k: int,
|
| 1450 |
+
qk_dtype: torch.dtype,
|
| 1451 |
+
pv_dtype: torch.dtype,
|
| 1452 |
+
):
|
| 1453 |
+
total_q, head_q, dim = q.shape
|
| 1454 |
+
if head_q % head_kv != 0:
|
| 1455 |
+
raise ValueError("q.shape[1] must be divisible by head_kv")
|
| 1456 |
+
max_num_kv_blocks = _csr_row_capacity(k2q_row_ptr)
|
| 1457 |
+
temperature_lse_fast_path = (
|
| 1458 |
+
return_temperature_lse
|
| 1459 |
+
and math.isclose(
|
| 1460 |
+
float(lse_temperature_scale),
|
| 1461 |
+
1.0,
|
| 1462 |
+
rel_tol=0.0,
|
| 1463 |
+
abs_tol=_TEMPERATURE_LSE_FAST_PATH_ABS_TOL,
|
| 1464 |
+
)
|
| 1465 |
+
)
|
| 1466 |
+
kernel_return_temperature_lse = (
|
| 1467 |
+
return_temperature_lse and not temperature_lse_fast_path
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
O_partial = torch.empty(
|
| 1471 |
+
topK, total_q, head_q, dim, dtype=partial_dtype, device=q.device
|
| 1472 |
+
)
|
| 1473 |
+
LSE_partial = torch.empty(
|
| 1474 |
+
topK, total_q, head_q, dtype=torch.float32, device=q.device
|
| 1475 |
+
)
|
| 1476 |
+
LSE_temperature_partial = (
|
| 1477 |
+
torch.empty(topK, total_q, head_q, dtype=torch.float32, device=q.device)
|
| 1478 |
+
if kernel_return_temperature_lse
|
| 1479 |
+
else None
|
| 1480 |
+
)
|
| 1481 |
+
O_out = torch.empty(total_q, head_q, dim, dtype=torch.bfloat16, device=q.device)
|
| 1482 |
+
LSE_out = torch.empty(total_q, head_q, dtype=torch.float32, device=q.device)
|
| 1483 |
+
LSE_temperature_out = (
|
| 1484 |
+
torch.empty_like(LSE_out) if kernel_return_temperature_lse else None
|
| 1485 |
+
)
|
| 1486 |
+
if schedule is None:
|
| 1487 |
+
k2q_qsplit_indices = torch.empty_like(k2q_q_indices)
|
| 1488 |
+
split_counts = torch.zeros(
|
| 1489 |
+
(total_q, head_kv),
|
| 1490 |
+
dtype=torch.int32,
|
| 1491 |
+
device=q.device,
|
| 1492 |
+
)
|
| 1493 |
+
else:
|
| 1494 |
+
_validate_fwd_schedule(
|
| 1495 |
+
schedule,
|
| 1496 |
+
q=q,
|
| 1497 |
+
k2q_q_indices=k2q_q_indices,
|
| 1498 |
+
head_kv=head_kv,
|
| 1499 |
+
)
|
| 1500 |
+
k2q_qsplit_indices = schedule.qsplit_indices
|
| 1501 |
+
split_counts = schedule.split_counts
|
| 1502 |
+
schedule = _call_sparse_forward_sm100_csr_varlen(
|
| 1503 |
+
q,
|
| 1504 |
+
k,
|
| 1505 |
+
v,
|
| 1506 |
+
k2q_row_ptr,
|
| 1507 |
+
k2q_q_indices,
|
| 1508 |
+
k2q_qsplit_indices,
|
| 1509 |
+
split_counts,
|
| 1510 |
+
cu_seqlens_q,
|
| 1511 |
+
cu_seqlens_k,
|
| 1512 |
+
page_table,
|
| 1513 |
+
seqused_k,
|
| 1514 |
+
O_partial,
|
| 1515 |
+
LSE_partial,
|
| 1516 |
+
LSE_temperature_partial,
|
| 1517 |
+
softmax_scale,
|
| 1518 |
+
lse_temperature_scale,
|
| 1519 |
+
kernel_return_temperature_lse,
|
| 1520 |
+
max_num_kv_blocks,
|
| 1521 |
+
blk_kv,
|
| 1522 |
+
head_kv,
|
| 1523 |
+
max_seqlen_q,
|
| 1524 |
+
usable_SM_count,
|
| 1525 |
+
causal=causal,
|
| 1526 |
+
schedule=schedule,
|
| 1527 |
+
qk_dtype=qk_dtype,
|
| 1528 |
+
pv_dtype=pv_dtype,
|
| 1529 |
+
)
|
| 1530 |
+
# Sparse Attention and Sparse Page Attention both use the varlen-Q
|
| 1531 |
+
# combine path; the kernel-written LSE_out is the final contract.
|
| 1532 |
+
combine(
|
| 1533 |
+
O_partial,
|
| 1534 |
+
LSE_partial,
|
| 1535 |
+
O_out,
|
| 1536 |
+
LSE_out,
|
| 1537 |
+
lse_temperature_partial=LSE_temperature_partial,
|
| 1538 |
+
lse_temperature_out=LSE_temperature_out,
|
| 1539 |
+
cu_seqlens=cu_seqlens_q,
|
| 1540 |
+
split_counts=split_counts,
|
| 1541 |
+
use_pdl=True,
|
| 1542 |
+
)
|
| 1543 |
+
if temperature_lse_fast_path:
|
| 1544 |
+
LSE_temperature_out = LSE_out
|
| 1545 |
+
|
| 1546 |
+
if return_softmax_lse:
|
| 1547 |
+
if return_temperature_lse:
|
| 1548 |
+
return O_out, LSE_out, LSE_temperature_out
|
| 1549 |
+
return O_out, LSE_out
|
| 1550 |
+
return O_out
|
| 1551 |
+
|
| 1552 |
+
|
| 1553 |
+
def _call_sparse_decode_forward_sm100_paged_fp8(
|
| 1554 |
+
q: torch.Tensor,
|
| 1555 |
+
k: torch.Tensor,
|
| 1556 |
+
v: torch.Tensor,
|
| 1557 |
+
q2k_indices: Optional[torch.Tensor],
|
| 1558 |
+
page_table: torch.Tensor,
|
| 1559 |
+
seqused_k: torch.Tensor,
|
| 1560 |
+
out: torch.Tensor,
|
| 1561 |
+
lse: torch.Tensor,
|
| 1562 |
+
schedule: DecodeAttentionSchedule,
|
| 1563 |
+
O_partial: Optional[torch.Tensor],
|
| 1564 |
+
LSE_partial: Optional[torch.Tensor],
|
| 1565 |
+
*,
|
| 1566 |
+
softmax_scale: float,
|
| 1567 |
+
seqlen_q: int,
|
| 1568 |
+
max_seqlen_k: int,
|
| 1569 |
+
blk_kv: int,
|
| 1570 |
+
causal: bool,
|
| 1571 |
+
return_lse: bool = True,
|
| 1572 |
+
) -> None:
|
| 1573 |
+
"""Compile and launch the SM100 paged fp8 decode forward kernel.
|
| 1574 |
+
|
| 1575 |
+
Dense decode is selected by ``q2k_indices=None``. Sparse decode will reuse
|
| 1576 |
+
the same schedule wrapper but needs a separate q2k gather path.
|
| 1577 |
+
"""
|
| 1578 |
+
if q2k_indices is not None:
|
| 1579 |
+
raise NotImplementedError("SM100 paged fp8 sparse decode forward is not implemented yet")
|
| 1580 |
+
if schedule.cta_tile_q != 128:
|
| 1581 |
+
raise NotImplementedError(f"decode forward requires cta_tile_q=128, got {schedule.cta_tile_q}")
|
| 1582 |
+
if schedule.split_kv:
|
| 1583 |
+
if O_partial is None or LSE_partial is None:
|
| 1584 |
+
raise ValueError("split decode forward requires O_partial and LSE_partial")
|
| 1585 |
+
if O_partial.dtype != torch.float32:
|
| 1586 |
+
raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}")
|
| 1587 |
+
if LSE_partial.dtype != torch.float32:
|
| 1588 |
+
raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}")
|
| 1589 |
+
|
| 1590 |
+
from .src.sm100.fwd_decode import decode_forward_paged_fp8
|
| 1591 |
+
|
| 1592 |
+
decode_forward_paged_fp8(
|
| 1593 |
+
q,
|
| 1594 |
+
k,
|
| 1595 |
+
v,
|
| 1596 |
+
page_table,
|
| 1597 |
+
seqused_k,
|
| 1598 |
+
out,
|
| 1599 |
+
lse,
|
| 1600 |
+
schedule.request_indices,
|
| 1601 |
+
schedule.qo_tile_indices,
|
| 1602 |
+
schedule.kv_tile_indices,
|
| 1603 |
+
schedule.block_valid_mask,
|
| 1604 |
+
schedule.split_counts,
|
| 1605 |
+
schedule.o_indptr,
|
| 1606 |
+
schedule.merge_indptr,
|
| 1607 |
+
O_partial,
|
| 1608 |
+
LSE_partial,
|
| 1609 |
+
softmax_scale=float(softmax_scale),
|
| 1610 |
+
seqlen_q=int(seqlen_q),
|
| 1611 |
+
page_size=int(blk_kv),
|
| 1612 |
+
kv_chunk_size_pages=int(schedule.kv_chunk_size_pages),
|
| 1613 |
+
max_split_count=int(schedule.max_split_count),
|
| 1614 |
+
split_kv=bool(schedule.split_kv),
|
| 1615 |
+
causal=bool(causal),
|
| 1616 |
+
return_lse=bool(return_lse),
|
| 1617 |
+
)
|
| 1618 |
+
|
| 1619 |
+
|
| 1620 |
+
def _call_sparse_forward_sm100_csr_varlen(
|
| 1621 |
+
q,
|
| 1622 |
+
k,
|
| 1623 |
+
v,
|
| 1624 |
+
k2q_row_ptr,
|
| 1625 |
+
k2q_q_indices,
|
| 1626 |
+
k2q_qsplit_indices,
|
| 1627 |
+
split_counts,
|
| 1628 |
+
cu_seqlens_q,
|
| 1629 |
+
cu_seqlens_k,
|
| 1630 |
+
page_table,
|
| 1631 |
+
seqused_k,
|
| 1632 |
+
O_partial,
|
| 1633 |
+
LSE_partial,
|
| 1634 |
+
LSE_temperature_partial,
|
| 1635 |
+
softmax_scale,
|
| 1636 |
+
lse_temperature_scale,
|
| 1637 |
+
return_temperature_lse,
|
| 1638 |
+
max_num_kv_blocks,
|
| 1639 |
+
blk_kv,
|
| 1640 |
+
head_kv,
|
| 1641 |
+
max_seqlen_q,
|
| 1642 |
+
usable_SM_count=-1,
|
| 1643 |
+
*,
|
| 1644 |
+
causal=False,
|
| 1645 |
+
use_prepare_scheduler=True,
|
| 1646 |
+
schedule: Optional[SparseAttentionSchedule] = None,
|
| 1647 |
+
qk_dtype: torch.dtype,
|
| 1648 |
+
pv_dtype: torch.dtype,
|
| 1649 |
+
):
|
| 1650 |
+
"""Compile and launch the SM100 sparse forward K1 kernel on CSR metadata."""
|
| 1651 |
+
head_dim = q.shape[-1]
|
| 1652 |
+
dtype = q.dtype
|
| 1653 |
+
qk_dtype = _normalize_forward_mma_dtype(qk_dtype, q.dtype, "qk_dtype")
|
| 1654 |
+
pv_dtype = _normalize_forward_mma_dtype(pv_dtype, v.dtype, "pv_dtype")
|
| 1655 |
+
partial_dtype = O_partial.dtype
|
| 1656 |
+
return_temperature_lse = bool(return_temperature_lse)
|
| 1657 |
+
if return_temperature_lse != (LSE_temperature_partial is not None):
|
| 1658 |
+
raise ValueError(
|
| 1659 |
+
"return_temperature_lse must match LSE_temperature_partial presence"
|
| 1660 |
+
)
|
| 1661 |
+
lse_temperature_scale = float(lse_temperature_scale)
|
| 1662 |
+
if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
|
| 1663 |
+
raise ValueError(
|
| 1664 |
+
f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
|
| 1665 |
+
)
|
| 1666 |
+
lse_temperature_inv_scale = 1.0 / lse_temperature_scale
|
| 1667 |
+
n_block_size = int(blk_kv)
|
| 1668 |
+
head_q = q.shape[1]
|
| 1669 |
+
qhead_per_kv = head_q // head_kv
|
| 1670 |
+
paged_kv = page_table is not None
|
| 1671 |
+
if not bool(use_prepare_scheduler):
|
| 1672 |
+
raise RuntimeError("sparse forward requires prepare scheduler")
|
| 1673 |
+
schedule_enabled = k2q_row_ptr.shape[1] > 1
|
| 1674 |
+
page_size = int(k.shape[2]) if paged_kv else None
|
| 1675 |
+
if paged_kv:
|
| 1676 |
+
k_kernel, v_kernel = _prepare_paged_kv_for_tma(k, v, n_block_size)
|
| 1677 |
+
else:
|
| 1678 |
+
k_kernel = k
|
| 1679 |
+
v_kernel = v
|
| 1680 |
+
O_partial_flat = O_partial.reshape(-1, head_dim).contiguous()
|
| 1681 |
+
Q_flat = q.reshape(-1, head_dim).contiguous()
|
| 1682 |
+
Q_gather4_desc = (
|
| 1683 |
+
create_q_gather4_tma_desc(
|
| 1684 |
+
Q_flat,
|
| 1685 |
+
box_x=128 if q.dtype == torch.float8_e4m3fn else 64,
|
| 1686 |
+
)
|
| 1687 |
+
if qhead_per_kv in (1, 2, 4)
|
| 1688 |
+
else None
|
| 1689 |
+
)
|
| 1690 |
+
if schedule is None:
|
| 1691 |
+
schedule = prepare_sparse_fwd_schedule_and_split(
|
| 1692 |
+
k2q_row_ptr=k2q_row_ptr,
|
| 1693 |
+
k2q_q_indices=k2q_q_indices,
|
| 1694 |
+
k2q_qsplit_indices=k2q_qsplit_indices,
|
| 1695 |
+
split_counts=split_counts,
|
| 1696 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 1697 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 1698 |
+
total_q=int(q.shape[0]),
|
| 1699 |
+
max_seqlen_q=max_seqlen_q,
|
| 1700 |
+
topk=int(O_partial.shape[0]),
|
| 1701 |
+
head_kv=head_kv,
|
| 1702 |
+
qhead_per_kv=qhead_per_kv,
|
| 1703 |
+
blk_kv=n_block_size,
|
| 1704 |
+
device=q.device,
|
| 1705 |
+
enabled=schedule_enabled,
|
| 1706 |
+
)
|
| 1707 |
+
use_prepare_scheduler = schedule.enabled
|
| 1708 |
+
scheduler_metadata = schedule.scheduler_metadata
|
| 1709 |
+
work_count = schedule.work_count
|
| 1710 |
+
work_capacity = schedule.work_capacity
|
| 1711 |
+
if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0:
|
| 1712 |
+
raise RuntimeError("sparse forward requires a non-empty prepared schedule")
|
| 1713 |
+
|
| 1714 |
+
key = (
|
| 1715 |
+
"sparse_forward_sm100_csr_varlen",
|
| 1716 |
+
head_dim,
|
| 1717 |
+
n_block_size,
|
| 1718 |
+
qhead_per_kv,
|
| 1719 |
+
dtype,
|
| 1720 |
+
k.dtype,
|
| 1721 |
+
v.dtype,
|
| 1722 |
+
qk_dtype,
|
| 1723 |
+
pv_dtype,
|
| 1724 |
+
partial_dtype,
|
| 1725 |
+
bool(causal),
|
| 1726 |
+
bool(paged_kv),
|
| 1727 |
+
bool(use_prepare_scheduler),
|
| 1728 |
+
page_size,
|
| 1729 |
+
bool(seqused_k is not None),
|
| 1730 |
+
bool(return_temperature_lse),
|
| 1731 |
+
)
|
| 1732 |
+
if key not in _compile_cache:
|
| 1733 |
+
from .src.common.aot_cache import try_load_aot, save_aot
|
| 1734 |
+
|
| 1735 |
+
loaded = try_load_aot(key)
|
| 1736 |
+
if loaded is not None:
|
| 1737 |
+
_compile_cache[key] = loaded
|
| 1738 |
+
else:
|
| 1739 |
+
kernel = SparseAttentionForwardSm100(
|
| 1740 |
+
head_dim=head_dim,
|
| 1741 |
+
qheadperkv=qhead_per_kv,
|
| 1742 |
+
n_block_size=n_block_size,
|
| 1743 |
+
paged_kv=paged_kv,
|
| 1744 |
+
page_size=page_size,
|
| 1745 |
+
has_seqused_k=seqused_k is not None,
|
| 1746 |
+
causal=bool(causal),
|
| 1747 |
+
use_prepare_scheduler=use_prepare_scheduler,
|
| 1748 |
+
qk_dtype=_torch_dtype_to_cutlass_dtype(qk_dtype),
|
| 1749 |
+
pv_dtype=_torch_dtype_to_cutlass_dtype(pv_dtype),
|
| 1750 |
+
)
|
| 1751 |
+
_compile_cache[key] = cute.compile(
|
| 1752 |
+
kernel,
|
| 1753 |
+
to_cute_tensor_kvouter(k_kernel),
|
| 1754 |
+
to_cute_tensor_kvouter(v_kernel),
|
| 1755 |
+
to_cute_tensor_kvouter(k2q_q_indices),
|
| 1756 |
+
to_cute_tensor_kvouter(k2q_qsplit_indices),
|
| 1757 |
+
to_cute_tensor_kvouter(k2q_row_ptr),
|
| 1758 |
+
None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata),
|
| 1759 |
+
None if work_count is None else to_cute_tensor_kvouter(work_count),
|
| 1760 |
+
to_cute_tensor_kvouter(O_partial_flat),
|
| 1761 |
+
to_cute_tensor_kvouter(LSE_partial),
|
| 1762 |
+
None
|
| 1763 |
+
if LSE_temperature_partial is None
|
| 1764 |
+
else to_cute_tensor_kvouter(LSE_temperature_partial),
|
| 1765 |
+
to_cute_tensor_kvouter(Q_flat),
|
| 1766 |
+
None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc),
|
| 1767 |
+
None if page_table is None else to_cute_tensor_kvouter(page_table),
|
| 1768 |
+
None if seqused_k is None else to_cute_tensor_kvouter(seqused_k),
|
| 1769 |
+
to_cute_tensor_kvouter(cu_seqlens_q),
|
| 1770 |
+
to_cute_tensor_kvouter(cu_seqlens_k),
|
| 1771 |
+
Float32(softmax_scale),
|
| 1772 |
+
Float32(lse_temperature_inv_scale),
|
| 1773 |
+
Int32(max_num_kv_blocks),
|
| 1774 |
+
Int32(head_kv),
|
| 1775 |
+
Int32(max_seqlen_q),
|
| 1776 |
+
Int32(work_capacity),
|
| 1777 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 1778 |
+
options="--enable-tvm-ffi",
|
| 1779 |
+
)
|
| 1780 |
+
save_aot(key, _compile_cache[key])
|
| 1781 |
+
|
| 1782 |
+
with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen"):
|
| 1783 |
+
_compile_cache[key](
|
| 1784 |
+
k_kernel,
|
| 1785 |
+
v_kernel,
|
| 1786 |
+
k2q_q_indices,
|
| 1787 |
+
k2q_qsplit_indices,
|
| 1788 |
+
k2q_row_ptr,
|
| 1789 |
+
scheduler_metadata,
|
| 1790 |
+
work_count,
|
| 1791 |
+
O_partial_flat,
|
| 1792 |
+
LSE_partial,
|
| 1793 |
+
LSE_temperature_partial,
|
| 1794 |
+
Q_flat,
|
| 1795 |
+
Q_gather4_desc,
|
| 1796 |
+
page_table,
|
| 1797 |
+
seqused_k,
|
| 1798 |
+
cu_seqlens_q,
|
| 1799 |
+
cu_seqlens_k,
|
| 1800 |
+
softmax_scale,
|
| 1801 |
+
lse_temperature_inv_scale,
|
| 1802 |
+
max_num_kv_blocks,
|
| 1803 |
+
head_kv,
|
| 1804 |
+
max_seqlen_q,
|
| 1805 |
+
work_capacity,
|
| 1806 |
+
)
|
| 1807 |
+
return schedule
|
| 1808 |
+
|
| 1809 |
+
|
| 1810 |
+
def _call_sparse_forward_sm100_csr_varlen_nvfp4_kv(
|
| 1811 |
+
q,
|
| 1812 |
+
k,
|
| 1813 |
+
v,
|
| 1814 |
+
k_scale_128x4,
|
| 1815 |
+
v_scale_128x4,
|
| 1816 |
+
k_global_scale,
|
| 1817 |
+
v_global_scale,
|
| 1818 |
+
k2q_row_ptr,
|
| 1819 |
+
k2q_q_indices,
|
| 1820 |
+
k2q_qsplit_indices,
|
| 1821 |
+
split_counts,
|
| 1822 |
+
cu_seqlens_q,
|
| 1823 |
+
cu_seqlens_k,
|
| 1824 |
+
page_table,
|
| 1825 |
+
seqused_k,
|
| 1826 |
+
O_partial,
|
| 1827 |
+
LSE_partial,
|
| 1828 |
+
LSE_temperature_partial,
|
| 1829 |
+
softmax_scale,
|
| 1830 |
+
lse_temperature_scale,
|
| 1831 |
+
return_temperature_lse,
|
| 1832 |
+
max_num_kv_blocks,
|
| 1833 |
+
blk_kv,
|
| 1834 |
+
head_kv,
|
| 1835 |
+
max_seqlen_q,
|
| 1836 |
+
*,
|
| 1837 |
+
causal=False,
|
| 1838 |
+
use_prepare_scheduler=True,
|
| 1839 |
+
schedule: Optional[SparseAttentionSchedule] = None,
|
| 1840 |
+
):
|
| 1841 |
+
"""Compile and launch the SM100 sparse forward K1 kernel with NVFP4 K/V."""
|
| 1842 |
+
|
| 1843 |
+
head_dim = q.shape[-1]
|
| 1844 |
+
dtype = q.dtype
|
| 1845 |
+
partial_dtype = O_partial.dtype
|
| 1846 |
+
return_temperature_lse = bool(return_temperature_lse)
|
| 1847 |
+
if return_temperature_lse != (LSE_temperature_partial is not None):
|
| 1848 |
+
raise ValueError(
|
| 1849 |
+
"return_temperature_lse must match LSE_temperature_partial presence"
|
| 1850 |
+
)
|
| 1851 |
+
lse_temperature_scale = float(lse_temperature_scale)
|
| 1852 |
+
if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0:
|
| 1853 |
+
raise ValueError(
|
| 1854 |
+
f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}"
|
| 1855 |
+
)
|
| 1856 |
+
lse_temperature_inv_scale = 1.0 / lse_temperature_scale
|
| 1857 |
+
n_block_size = int(blk_kv)
|
| 1858 |
+
head_q = q.shape[1]
|
| 1859 |
+
qhead_per_kv = head_q // head_kv
|
| 1860 |
+
fp8_pair_dequant = os.environ.get("MINIMAX_KVFP4_FP8_PAIR_DEQUANT", "1") != "0"
|
| 1861 |
+
k_global_scale_kernel = k_global_scale
|
| 1862 |
+
# V global scale is linear in the final output. Keep K1 on block-scale-only V
|
| 1863 |
+
# and apply the tensor scale once in K2 combine.
|
| 1864 |
+
v_global_scale_kernel = None
|
| 1865 |
+
has_k_global_scale = k_global_scale_kernel is not None
|
| 1866 |
+
has_v_global_scale = v_global_scale_kernel is not None
|
| 1867 |
+
paged_kv = page_table is not None
|
| 1868 |
+
if not bool(use_prepare_scheduler):
|
| 1869 |
+
raise RuntimeError("KVFP4 sparse forward requires prepare scheduler")
|
| 1870 |
+
schedule_enabled = k2q_row_ptr.shape[1] > 1
|
| 1871 |
+
page_size = int(k.shape[2]) if paged_kv else None
|
| 1872 |
+
if paged_kv:
|
| 1873 |
+
_prepare_paged_kv_for_tma(k, v, n_block_size)
|
| 1874 |
+
k_kernel = k
|
| 1875 |
+
v_kernel = v
|
| 1876 |
+
O_partial_flat = O_partial.reshape(-1, head_dim).contiguous()
|
| 1877 |
+
Q_flat = q.reshape(-1, head_dim).contiguous()
|
| 1878 |
+
Q_gather4_desc = (
|
| 1879 |
+
create_q_gather4_tma_desc(
|
| 1880 |
+
Q_flat,
|
| 1881 |
+
box_x=128 if q.dtype == torch.float8_e4m3fn else 64,
|
| 1882 |
+
)
|
| 1883 |
+
if qhead_per_kv in (1, 2, 4)
|
| 1884 |
+
else None
|
| 1885 |
+
)
|
| 1886 |
+
if schedule is None:
|
| 1887 |
+
schedule = prepare_sparse_fwd_schedule_and_split(
|
| 1888 |
+
k2q_row_ptr=k2q_row_ptr,
|
| 1889 |
+
k2q_q_indices=k2q_q_indices,
|
| 1890 |
+
k2q_qsplit_indices=k2q_qsplit_indices,
|
| 1891 |
+
split_counts=split_counts,
|
| 1892 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 1893 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 1894 |
+
total_q=int(q.shape[0]),
|
| 1895 |
+
max_seqlen_q=max_seqlen_q,
|
| 1896 |
+
topk=int(O_partial.shape[0]),
|
| 1897 |
+
head_kv=head_kv,
|
| 1898 |
+
qhead_per_kv=qhead_per_kv,
|
| 1899 |
+
blk_kv=n_block_size,
|
| 1900 |
+
device=q.device,
|
| 1901 |
+
enabled=schedule_enabled,
|
| 1902 |
+
)
|
| 1903 |
+
use_prepare_scheduler = schedule.enabled
|
| 1904 |
+
scheduler_metadata = schedule.scheduler_metadata
|
| 1905 |
+
work_count = schedule.work_count
|
| 1906 |
+
work_capacity = schedule.work_capacity
|
| 1907 |
+
if not use_prepare_scheduler or scheduler_metadata is None or work_count is None or work_capacity <= 0:
|
| 1908 |
+
raise RuntimeError("KVFP4 sparse forward requires a non-empty prepared schedule")
|
| 1909 |
+
|
| 1910 |
+
key = (
|
| 1911 |
+
"sparse_forward_sm100_csr_varlen_nvfp4_kv",
|
| 1912 |
+
head_dim,
|
| 1913 |
+
n_block_size,
|
| 1914 |
+
qhead_per_kv,
|
| 1915 |
+
dtype,
|
| 1916 |
+
partial_dtype,
|
| 1917 |
+
bool(causal),
|
| 1918 |
+
bool(paged_kv),
|
| 1919 |
+
bool(use_prepare_scheduler),
|
| 1920 |
+
page_size,
|
| 1921 |
+
bool(seqused_k is not None),
|
| 1922 |
+
bool(return_temperature_lse),
|
| 1923 |
+
bool(fp8_pair_dequant),
|
| 1924 |
+
bool(has_k_global_scale),
|
| 1925 |
+
bool(has_v_global_scale),
|
| 1926 |
+
)
|
| 1927 |
+
if key not in _compile_cache:
|
| 1928 |
+
from .src.common.aot_cache import try_load_aot, save_aot
|
| 1929 |
+
|
| 1930 |
+
loaded = try_load_aot(key)
|
| 1931 |
+
if loaded is not None:
|
| 1932 |
+
_compile_cache[key] = loaded
|
| 1933 |
+
else:
|
| 1934 |
+
kernel = SparseAttentionForwardNvfp4KvSm100(
|
| 1935 |
+
head_dim=head_dim,
|
| 1936 |
+
qheadperkv=qhead_per_kv,
|
| 1937 |
+
n_block_size=n_block_size,
|
| 1938 |
+
paged_kv=paged_kv,
|
| 1939 |
+
page_size=page_size,
|
| 1940 |
+
has_seqused_k=seqused_k is not None,
|
| 1941 |
+
causal=bool(causal),
|
| 1942 |
+
use_prepare_scheduler=use_prepare_scheduler,
|
| 1943 |
+
fp8_pair_dequant=bool(fp8_pair_dequant),
|
| 1944 |
+
has_k_global_scale=bool(has_k_global_scale),
|
| 1945 |
+
has_v_global_scale=bool(has_v_global_scale),
|
| 1946 |
+
)
|
| 1947 |
+
_compile_cache[key] = cute.compile(
|
| 1948 |
+
kernel,
|
| 1949 |
+
to_cute_tensor_kvouter(k_kernel),
|
| 1950 |
+
to_cute_tensor_kvouter(v_kernel),
|
| 1951 |
+
to_cute_tensor_kvouter(k_scale_128x4),
|
| 1952 |
+
to_cute_tensor_kvouter(v_scale_128x4),
|
| 1953 |
+
None if k_global_scale_kernel is None else to_cute_tensor_kvouter(k_global_scale_kernel),
|
| 1954 |
+
None if v_global_scale_kernel is None else to_cute_tensor_kvouter(v_global_scale_kernel),
|
| 1955 |
+
to_cute_tensor_kvouter(k2q_q_indices),
|
| 1956 |
+
to_cute_tensor_kvouter(k2q_qsplit_indices),
|
| 1957 |
+
to_cute_tensor_kvouter(k2q_row_ptr),
|
| 1958 |
+
None if scheduler_metadata is None else to_cute_tensor_kvouter(scheduler_metadata),
|
| 1959 |
+
None if work_count is None else to_cute_tensor_kvouter(work_count),
|
| 1960 |
+
to_cute_tensor_kvouter(O_partial_flat),
|
| 1961 |
+
to_cute_tensor_kvouter(LSE_partial),
|
| 1962 |
+
None
|
| 1963 |
+
if LSE_temperature_partial is None
|
| 1964 |
+
else to_cute_tensor_kvouter(LSE_temperature_partial),
|
| 1965 |
+
to_cute_tensor_kvouter(Q_flat),
|
| 1966 |
+
None if Q_gather4_desc is None else to_cute_tensor_kvouter(Q_gather4_desc),
|
| 1967 |
+
None if page_table is None else to_cute_tensor_kvouter(page_table),
|
| 1968 |
+
None if seqused_k is None else to_cute_tensor_kvouter(seqused_k),
|
| 1969 |
+
to_cute_tensor_kvouter(cu_seqlens_q),
|
| 1970 |
+
to_cute_tensor_kvouter(cu_seqlens_k),
|
| 1971 |
+
Float32(softmax_scale),
|
| 1972 |
+
Float32(lse_temperature_inv_scale),
|
| 1973 |
+
Int32(max_num_kv_blocks),
|
| 1974 |
+
Int32(head_kv),
|
| 1975 |
+
Int32(max_seqlen_q),
|
| 1976 |
+
Int32(work_capacity),
|
| 1977 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 1978 |
+
options="--enable-tvm-ffi",
|
| 1979 |
+
)
|
| 1980 |
+
save_aot(key, _compile_cache[key])
|
| 1981 |
+
|
| 1982 |
+
with torch.cuda.nvtx.range("Fwd_SparseAttn_Sm100_CsrVarlen_KVFP4"):
|
| 1983 |
+
_compile_cache[key](
|
| 1984 |
+
k_kernel,
|
| 1985 |
+
v_kernel,
|
| 1986 |
+
k_scale_128x4,
|
| 1987 |
+
v_scale_128x4,
|
| 1988 |
+
k_global_scale_kernel,
|
| 1989 |
+
v_global_scale_kernel,
|
| 1990 |
+
k2q_q_indices,
|
| 1991 |
+
k2q_qsplit_indices,
|
| 1992 |
+
k2q_row_ptr,
|
| 1993 |
+
scheduler_metadata,
|
| 1994 |
+
work_count,
|
| 1995 |
+
O_partial_flat,
|
| 1996 |
+
LSE_partial,
|
| 1997 |
+
LSE_temperature_partial,
|
| 1998 |
+
Q_flat,
|
| 1999 |
+
Q_gather4_desc,
|
| 2000 |
+
page_table,
|
| 2001 |
+
seqused_k,
|
| 2002 |
+
cu_seqlens_q,
|
| 2003 |
+
cu_seqlens_k,
|
| 2004 |
+
softmax_scale,
|
| 2005 |
+
lse_temperature_inv_scale,
|
| 2006 |
+
max_num_kv_blocks,
|
| 2007 |
+
head_kv,
|
| 2008 |
+
max_seqlen_q,
|
| 2009 |
+
work_capacity,
|
| 2010 |
+
)
|
| 2011 |
+
return schedule
|
build/torch211-cxx11-cu128-x86_64-linux/metadata.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "msa",
|
| 3 |
+
"id": "_msa_cuda_09d7851",
|
| 4 |
+
"version": 0,
|
| 5 |
+
"license": "other",
|
| 6 |
+
"upstream": "https://github.com/MiniMax-AI/MSA",
|
| 7 |
+
"python-depends": [
|
| 8 |
+
"tvm-ffi",
|
| 9 |
+
"nvidia-cutlass-dsl"
|
| 10 |
+
],
|
| 11 |
+
"backend": {
|
| 12 |
+
"type": "cuda",
|
| 13 |
+
"archs": [
|
| 14 |
+
"10.0"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
"digest": {
|
| 18 |
+
"algorithm": "sha256",
|
| 19 |
+
"files": {
|
| 20 |
+
"__init__.py": "+W+3U1Z5ZKc/dTA+JUG+6dMjfe9H/d9J+8fN+936wbI=",
|
| 21 |
+
"_msa_cuda_09d7851.abi3.so": "jc2MhuUS893VrLlfb9ytPPqhV5u2+HSnFPugZuaHcWE=",
|
| 22 |
+
"_ops.py": "o9RBC1FB95LP9Sp+GkBILumbSek9oEtxb8F7XXO0F0g=",
|
| 23 |
+
"fp4_indexer_interface.py": "M+0e93gWG8CGOrhY5bm1hEQJU+TT5PrCmwJzTofaDAg=",
|
| 24 |
+
"interface.py": "B4AHQfNyO+vl6MdyMAHW0GhArl7HGufAEa0ATxsWorY=",
|
| 25 |
+
"msa/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY=",
|
| 26 |
+
"quack/__init__.py": "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=",
|
| 27 |
+
"quack/activation.py": "T/ypcXoz6a4wPPNZW2gKZuEj8JeucaKtKxQiQl5XrXc=",
|
| 28 |
+
"quack/compile_utils.py": "qJ3oTsDlbAiddrJHtEO7LPYVqn/s+neNfiw+/KvfXZU=",
|
| 29 |
+
"quack/copy_utils.py": "rdohXm9bKXqDHkMHf8lWQJQnCb0hMLvhzIudkj0Bxeg=",
|
| 30 |
+
"quack/cute_dsl_utils.py": "4uQx5aYDG9UvVzbWwJTjjJLrnoympz70/CD8b37FQWo=",
|
| 31 |
+
"quack/layout_utils.py": "69N1aTy+840X3seMuLfLxiV3BW8SaVsM3Tf0Vf4NCSI=",
|
| 32 |
+
"quantize.py": "1jePLbJngji8ANfnDK6PCG829AMSd+XOMqYVuJ5pXyY=",
|
| 33 |
+
"sparse_index_utils.py": "kzYMdtFPRBfaL6Vfw9xLLre7ph8svtEQrB/txC+52Fc=",
|
| 34 |
+
"src/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=",
|
| 35 |
+
"src/common/__init__.py": "ZdCpznblq9UgGSNgI0hJoDXpk+evcvGjv3GGthxD/nM=",
|
| 36 |
+
"src/common/aot_cache.py": "ya1OHE6Lqx/pb9UhH++Bu8a98Huhmdl084C6cgWdH1s=",
|
| 37 |
+
"src/common/barrier.py": "Godvhwwaf9iyDA/A78VoQMMRRn6ZSnq2YPosr7K2SVE=",
|
| 38 |
+
"src/common/blackwell_helpers.py": "BYJYCeNQ9cYVhWZlfjv0IgNaNqlnoD21nX3gAA5pRB4=",
|
| 39 |
+
"src/common/block_info.py": "U7qL3AZ5ROkNZdL6RTPlLlnLJ6tZ4b2VFVufZLyuuq8=",
|
| 40 |
+
"src/common/copy_utils.py": "bEtyb8O7Z7jIKNjN5ESlnh4WVvdf8vr5ZfQxA6vS6zA=",
|
| 41 |
+
"src/common/cute_dsl_utils.py": "nd8vII+r49Kk185ja3+VM6dwJlvMqCkjMBRh0WEHakw=",
|
| 42 |
+
"src/common/fast_math.py": "nqt6MtzAt7uplC4+kczgBfin4gHNs+QSoufR1TuMZ88=",
|
| 43 |
+
"src/common/mask.py": "l9v4End+9k3ZHRO6DCnuOD9K9iOCiN81osRATKvK41k=",
|
| 44 |
+
"src/common/mma_sm100_desc.py": "C1PqBdp6CNPA9xadQ2xBnf4wvQlE93SS/7CU+LZBQkA=",
|
| 45 |
+
"src/common/named_barrier.py": "5ktJiO+hP80fjTR797CslUGfm2PyhpcW6WJZrNyI5bQ=",
|
| 46 |
+
"src/common/pack_gqa.py": "UrAAIge5XLmilqXWGtCZJobgpuA6B0N1Vw3tDhyUi7s=",
|
| 47 |
+
"src/common/paged_kv.py": "j0/6stT1A5uEVALEX/GaQhYWIie+6LpGseAW8aQiHbk=",
|
| 48 |
+
"src/common/pipeline.py": "MIFfoDDD8Fs//SQSR+JzI/0MJ1qPGml297RtbC2qPRU=",
|
| 49 |
+
"src/common/seqlen_info.py": "EX2W8MTGcnAZ+J60tGG9D7IzvdLeIVQshztntGDkPMQ=",
|
| 50 |
+
"src/common/softmax.py": "ePjb2TUcr4fHLmw0zx9Lt+vvR6hSm2mQwiENf2J/AoQ=",
|
| 51 |
+
"src/common/tile_scheduler.py": "f8UknoE0j9BfPomRI/QCsDJoRk+1IpJrLfBOAh2mlls=",
|
| 52 |
+
"src/common/tma_utils.py": "gpAmBh58VOfHRghZTCbQ5SQpbAYy0lFnpvIcFSLBNb8=",
|
| 53 |
+
"src/common/utils.py": "eGGo5Ul+0XpKtiw6JLofVdFDj6s2xe4LWqDmlqp9AKk=",
|
| 54 |
+
"src/sm100/__init__.py": "JQpQtL58fso8B2Xwvn0XVevVqIjnk15wVQE0UUGGLCs=",
|
| 55 |
+
"src/sm100/build_k2q_csr/__init__.py": "75ICu6BIZir0OeyEgZ1TEYNY7pn+lA4P6McCSSC20rI=",
|
| 56 |
+
"src/sm100/decode_schedule.py": "/VRAmvrMX+oYLzWK1sqve86tprXkqX0/f4o5WMVeU4I=",
|
| 57 |
+
"src/sm100/fp4_indexer.py": "1lc9/rgU09wwF08WBRaXIE0CE2b19pBRwXekDduFs0o=",
|
| 58 |
+
"src/sm100/fwd/__init__.py": "A0uq2t4n5Y34mEgxb9Nzxk9sKsYr2FZ4sF+RoEilOmo=",
|
| 59 |
+
"src/sm100/fwd/atten_fwd.py": "4LJaUh2pn3QiwcMr+8QOVUJjNIAQqYal1xFJ/1takQY=",
|
| 60 |
+
"src/sm100/fwd/atten_fwd_nvfp4_kv.py": "EqU+ehJasAa9NvpDWipMPxaptOw+vcojprVas+b+x18=",
|
| 61 |
+
"src/sm100/fwd/combine.py": "7rQW4rUpzy0M19u+/iLfHHGMbAIQhi4HEnYeLu/qmi4=",
|
| 62 |
+
"src/sm100/fwd_decode/__init__.py": "XQJdwvLQm29RwVqVZvCstEnTx+dhUrwmH6RcW675pR8=",
|
| 63 |
+
"src/sm100/fwd_decode/atten_fwd.py": "3S4iE9h6fXUBjas51fRbakqnOzN79f0QUJ/EBRm+Ckg=",
|
| 64 |
+
"src/sm100/fwd_decode/build_decode_schedule/__init__.py": "qUElKK/HC03N9ntOA0sc8LB08jF5MWd7wq3MUnu4wgM=",
|
| 65 |
+
"src/sm100/fwd_decode/combine.py": "wIvKZzHissMLe83PUbybUoM39HTMIAexHw5I1yfJH94=",
|
| 66 |
+
"src/sm100/fwd_decode/tile_scheduler.py": "OWdID5fCFmwXqz6RtseFphfJtezOOQ091K+bJFcD6bc=",
|
| 67 |
+
"src/sm100/prepare_k2q_csr.py": "nCeG6m24dLNwJeQDFppjqR3wVCDxMY0we+20zEEeMy8=",
|
| 68 |
+
"src/sm100/prepare_scheduler.py": "CQuJI6Fn0uR0oMcfzmlIH+bjg+2uKTzqCXbw5H0YgSw="
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
build/torch211-cxx11-cu128-x86_64-linux/metadata.json.sigstore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"mediaType":"application/vnd.dev.sigstore.bundle.v0.3+json","verificationMaterial":{"certificate":{"rawBytes":"MIIHTDCCBtGgAwIBAgIUXQHYSDFOSO1tjFUUICxJvOGeZcMwCgYIKoZIzj0EAwMwNzEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MR4wHAYDVQQDExVzaWdzdG9yZS1pbnRlcm1lZGlhdGUwHhcNMjYwNjMwMTc0NDA4WhcNMjYwNjMwMTc1NDA4WjAAMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPXM0K6Fgcg5CUSklxxl2csu3F3KVSv8zPaW2wSeCwTB487WjsTVM+EqcLz/LSKUD5XL4tCAc1+gFBa30H4iDgKOCBfAwggXsMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDAzAdBgNVHQ4EFgQUfsSvN2oaJ+OmV0cSOHDNe9Nc/qUwHwYDVR0jBBgwFoAU39Ppz1YkEZb5qNjpKFWixi4YZD8wawYDVR0RAQH/BGEwX4ZdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDkGCisGAQQBg78wAQEEK2h0dHBzOi8vdG9rZW4uYWN0aW9ucy5naXRodWJ1c2VyY29udGVudC5jb20wHwYKKwYBBAGDvzABAgQRd29ya2Zsb3dfZGlzcGF0Y2gwNgYKKwYBBAGDvzABAwQoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTATBgorBgEEAYO/MAEEBAVCdWlsZDArBgorBgEEAYO/MAEFBB1odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eTAdBgorBgEEAYO/MAEGBA9yZWZzL2hlYWRzL21haW4wOwYKKwYBBAGDvzABCAQtDCtodHRwczovL3Rva2VuLmFjdGlvbnMuZ2l0aHVidXNlcmNvbnRlbnQuY29tMG0GCisGAQQBg78wAQkEXwxdaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5Ly5naXRodWIvd29ya2Zsb3dzL2J1aWxkLnlhbWxAcmVmcy9oZWFkcy9tYWluMDgGCisGAQQBg78wAQoEKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAbBgorBgEEAYO/MAELBA0MC3NlbGYtaG9zdGVkMEAGCisGAQQBg78wAQwEMgwwaHR0cHM6Ly9naXRodWIuY29tL2h1Z2dpbmdmYWNlL2tlcm5lbHMtY29tbXVuaXR5MDgGCisGAQQBg78wAQ0EKgwoMDlkNzg1MTVjNTUzMmU3MDAyNzBlOWUxMzU1NmEyYWQwMmU5ZjVmOTAfBgorBgEEAYO/MAEOBBEMD3JlZnMvaGVhZHMvbWFpbjAaBgorBgEEAYO/MAEPBAwMCjEwNzE0NzU1MjkwLgYKKwYBBAGDvzABEAQgDB5odHRwczovL2dpdGh1Yi5jb20vaHVnZ2luZ2ZhY2UwGAYKKwYBBAGDvzABEQQKDAgyNTcyMDc0MzBtBgorBgEEAYO/MAESBF8MXWh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS8uZ2l0aHViL3dvcmtmbG93cy9idWlsZC55YW1sQHJlZnMvaGVhZHMvbWFpbjA4BgorBgEEAYO/MAETBCoMKDA5ZDc4NTE1YzU1MzJlNzAwMjcwZTllMTM1NTZhMmFkMDJlOWY1ZjkwIQYKKwYBBAGDvzABFAQTDBF3b3JrZmxvd19kaXNwYXRjaDBkBgorBgEEAYO/MAEVBFYMVGh0dHBzOi8vZ2l0aHViLmNvbS9odWdnaW5nZmFjZS9rZXJuZWxzLWNvbW11bml0eS9hY3Rpb25zL3J1bnMvMjg0NjM5NjE5NTUvYXR0ZW1wdHMvMTAWBgorBgEEAYO/MAEWBAgMBnB1YmxpYzBGBgorBgEEAYO/MAEYBDgMNnJlcG86aHVnZ2luZ2ZhY2Uva2VybmVscy1jb21tdW5pdHk6cmVmOnJlZnMvaGVhZHMvbWFpbjCBigYKKwYBBAHWeQIEAgR8BHoAeAB2AN09MGrGxxEyYxkeHJlnNwKiSl643jyt/4eKcoAvKe6OAAABnxmhlrEAAAQDAEcwRQIhAN6iYC5242Rjj5dTsIgyISVMIPYWL2i81TwWknEvZur+AiAt30f5Wif9ZHR/wsWh+ve5O9GtVpL2jPTURJTl0u2xMjAKBggqhkjOPQQDAwNpADBmAjEA4i2QuFAcvw5KQAQADHbn8kVwmCTVfjK5xdQ1bJEu5eVu4PY4Br1zC9GVk7p6opFmAjEAm7jnPQ2jC5BL90FIlwMdeEVPgNmR7svFEElrkQme43Rqt6pvdGksMAzAqaWXQFqT"},"tlogEntries":[{"logIndex":"2024793345","logId":{"keyId":"wNI9atQGlz+VWfO6LRygH4QUfY/8W4RFwiT5i5WRgB0="},"kindVersion":{"kind":"hashedrekord","version":"0.0.1"},"integratedTime":"1782841448","inclusionPromise":{"signedEntryTimestamp":"MEUCIQDoWovnRcuj8EsCnxn/h18ObLX1W2EowGsjOnjj31tjKgIgE1bqiVYG2avTTL3CutjFGVSxSQtlXFYWVfl+DRCyVUk="},"inclusionProof":{"logIndex":"1902889083","rootHash":"rTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=","treeSize":"1902889093","hashes":["o6DK+OhTtiUAKd3yIcR79MoEH+e/lGDEz7/klBOgQgQ=","QFE69AbxzyZT6lYixktLCZ3SnTobLI2F6l/FFy7U7bE=","euXxtVgM7AeowPy83tQZihH1C4RDec9dw20k4Rjy7X8=","mCF45aBQkD6Ga0kRgUZm/6GIWnlvuDEwC1rsiDj7r9A=","wCaOWjILsSS/Bc8GMCLLwZ/lR4z6kHhhDwjBR489Drg=","oREPAC441YAiXLkRB+S3slZaG/rywypoRAOWh9Onh28=","tdRUnZp2XzgIgMBhnUUzZKRYmgMR9VRE4EFRMnBcvN4=","SRE7OpzsmEEBrnt2NvwSO2YvAQJHxIzVKMjw7ssvt3A=","5DB/VRMbICRg24kfvBoq+aFOMwCKvhr1zQj5SpDh5Ck=","NRxwUF55kxkZUtVui8nzfzj4LLT960XpxpXnY6C7pqs=","KTak07KIu/wsxelNu7DaqjZg2G0WnevWjQkjflcCfjI=","o03232Stm2HWKs2uG6lq2NP4O1Zym1pjI+LbQCbPISY=","nGtXNKgDUZj+ZjPgQKuKFp9orlBq81iSk8yjysQUTIU=","+/rlNRIrSvbSLthLGxHY8saYzo8HTl12uoWcFuXbbE0=","tC4XX6tUr8g/3yF+0T8f2DfrTWQmbDBfMxTOmNuWyzI=","E8u2TYaBleTNUd9vupjpxhOMu+bExC1kpTjfOk2GAUA=","cJbCQtmuzzN6T9df9SuhiY4cyCN7ezf1n+yFrgRkcgE=","+/VZ56MsIPxMiyLAodzKXo5TEWdQp36z89qLhpzloAo=","daxmZaajRpZV+JxHiOYZhJBiSKN5ucqjh2WnGbHhirw=","DOCeoSMovIvLExkhIvisow9AuNXgeWs4ECkyR6EcqYU="],"checkpoint":{"envelope":"rekor.sigstore.dev - 1193050959916656506\n1902889093\nrTzAPs80Dh6PVJ0tfFBFa06/Bp0jBkLOYrqCKGcj2Jw=\n\n— rekor.sigstore.dev wNI9ajBFAiBuldB8XClfqbEMlZnWsMAPF1CWf+PfKW6kiBU0RaE3YwIhAKQGXPHErozLpsxzvdgVeeJVRUx9RGAtRP5qoXqfKhJm\n"}},"canonicalizedBody":"eyJhcGlWZXJzaW9uIjoiMC4wLjEiLCJraW5kIjoiaGFzaGVkcmVrb3JkIiwic3BlYyI6eyJkYXRhIjp7Imhhc2giOnsiYWxnb3JpdGhtIjoic2hhMjU2IiwidmFsdWUiOiIyMzVlYjhiNGYxZmIyOWIzZWU4OTNlNzI4ODU1NDc3N2E3YzE3ZTVhNzNkNDM3YTc0M2JlNzAxOGYyOWQ5OGI4In19LCJzaWduYXR1cmUiOnsiY29udGVudCI6Ik1FVUNJQ1dkOUxlZ3ZSb0oxWDZIQUwway9SV1BvTG1sbS9YU3c3VXhOWmNpSFMwc0FpRUE3U1phSlJXVGlHdlJIWWh2d0pLS0RwRDVnRUNZT25GMGMzRURMT0VTOWNNPSIsInB1YmxpY0tleSI6eyJjb250ZW50IjoiTFMwdExTMUNSVWRKVGlCRFJWSlVTVVpKUTBGVVJTMHRMUzB0Q2sxSlNVaFVSRU5EUW5SSFowRjNTVUpCWjBsVldGRklXVk5FUms5VFR6RjBha1pWVlVsRGVFcDJUMGRsV21OTmQwTm5XVWxMYjFwSmVtb3dSVUYzVFhjS1RucEZWazFDVFVkQk1WVkZRMmhOVFdNeWJHNWpNMUoyWTIxVmRWcEhWakpOVWpSM1NFRlpSRlpSVVVSRmVGWjZZVmRrZW1SSE9YbGFVekZ3WW01U2JBcGpiVEZzV2tkc2FHUkhWWGRJYUdOT1RXcFpkMDVxVFhkTlZHTXdUa1JCTkZkb1kwNU5hbGwzVG1wTmQwMVVZekZPUkVFMFYycEJRVTFHYTNkRmQxbElDa3R2V2tsNmFqQkRRVkZaU1V0dldrbDZhakJFUVZGalJGRm5RVVZRV0Uwd1N6WkdaMk5uTlVOVlUydHNlSGhzTW1OemRUTkdNMHRXVTNZNGVsQmhWeklLZDFObFEzZFVRalE0TjFkcWMxUldUU3RGY1dOTWVpOU1VMHRWUkRWWVREUjBRMEZqTVN0blJrSmhNekJJTkdsRVowdFBRMEptUVhkbloxaHpUVUUwUndwQk1WVmtSSGRGUWk5M1VVVkJkMGxJWjBSQlZFSm5UbFpJVTFWRlJFUkJTMEpuWjNKQ1owVkdRbEZqUkVGNlFXUkNaMDVXU0ZFMFJVWm5VVlZtYzFOMkNrNHliMkZLSzA5dFZqQmpVMDlJUkU1bE9VNWpMM0ZWZDBoM1dVUldVakJxUWtKbmQwWnZRVlV6T1ZCd2VqRlphMFZhWWpWeFRtcHdTMFpYYVhocE5Ga0tXa1E0ZDJGM1dVUldVakJTUVZGSUwwSkhSWGRZTkZwa1lVaFNNR05JVFRaTWVUbHVZVmhTYjJSWFNYVlpNamwwVERKb01Wb3laSEJpYldSdFdWZE9iQXBNTW5Sc1kyMDFiR0pJVFhSWk1qbDBZbGhXZFdGWVVqVk1lVFZ1WVZoU2IyUlhTWFprTWpsNVlUSmFjMkl6WkhwTU1rb3hZVmQ0YTB4dWJHaGlWM2hCQ21OdFZtMWplVGx2V2xkR2EyTjVPWFJaVjJ4MVRVUnJSME5wYzBkQlVWRkNaemM0ZDBGUlJVVkxNbWd3WkVoQ2VrOXBPSFprUnpseVdsYzBkVmxYVGpBS1lWYzVkV041Tlc1aFdGSnZaRmRLTVdNeVZubFpNamwxWkVkV2RXUkROV3BpTWpCM1NIZFpTMHQzV1VKQ1FVZEVkbnBCUWtGblVWSmtNamw1WVRKYWN3cGlNMlJtV2tkc2VtTkhSakJaTW1kM1RtZFpTMHQzV1VKQ1FVZEVkbnBCUWtGM1VXOU5SR3hyVG5wbk1VMVVWbXBPVkZWNlRXMVZNMDFFUVhsT2VrSnNDazlYVlhoTmVsVXhUbTFGZVZsWFVYZE5iVlUxV21wV2JVOVVRVlJDWjI5eVFtZEZSVUZaVHk5TlFVVkZRa0ZXUTJSWGJITmFSRUZ5UW1kdmNrSm5SVVVLUVZsUEwwMUJSVVpDUWpGdlpGZGtibUZYTlc1YWJVWnFXbE01Y2xwWVNuVmFWM2g2VEZkT2RtSlhNVEZpYld3d1pWUkJaRUpuYjNKQ1owVkZRVmxQTHdwTlFVVkhRa0U1ZVZwWFducE1NbWhzV1ZkU2Vrd3lNV2hoVnpSM1QzZFpTMHQzV1VKQ1FVZEVkbnBCUWtOQlVYUkVRM1J2WkVoU2QyTjZiM1pNTTFKMkNtRXlWblZNYlVacVpFZHNkbUp1VFhWYU1td3dZVWhXYVdSWVRteGpiVTUyWW01U2JHSnVVWFZaTWpsMFRVY3dSME5wYzBkQlVWRkNaemM0ZDBGUmEwVUtXSGQ0WkdGSVVqQmpTRTAyVEhrNWJtRllVbTlrVjBsMVdUSTVkRXd5YURGYU1tUndZbTFrYlZsWFRteE1NblJzWTIwMWJHSklUWFJaTWpsMFlsaFdkUXBoV0ZJMVRIazFibUZZVW05a1YwbDJaREk1ZVdFeVduTmlNMlI2VERKS01XRlhlR3RNYm14b1lsZDRRV050Vm0xamVUbHZXbGRHYTJONU9YUlpWMngxQ2sxRVowZERhWE5IUVZGUlFtYzNPSGRCVVc5RlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXhQVjFWNFRYcFZNVTV0UlhrS1dWZFJkMDF0VlRWYWFsWnRUMVJCWWtKbmIzSkNaMFZGUVZsUEwwMUJSVXhDUVRCTlF6Tk9iR0pIV1hSaFJ6bDZaRWRXYTAxRlFVZERhWE5IUVZGUlFncG5OemgzUVZGM1JVMW5kM2RoU0ZJd1kwaE5Oa3g1T1c1aFdGSnZaRmRKZFZreU9YUk1NbWd4V2pKa2NHSnRaRzFaVjA1c1RESjBiR050Tld4aVNFMTBDbGt5T1hSaVdGWjFZVmhTTlUxRVowZERhWE5IUVZGUlFtYzNPSGRCVVRCRlMyZDNiMDFFYkd0T2VtY3hUVlJXYWs1VVZYcE5iVlV6VFVSQmVVNTZRbXdLVDFkVmVFMTZWVEZPYlVWNVdWZFJkMDF0VlRWYWFsWnRUMVJCWmtKbmIzSkNaMFZGUVZsUEwwMUJSVTlDUWtWTlJETktiRnB1VFhaaFIxWm9Xa2hOZGdwaVYwWndZbXBCWVVKbmIzSkNaMFZGUVZsUEwwMUJSVkJDUVhkTlEycEZkMDU2UlRCT2VsVXhUV3ByZDB4bldVdExkMWxDUWtGSFJIWjZRVUpGUVZGbkNrUkNOVzlrU0ZKM1kzcHZka3d5WkhCa1IyZ3hXV2sxYW1JeU1IWmhTRlp1V2pKc2RWb3lXbWhaTWxWM1IwRlpTMHQzV1VKQ1FVZEVkbnBCUWtWUlVVc0tSRUZuZVU1VVkzbE5SR013VFhwQ2RFSm5iM0pDWjBWRlFWbFBMMDFCUlZOQ1JqaE5XRmRvTUdSSVFucFBhVGgyV2pKc01HRklWbWxNYlU1MllsTTVid3BrVjJSdVlWYzFibHB0Um1wYVV6bHlXbGhLZFZwWGVIcE1WMDUyWWxjeE1XSnRiREJsVXpoMVdqSnNNR0ZJVm1sTU0yUjJZMjEwYldKSE9UTmplVGxwQ21SWGJITmFRelUxV1ZjeGMxRklTbXhhYmsxMllVZFdhRnBJVFhaaVYwWndZbXBCTkVKbmIzSkNaMFZGUVZsUEwwMUJSVlJDUTI5TlMwUkJOVnBFWXpRS1RsUkZNVmw2VlRGTmVrcHNUbnBCZDAxcVkzZGFWR3hzVFZSTk1VNVVXbWhOYlVaclRVUktiRTlYV1RGYWFtdDNTVkZaUzB0M1dVSkNRVWRFZG5wQlFncEdRVkZVUkVKR00ySXpTbkphYlhoMlpERTVhMkZZVG5kWldGSnFZVVJDYTBKbmIzSkNaMFZGUVZsUEwwMUJSVlpDUmxsTlZrZG9NR1JJUW5wUGFUaDJDbG95YkRCaFNGWnBURzFPZG1KVE9XOWtWMlJ1WVZjMWJscHRSbXBhVXpseVdsaEtkVnBYZUhwTVYwNTJZbGN4TVdKdGJEQmxVemxvV1ROU2NHSXlOWG9LVEROS01XSnVUWFpOYW1jd1RtcE5OVTVxUlRWT1ZGVjJXVmhTTUZwWE1YZGtTRTEyVFZSQlYwSm5iM0pDWjBWRlFWbFBMMDFCUlZkQ1FXZE5RbTVDTVFwWmJYaHdXWHBDUjBKbmIzSkNaMFZGUVZsUEwwMUJSVmxDUkdkTlRtNUtiR05IT0RaaFNGWnVXakpzZFZveVdtaFpNbFYyWVRKV2VXSnRWbk5qZVRGcUNtSXlNWFJrVnpWd1pFaHJObU50Vm0xUGJrcHNXbTVOZG1GSFZtaGFTRTEyWWxkR2NHSnFRMEpwWjFsTFMzZFpRa0pCU0ZkbFVVbEZRV2RTT0VKSWIwRUtaVUZDTWtGT01EbE5SM0pIZUhoRmVWbDRhMlZJU214dVRuZExhVk5zTmpRemFubDBMelJsUzJOdlFYWkxaVFpQUVVGQlFtNTRiV2hzY2tWQlFVRlJSQXBCUldOM1VsRkphRUZPTm1sWlF6VXlOREpTYW1vMVpGUnpTV2Q1U1ZOV1RVbFFXVmRNTW1rNE1WUjNWMnR1UlhaYWRYSXJRV2xCZERNd1pqVlhhV1k1Q2xwSVVpOTNjMWRvSzNabE5VODVSM1JXY0V3eWFsQlVWVkpLVkd3d2RUSjRUV3BCUzBKblozRm9hMnBQVUZGUlJFRjNUbkJCUkVKdFFXcEZRVFJwTWxFS2RVWkJZM1ozTlV0UlFWRkJSRWhpYmpoclZuZHRRMVJXWm1wTE5YaGtVVEZpU2tWMU5XVldkVFJRV1RSQ2NqRjZRemxIVm1zM2NEWnZjRVp0UVdwRlFRcHROMnB1VUZFeWFrTTFRa3c1TUVaSmJIZE5aR1ZGVmxCblRtMVNOM04yUmtWRmJISnJVVzFsTkROU2NYUTJjSFprUjJ0elRVRjZRWEZoVjFoUlJuRlVDaTB0TFMwdFJVNUVJRU5GVWxSSlJrbERRVlJGTFMwdExTMEsifX19fQ=="}],"timestampVerificationData":{"rfc3161Timestamps":[{"signedTimestamp":"MIICyDADAgEAMIICvwYJKoZIhvcNAQcCoIICsDCCAqwCAQMxDTALBglghkgBZQMEAgEwgbcGCyqGSIb3DQEJEAEEoIGnBIGkMIGhAgEBBgkrBgEEAYO/MAIwMTANBglghkgBZQMEAgEFAAQghcKBnsFCpVtXbanqDCSR8zDubO5wb4xvtguYuZJRTKMCFGXfBMQDzomI8IngRpeuarmPZQoDGA8yMDI2MDYzMDE3NDQwOFowAwIBAaAypDAwLjEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MRUwEwYDVQQDEwxzaWdzdG9yZS10c2GgADGCAdowggHWAgEBMFEwOTEVMBMGA1UEChMMc2lnc3RvcmUuZGV2MSAwHgYDVQQDExdzaWdzdG9yZS10c2Etc2VsZnNpZ25lZAIUOhNULwyQYe68wUMvy4qOiyojiwwwCwYJYIZIAWUDBAIBoIH8MBoGCSqGSIb3DQEJAzENBgsqhkiG9w0BCRABBDAcBgkqhkiG9w0BCQUxDxcNMjYwNjMwMTc0NDA4WjAvBgkqhkiG9w0BCQQxIgQgczwr9pKyxDMc0eur+DGt9Mdetezf8UQKp2Sn3wspffwwgY4GCyqGSIb3DQEJEAIvMX8wfTB7MHkEIIX5J7wHq2LKw7RDVsEO/IGyxog/2nq55thw2dE6zQW3MFUwPaQ7MDkxFTATBgNVBAoTDHNpZ3N0b3JlLmRldjEgMB4GA1UEAxMXc2lnc3RvcmUtdHNhLXNlbGZzaWduZWQCFDoTVC8MkGHuvMFDL8uKjosqI4sMMAoGCCqGSM49BAMCBGYwZAIwJmfpM3hVIBsGwNTieyT54BZfQTwFye2f0/les1QzRFpXz5nu59C0tKLFYqcNPDdQAjBI9y5eNjjl9yo9BtpcZmIjURLuYioqzrjahNDmiThJZgRNROaVkPWrE5dlDJoFe58="}]}},"messageSignature":{"messageDigest":{"algorithm":"SHA2_256","digest":"I164tPH7KbPuiT5yiFVHd6fBflpz1DenQ75wGPKdmLg="},"signature":"MEUCICWd9LegvRoJ1X6HAL0k/RWPoLmlm/XSw7UxNZciHS0sAiEA7SZaJRWTiGvRHYhvwJKKDpD5gECYOnF0c3EDLOES9cM="}}
|
build/torch211-cxx11-cu128-x86_64-linux/msa/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import ModuleType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch211-cxx11-cu128-x86_64-linux/quack/__init__.py
ADDED
|
File without changes
|
build/torch211-cxx11-cu128-x86_64-linux/quack/activation.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Float32, Boolean, const_expr
|
| 9 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 10 |
+
from cutlass._mlir.dialects import llvm, nvvm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
sub_packed_f32x2 = partial(
|
| 17 |
+
cute.arch.calc_packed_f32x2_op,
|
| 18 |
+
src_c=None,
|
| 19 |
+
calc_func=nvvm.sub_packed_f32x2,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 25 |
+
return Float32(
|
| 26 |
+
llvm.inline_asm(
|
| 27 |
+
T.f32(),
|
| 28 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 29 |
+
"tanh.approx.f32 $0, $1;",
|
| 30 |
+
"=f,f",
|
| 31 |
+
has_side_effects=False,
|
| 32 |
+
is_align_stack=False,
|
| 33 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dsl_user_op
|
| 39 |
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 40 |
+
if const_expr(not isinstance(x, tuple)):
|
| 41 |
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
| 42 |
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
| 43 |
+
else:
|
| 44 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 45 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 46 |
+
return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dsl_user_op
|
| 50 |
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 51 |
+
# return dout * out * (1.0 - out)
|
| 52 |
+
return dout * (out - out * out)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dsl_user_op
|
| 56 |
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 57 |
+
if const_expr(not isinstance(x, tuple)):
|
| 58 |
+
return cute.arch.fmax(x, Float32(0.0))
|
| 59 |
+
else:
|
| 60 |
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dsl_user_op
|
| 64 |
+
@cute.jit
|
| 65 |
+
def drelu(
|
| 66 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 67 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 68 |
+
if const_expr(not isinstance(x, tuple)):
|
| 69 |
+
x_pos = Boolean(x > 0)
|
| 70 |
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
| 71 |
+
else:
|
| 72 |
+
x0_pos = Boolean(x[0] > 0)
|
| 73 |
+
x1_pos = Boolean(x[1] > 0)
|
| 74 |
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
| 75 |
+
return dx, relu(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dsl_user_op
|
| 79 |
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 80 |
+
if const_expr(not isinstance(x, tuple)):
|
| 81 |
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
| 82 |
+
else:
|
| 83 |
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
| 84 |
+
return cute.arch.mul_packed_f32x2(relu_x, x)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dsl_user_op
|
| 88 |
+
@cute.jit
|
| 89 |
+
def drelu_sq(
|
| 90 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 91 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 92 |
+
"""
|
| 93 |
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
| 94 |
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
| 95 |
+
Returns: (dx, relu_sq_out) where:
|
| 96 |
+
- dx = dout * 2 * x if x > 0, else 0
|
| 97 |
+
- relu_sq_out = max(x, 0) * x
|
| 98 |
+
"""
|
| 99 |
+
if const_expr(not isinstance(x, tuple)):
|
| 100 |
+
relu_x = relu(x)
|
| 101 |
+
relu_sq_out = relu_x * x
|
| 102 |
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
| 103 |
+
dx = 2.0 * (dout * relu_x)
|
| 104 |
+
return dx, relu_sq_out
|
| 105 |
+
else:
|
| 106 |
+
relu_x = relu(x)
|
| 107 |
+
relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
|
| 108 |
+
dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
|
| 109 |
+
return dx, relu_sq_out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dsl_user_op
|
| 113 |
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 114 |
+
"""
|
| 115 |
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
| 116 |
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
| 117 |
+
"""
|
| 118 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
| 119 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
| 120 |
+
if const_expr(not isinstance(x, tuple)):
|
| 121 |
+
return 0.5 * (
|
| 122 |
+
x
|
| 123 |
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
| 124 |
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
| 125 |
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 129 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 130 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 131 |
+
)
|
| 132 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 133 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 134 |
+
x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
|
| 135 |
+
return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dsl_user_op
|
| 139 |
+
def dgelu_tanh_approx(
|
| 140 |
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 141 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
| 142 |
+
"""
|
| 143 |
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
| 144 |
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
| 145 |
+
Returns: (dx, gelu_out)
|
| 146 |
+
|
| 147 |
+
Derivative uses the chain rule:
|
| 148 |
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 149 |
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
| 150 |
+
and sech^2(z) = 1 - tanh^2(z)
|
| 151 |
+
"""
|
| 152 |
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
| 153 |
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
| 154 |
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
| 155 |
+
|
| 156 |
+
if const_expr(not isinstance(x, tuple)):
|
| 157 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 158 |
+
x_sq = x * x
|
| 159 |
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
| 160 |
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
| 161 |
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
| 162 |
+
gelu_out = x * half_tanh_z_plus_one
|
| 163 |
+
|
| 164 |
+
# Compute gradient
|
| 165 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 166 |
+
sech2_z = 1 - tanh_z * tanh_z
|
| 167 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 168 |
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
| 169 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 170 |
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
| 171 |
+
|
| 172 |
+
dx = dout * dgelu
|
| 173 |
+
return dx, gelu_out
|
| 174 |
+
else:
|
| 175 |
+
# Compute z = x * (c1 + c2 * x^2)
|
| 176 |
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
| 177 |
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
| 178 |
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 179 |
+
)
|
| 180 |
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
| 181 |
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
| 182 |
+
half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
| 183 |
+
gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
| 184 |
+
|
| 185 |
+
# Compute gradient
|
| 186 |
+
# sech^2(z) = 1 - tanh^2(z)
|
| 187 |
+
sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
| 188 |
+
# dz/dx = c1 + 3 * c2 * x^2
|
| 189 |
+
dz_dx = cute.arch.fma_packed_f32x2(
|
| 190 |
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
| 191 |
+
)
|
| 192 |
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
| 193 |
+
sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
|
| 194 |
+
x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
|
| 195 |
+
dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
| 196 |
+
|
| 197 |
+
dx = cute.arch.mul_packed_f32x2(dout, dgelu)
|
| 198 |
+
return dx, gelu_out
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dsl_user_op
|
| 202 |
+
@cute.jit
|
| 203 |
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 204 |
+
if const_expr(not isinstance(x, tuple)):
|
| 205 |
+
use_linear = Boolean(x > 20.0)
|
| 206 |
+
return (
|
| 207 |
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
| 208 |
+
if not use_linear
|
| 209 |
+
else x
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
log2_e = math.log2(math.e)
|
| 213 |
+
x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
|
| 214 |
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
| 215 |
+
x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
|
| 216 |
+
log_x_exp_p1 = (
|
| 217 |
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
| 218 |
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
| 219 |
+
)
|
| 220 |
+
ln2 = math.log(2.0)
|
| 221 |
+
softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
| 222 |
+
use_linear_0 = Boolean(x[0] > 20.0)
|
| 223 |
+
use_linear_1 = Boolean(x[1] > 20.0)
|
| 224 |
+
return (
|
| 225 |
+
softplus_x[0] if not use_linear_0 else x[0],
|
| 226 |
+
softplus_x[1] if not use_linear_1 else x[1],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dsl_user_op
|
| 231 |
+
@cute.jit
|
| 232 |
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
| 233 |
+
use_linear = Boolean(out > 20.0)
|
| 234 |
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
| 235 |
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
| 236 |
+
return dx if not use_linear else dout
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dsl_user_op
|
| 240 |
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
| 241 |
+
"""
|
| 242 |
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
| 243 |
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
| 244 |
+
"""
|
| 245 |
+
if const_expr(not isinstance(x, tuple)):
|
| 246 |
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
| 247 |
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
| 248 |
+
return x_half * tanh(x_half) + x_half
|
| 249 |
+
else:
|
| 250 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
| 251 |
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
| 252 |
+
return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@dsl_user_op
|
| 256 |
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 257 |
+
if const_expr(not isinstance(x, tuple)):
|
| 258 |
+
return silu(x) * y
|
| 259 |
+
else:
|
| 260 |
+
return cute.arch.mul_packed_f32x2(silu(x), y)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dsl_user_op
|
| 264 |
+
def dswiglu(
|
| 265 |
+
x: F32_or_F32x2,
|
| 266 |
+
y: F32_or_F32x2,
|
| 267 |
+
dout: F32_or_F32x2,
|
| 268 |
+
*,
|
| 269 |
+
already_halved: bool = False,
|
| 270 |
+
loc=None,
|
| 271 |
+
ip=None,
|
| 272 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 273 |
+
"""
|
| 274 |
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 275 |
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
| 276 |
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
| 277 |
+
|
| 278 |
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
| 279 |
+
|
| 280 |
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
| 281 |
+
to use FFMA instead of FADD and FMUL).
|
| 282 |
+
"""
|
| 283 |
+
if const_expr(not isinstance(x, tuple)):
|
| 284 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
| 285 |
+
# FMUL, MUFU.TANH, then FFMA
|
| 286 |
+
if const_expr(not already_halved):
|
| 287 |
+
sigmoid_x = sigmoid(x)
|
| 288 |
+
silu_x = x * sigmoid_x # FMUL
|
| 289 |
+
else:
|
| 290 |
+
tanh_x = tanh(x) # MUFU.TANH
|
| 291 |
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
| 292 |
+
silu_x = x * tanh_x + x # FFMA
|
| 293 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 294 |
+
# d_silu(x) * dout
|
| 295 |
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
| 296 |
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
| 297 |
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
| 298 |
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
| 299 |
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 300 |
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
| 301 |
+
dx = d_silu_x_dout * y # FMUL
|
| 302 |
+
dy = silu_x_dout
|
| 303 |
+
swiglu_out = silu_x * y # FMUL
|
| 304 |
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
| 305 |
+
return dx, dy, swiglu_out
|
| 306 |
+
else:
|
| 307 |
+
# Compute sigmoid(x) and silu(x)
|
| 308 |
+
if const_expr(not already_halved):
|
| 309 |
+
sigmoid_x = sigmoid(x)
|
| 310 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
|
| 311 |
+
else:
|
| 312 |
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
| 313 |
+
sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
| 314 |
+
silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
|
| 315 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 316 |
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
| 317 |
+
sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
|
| 318 |
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
| 319 |
+
)
|
| 320 |
+
d_silu_x_dout = cute.arch.fma_packed_f32x2(
|
| 321 |
+
sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
|
| 322 |
+
)
|
| 323 |
+
dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
|
| 324 |
+
dy = silu_x_dout
|
| 325 |
+
swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
|
| 326 |
+
return dx, dy, swiglu_out
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@dsl_user_op
|
| 330 |
+
def swiglu_oai(
|
| 331 |
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 332 |
+
) -> F32_or_F32x2:
|
| 333 |
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
| 334 |
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
| 335 |
+
x * sigmoid(alpha * x) * (y + 1)
|
| 336 |
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
| 337 |
+
"""
|
| 338 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 339 |
+
if const_expr(not isinstance(x, tuple)):
|
| 340 |
+
x_half = 0.5 * x
|
| 341 |
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
| 342 |
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
| 343 |
+
return silu_x * y + silu_x
|
| 344 |
+
else:
|
| 345 |
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
| 346 |
+
alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
|
| 347 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 348 |
+
silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
| 349 |
+
return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@dsl_user_op
|
| 353 |
+
def dswiglu_oai(
|
| 354 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
| 355 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 356 |
+
"""
|
| 357 |
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
| 358 |
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
| 359 |
+
Returns: (dx, dy, swiglu_oai_out)
|
| 360 |
+
|
| 361 |
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
| 362 |
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
| 363 |
+
"""
|
| 364 |
+
if const_expr(not isinstance(x, tuple)):
|
| 365 |
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
| 366 |
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
| 367 |
+
# MUFU.TANH, then FFMA
|
| 368 |
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
| 369 |
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
| 370 |
+
silu_x = x * sigmoid_alpha_x # FMUL
|
| 371 |
+
silu_x_dout = silu_x * dout # FMUL
|
| 372 |
+
# FFMA, FFMA, FMUL
|
| 373 |
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 374 |
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
| 375 |
+
dy = silu_x_dout
|
| 376 |
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
| 377 |
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
| 378 |
+
return dx, dy, swiglu_out
|
| 379 |
+
else:
|
| 380 |
+
# Compute sigmoid(alpha * x)
|
| 381 |
+
alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
| 382 |
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
| 383 |
+
sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
| 384 |
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
|
| 385 |
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
| 386 |
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
| 387 |
+
silu_x_minus_product = cute.arch.fma_packed_f32x2(
|
| 388 |
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
| 389 |
+
)
|
| 390 |
+
sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
|
| 391 |
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
| 392 |
+
)
|
| 393 |
+
d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
| 394 |
+
dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
| 395 |
+
dy = silu_x_dout
|
| 396 |
+
swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
| 397 |
+
return dx, dy, swiglu_out
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@dsl_user_op
|
| 401 |
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 402 |
+
"""GLU: Gated Linear Unit
|
| 403 |
+
glu(x, y) = sigmoid(x) * y
|
| 404 |
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 405 |
+
"""
|
| 406 |
+
if const_expr(not isinstance(x, tuple)):
|
| 407 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 408 |
+
return sigmoid_x * y # FMUL
|
| 409 |
+
else:
|
| 410 |
+
sigmoid_x = sigmoid(x)
|
| 411 |
+
return cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@dsl_user_op
|
| 415 |
+
def dglu(
|
| 416 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 417 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 418 |
+
"""
|
| 419 |
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 420 |
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
| 421 |
+
Returns: (dx, dy, glu_out) where:
|
| 422 |
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
| 423 |
+
- dy = dout * sigmoid(x)
|
| 424 |
+
- glu_out = sigmoid(x) * y
|
| 425 |
+
"""
|
| 426 |
+
if const_expr(not isinstance(x, tuple)):
|
| 427 |
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
| 428 |
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
| 429 |
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
| 430 |
+
glu_out = sigmoid_x * y # FMUL
|
| 431 |
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
| 432 |
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
| 433 |
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
| 434 |
+
# = (y - glu_out) * sigmoid_x_dout
|
| 435 |
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
| 436 |
+
dy = sigmoid_x_dout
|
| 437 |
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
| 438 |
+
return dx, dy, glu_out
|
| 439 |
+
else:
|
| 440 |
+
sigmoid_x = sigmoid(x)
|
| 441 |
+
sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
|
| 442 |
+
glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
| 443 |
+
# dx = (y - glu_out) * sigmoid_x_dout
|
| 444 |
+
y_minus_glu_out = sub_packed_f32x2(y, glu_out)
|
| 445 |
+
dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
| 446 |
+
dy = sigmoid_x_dout
|
| 447 |
+
return dx, dy, glu_out
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@dsl_user_op
|
| 451 |
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 452 |
+
"""ReGLU: ReLU Gated Linear Unit
|
| 453 |
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
| 454 |
+
"""
|
| 455 |
+
if const_expr(not isinstance(x, tuple)):
|
| 456 |
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
| 457 |
+
else:
|
| 458 |
+
relu_x = relu(x)
|
| 459 |
+
return cute.arch.mul_packed_f32x2(relu_x, y)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@dsl_user_op
|
| 463 |
+
@cute.jit
|
| 464 |
+
def dreglu(
|
| 465 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 466 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 467 |
+
"""
|
| 468 |
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 469 |
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
| 470 |
+
Returns: (dx, dy, reglu_out) where:
|
| 471 |
+
- dx = dout * y if x > 0, else 0
|
| 472 |
+
- dy = dout * relu(x)
|
| 473 |
+
- reglu_out = relu(x) * y
|
| 474 |
+
"""
|
| 475 |
+
if const_expr(not isinstance(x, tuple)):
|
| 476 |
+
x_pos = Boolean(x > 0)
|
| 477 |
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
| 478 |
+
dx = (dout * y) if x_pos else Float32(0.0)
|
| 479 |
+
dy = dout * relu_x
|
| 480 |
+
reglu_out = relu_x * y
|
| 481 |
+
return dx, dy, reglu_out
|
| 482 |
+
else:
|
| 483 |
+
x0_pos = Boolean(x[0] > 0)
|
| 484 |
+
x1_pos = Boolean(x[1] > 0)
|
| 485 |
+
relu_x = relu(x)
|
| 486 |
+
dout_y = cute.arch.mul_packed_f32x2(dout, y)
|
| 487 |
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
| 488 |
+
dy = cute.arch.mul_packed_f32x2(dout, relu_x)
|
| 489 |
+
reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
|
| 490 |
+
return dx, dy, reglu_out
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@dsl_user_op
|
| 494 |
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
| 495 |
+
"""GeGLU: GELU Gated Linear Unit
|
| 496 |
+
geglu(x, y) = gelu(x) * y
|
| 497 |
+
Uses the tanh approximation of GELU
|
| 498 |
+
"""
|
| 499 |
+
if const_expr(not isinstance(x, tuple)):
|
| 500 |
+
return gelu_tanh_approx(x) * y
|
| 501 |
+
else:
|
| 502 |
+
return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
@dsl_user_op
|
| 506 |
+
def dgeglu(
|
| 507 |
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
| 508 |
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
| 509 |
+
"""
|
| 510 |
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
| 511 |
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
| 512 |
+
Returns: (dx, dy, geglu_out) where:
|
| 513 |
+
- dx = dout * y * d_gelu(x)
|
| 514 |
+
- dy = dout * gelu(x)
|
| 515 |
+
- geglu_out = gelu(x) * y
|
| 516 |
+
"""
|
| 517 |
+
if const_expr(not isinstance(x, tuple)):
|
| 518 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 519 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 520 |
+
# Compute gradients for geglu
|
| 521 |
+
dx = dgelu_x_dout * y
|
| 522 |
+
dy = gelu_x * dout
|
| 523 |
+
geglu_out = gelu_x * y
|
| 524 |
+
return dx, dy, geglu_out
|
| 525 |
+
else:
|
| 526 |
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
| 527 |
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
| 528 |
+
# Compute gradients for geglu
|
| 529 |
+
dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
|
| 530 |
+
dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
|
| 531 |
+
geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
|
| 532 |
+
return dx, dy, geglu_out
|
build/torch211-cxx11-cu128-x86_64-linux/quack/compile_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
| 9 |
+
if leading_dim < 0:
|
| 10 |
+
leading_dim = len(shape) + leading_dim
|
| 11 |
+
if dtype is None:
|
| 12 |
+
return None
|
| 13 |
+
stride = tuple(
|
| 14 |
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
| 15 |
+
for i in range(len(shape))
|
| 16 |
+
)
|
| 17 |
+
return cute.runtime.make_fake_tensor(
|
| 18 |
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
| 19 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/quack/copy_utils.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Optional, Type, Tuple, Callable, Sequence
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
|
| 10 |
+
from cutlass import Int32, Int16, Boolean, const_expr
|
| 11 |
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
| 12 |
+
from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
|
| 13 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 14 |
+
import cutlass.pipeline
|
| 15 |
+
from cutlass._mlir.dialects import llvm
|
| 16 |
+
from cutlass._mlir import ir
|
| 17 |
+
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Sm100MmaPeerBitMask = 0xFEFFFFFF
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def cvt_copy(
|
| 25 |
+
tiled_copy: cute.TiledCopy,
|
| 26 |
+
src: cute.Tensor,
|
| 27 |
+
dst: cute.Tensor,
|
| 28 |
+
*,
|
| 29 |
+
pred: Optional[cute.Tensor] = None,
|
| 30 |
+
retile: bool = False,
|
| 31 |
+
loc=None,
|
| 32 |
+
ip=None,
|
| 33 |
+
**kwargs,
|
| 34 |
+
) -> None:
|
| 35 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 36 |
+
if const_expr(src.element_type != dst.element_type):
|
| 37 |
+
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
| 38 |
+
src_cvt.store(src.load().to(dst.element_type))
|
| 39 |
+
src = src_cvt
|
| 40 |
+
if const_expr(retile):
|
| 41 |
+
src = tiled_copy.retile(src)
|
| 42 |
+
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dsl_user_op
|
| 46 |
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 47 |
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
| 48 |
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 49 |
+
return dst
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dsl_user_op
|
| 53 |
+
def load_s2r_retile(
|
| 54 |
+
tiled_copy: cute.TiledCopy,
|
| 55 |
+
src: cute.Tensor,
|
| 56 |
+
dst_shape: cute.Tensor | cute.Shape,
|
| 57 |
+
*,
|
| 58 |
+
loc=None,
|
| 59 |
+
ip=None,
|
| 60 |
+
) -> cute.Tensor:
|
| 61 |
+
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
| 62 |
+
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
| 63 |
+
dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
|
| 64 |
+
else:
|
| 65 |
+
dst = dst_shape
|
| 66 |
+
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
| 67 |
+
return dst
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dsl_user_op
|
| 71 |
+
def get_copy_atom(
|
| 72 |
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
| 73 |
+
) -> cute.CopyAtom:
|
| 74 |
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
| 75 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 76 |
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dsl_user_op
|
| 80 |
+
def copy(
|
| 81 |
+
src: cute.Tensor,
|
| 82 |
+
dst: cute.Tensor,
|
| 83 |
+
*,
|
| 84 |
+
pred: Optional[cute.Tensor] = None,
|
| 85 |
+
is_async: bool = False,
|
| 86 |
+
loc=None,
|
| 87 |
+
ip=None,
|
| 88 |
+
**kwargs,
|
| 89 |
+
) -> None:
|
| 90 |
+
num_copy_elems = src.shape[0][0]
|
| 91 |
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
| 92 |
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def tiled_copy_1d(
|
| 96 |
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
| 97 |
+
) -> cute.TiledCopy:
|
| 98 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 99 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 100 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 101 |
+
thr_layout = cute.make_layout(num_threads)
|
| 102 |
+
val_layout = cute.make_layout(num_copy_elems)
|
| 103 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def tiled_copy_2d(
|
| 107 |
+
dtype: Type[cutlass.Numeric],
|
| 108 |
+
threads_per_row: int,
|
| 109 |
+
num_threads: int,
|
| 110 |
+
num_copy_elems: int = 1,
|
| 111 |
+
is_async: bool = False,
|
| 112 |
+
) -> cute.TiledCopy:
|
| 113 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 114 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 115 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 116 |
+
assert num_threads % threads_per_row == 0
|
| 117 |
+
thr_layout = cute.make_ordered_layout(
|
| 118 |
+
(num_threads // threads_per_row, threads_per_row),
|
| 119 |
+
order=(1, 0),
|
| 120 |
+
)
|
| 121 |
+
val_layout = cute.make_layout((1, num_copy_elems))
|
| 122 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@cute.jit
|
| 126 |
+
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
| 127 |
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 128 |
+
tApA = cute.make_rmem_tensor(
|
| 129 |
+
cute.make_layout(
|
| 130 |
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 131 |
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
| 132 |
+
),
|
| 133 |
+
Boolean,
|
| 134 |
+
)
|
| 135 |
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
| 136 |
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
| 137 |
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
| 138 |
+
return tApA
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# def tiled_copy_2d(
|
| 142 |
+
# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
| 143 |
+
# ) -> cute.TiledCopy:
|
| 144 |
+
# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
| 145 |
+
# copy_elems = num_copy_bits // dtype.width
|
| 146 |
+
# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 147 |
+
# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 148 |
+
# gmem_threads_per_row = major_mode_size // copy_elems
|
| 149 |
+
# assert num_threads % gmem_threads_per_row == 0
|
| 150 |
+
# thr_layout = cute.make_ordered_layout(
|
| 151 |
+
# (num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 152 |
+
# order=(1, 0),
|
| 153 |
+
# )
|
| 154 |
+
# val_layout = cute.make_layout((1, copy_elems))
|
| 155 |
+
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
|
| 159 |
+
"""Extract swizzle parameters from a pointer's swizzle_type.
|
| 160 |
+
|
| 161 |
+
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
| 162 |
+
b, m, s are the swizzle parameters (bits, base, shift).
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
A cute.Swizzle object constructed from the extracted parameters
|
| 166 |
+
|
| 167 |
+
Raises:
|
| 168 |
+
ValueError: If the swizzle_type string cannot be parsed
|
| 169 |
+
"""
|
| 170 |
+
# Ideally there should be a better API to get swizzle parameters, but we'll just parse
|
| 171 |
+
# the string here.
|
| 172 |
+
swizzle_str = str(ptr.type.swizzle_type)
|
| 173 |
+
# Extract the inner part "S<b,m,s>"
|
| 174 |
+
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
| 175 |
+
if match:
|
| 176 |
+
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
| 177 |
+
return b, m, s
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
| 183 |
+
bit_msk = (1 << b) - 1
|
| 184 |
+
yyy_msk = bit_msk << (m + s)
|
| 185 |
+
return ptr_int ^ ((ptr_int & yyy_msk) >> s)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def swizzle_ptr(ptr: cute.Pointer):
|
| 189 |
+
b, m, s = parse_swizzle_from_pointer(ptr)
|
| 190 |
+
ptr_int = swizzle_int(ptr.toint(), b, m, s)
|
| 191 |
+
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
| 195 |
+
outer = tensor.layout
|
| 196 |
+
width = tensor.element_type.width
|
| 197 |
+
inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
|
| 198 |
+
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
| 199 |
+
# for 16 bits and <3, 2, 3> for 32 bits)
|
| 200 |
+
new_layout = cute.recast_layout(
|
| 201 |
+
width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
|
| 202 |
+
)
|
| 203 |
+
# recast_ptr to remove the pointer swizzle
|
| 204 |
+
return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def partition_D_position_independent(
|
| 208 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 209 |
+
) -> cute.Tensor:
|
| 210 |
+
return cute.make_tensor(
|
| 211 |
+
swizzle_ptr(thr_copy.partition_D(tensor).iterator),
|
| 212 |
+
thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def partition_S_position_independent(
|
| 217 |
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
| 218 |
+
) -> cute.Tensor:
|
| 219 |
+
return cute.make_tensor(
|
| 220 |
+
swizzle_ptr(thr_copy.partition_S(tensor).iterator),
|
| 221 |
+
thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@dsl_user_op
|
| 226 |
+
def sm90_get_smem_load_op(
|
| 227 |
+
layout_c: cutlass.utils.LayoutEnum,
|
| 228 |
+
elem_ty_c: Type[cutlass.Numeric],
|
| 229 |
+
*,
|
| 230 |
+
loc=None,
|
| 231 |
+
ip=None,
|
| 232 |
+
) -> cute.CopyAtom:
|
| 233 |
+
"""
|
| 234 |
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
| 235 |
+
|
| 236 |
+
Parameters:
|
| 237 |
+
-----------
|
| 238 |
+
layout_c : LayoutEnum
|
| 239 |
+
The layout enum of the output tensor D.
|
| 240 |
+
|
| 241 |
+
elem_ty_c : Type[Numeric]
|
| 242 |
+
The element type for output tensor D.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
--------
|
| 246 |
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
| 250 |
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
| 251 |
+
is_m_major = layout_c.is_m_major_c()
|
| 252 |
+
if elem_ty_c.width == 16:
|
| 253 |
+
return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
|
| 254 |
+
else:
|
| 255 |
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_smem_store_atom(
|
| 259 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 260 |
+
) -> cute.CopyAtom:
|
| 261 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 262 |
+
return cute.make_copy_atom(
|
| 263 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 264 |
+
element_type,
|
| 265 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
return cute.make_copy_atom(
|
| 269 |
+
warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 270 |
+
element_type,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_smem_load_atom(
|
| 275 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 276 |
+
) -> cute.CopyAtom:
|
| 277 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 278 |
+
return cute.make_copy_atom(
|
| 279 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 280 |
+
element_type,
|
| 281 |
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
return cute.make_copy_atom(
|
| 285 |
+
warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 286 |
+
element_type,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_smem_store_C(
|
| 291 |
+
tiled_mma: cute.TiledMma,
|
| 292 |
+
sC: cute.Tensor,
|
| 293 |
+
tidx: Int32,
|
| 294 |
+
arch: int,
|
| 295 |
+
transpose: bool = False,
|
| 296 |
+
position_independent=False,
|
| 297 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 298 |
+
dtype = sC.element_type
|
| 299 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 300 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 301 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 302 |
+
if const_expr(not position_independent):
|
| 303 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 304 |
+
else:
|
| 305 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 306 |
+
|
| 307 |
+
def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
|
| 308 |
+
dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
|
| 309 |
+
cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
|
| 310 |
+
|
| 311 |
+
return copy_fn, thr_copy, tRS_sC
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_smem_load_C(
|
| 315 |
+
tiled_mma: cute.TiledMma,
|
| 316 |
+
sC: cute.Tensor,
|
| 317 |
+
tidx: Int32,
|
| 318 |
+
arch: int,
|
| 319 |
+
transpose: bool = False,
|
| 320 |
+
position_independent=False,
|
| 321 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 322 |
+
dtype = sC.element_type
|
| 323 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 324 |
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
| 325 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 326 |
+
if const_expr(not position_independent):
|
| 327 |
+
tSR_sC = thr_copy.partition_S(sC)
|
| 328 |
+
else:
|
| 329 |
+
tSR_sC = partition_S_position_independent(thr_copy, sC)
|
| 330 |
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
| 331 |
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
| 332 |
+
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
| 333 |
+
|
| 334 |
+
def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
|
| 335 |
+
src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
|
| 336 |
+
return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
|
| 337 |
+
|
| 338 |
+
return copy_fn, thr_copy, tSR_sC
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def epilog_smem_copy_atom(
|
| 342 |
+
tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
|
| 343 |
+
) -> cute.TiledCopy:
|
| 344 |
+
copy_atom_C = cute.make_copy_atom(
|
| 345 |
+
warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
|
| 346 |
+
cutlass.Float16, # this is just to get the right source layout
|
| 347 |
+
)
|
| 348 |
+
tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
| 349 |
+
return tiled_copy_C_atom
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def get_smem_store_epi(
|
| 353 |
+
tiled_mma: cute.TiledMma,
|
| 354 |
+
epi_tile: cute.Shape,
|
| 355 |
+
sC: Optional[cute.Tensor],
|
| 356 |
+
tidx: Int32,
|
| 357 |
+
arch: int,
|
| 358 |
+
transpose: bool = False,
|
| 359 |
+
position_independent=False,
|
| 360 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
| 361 |
+
dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
|
| 362 |
+
tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
|
| 363 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 364 |
+
tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
|
| 365 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 366 |
+
tRS_sC = None
|
| 367 |
+
if const_expr(sC is not None):
|
| 368 |
+
if const_expr(not position_independent):
|
| 369 |
+
tRS_sC = thr_copy.partition_D(sC)
|
| 370 |
+
else:
|
| 371 |
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
| 372 |
+
sC_shape = sC.shape[:2] if sC is not None else epi_tile
|
| 373 |
+
# (R2S, R2S_M, R2S_N, PIPE_C)
|
| 374 |
+
tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
|
| 375 |
+
tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
|
| 376 |
+
|
| 377 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 378 |
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
|
| 379 |
+
|
| 380 |
+
return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def get_smem_store_A(
|
| 384 |
+
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
| 385 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 386 |
+
dtype = sA.element_type
|
| 387 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 388 |
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
| 389 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 390 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 391 |
+
if const_expr(not position_independent):
|
| 392 |
+
tRS_sA = thr_copy.partition_D(sA)
|
| 393 |
+
else:
|
| 394 |
+
tRS_sA = partition_D_position_independent(thr_copy, sA)
|
| 395 |
+
|
| 396 |
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
| 397 |
+
cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
|
| 398 |
+
|
| 399 |
+
return copy_fn, thr_copy, tRS_sA
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def get_smem_load_A(
|
| 403 |
+
tiled_mma: cute.TiledMma,
|
| 404 |
+
sA: cute.Tensor,
|
| 405 |
+
tidx: Int32,
|
| 406 |
+
arch: int,
|
| 407 |
+
with_dst_tensor: bool = False,
|
| 408 |
+
position_independent=False,
|
| 409 |
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
| 410 |
+
dtype = sA.element_type
|
| 411 |
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
| 412 |
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
| 413 |
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 414 |
+
thr_copy = tiled_copy.get_slice(tidx)
|
| 415 |
+
if const_expr(not position_independent):
|
| 416 |
+
tSR_sA = thr_copy.partition_S(sA)
|
| 417 |
+
else:
|
| 418 |
+
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
| 419 |
+
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
| 420 |
+
|
| 421 |
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
| 422 |
+
return load_s2r_retile(
|
| 423 |
+
tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
|
| 427 |
+
return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
|
| 428 |
+
|
| 429 |
+
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@dsl_user_op
|
| 433 |
+
def cpasync_reduce_bulk_add_f32(
|
| 434 |
+
smem_ptr: cute.Pointer,
|
| 435 |
+
gmem_ptr: cute.Pointer,
|
| 436 |
+
store_bytes: int | Int32,
|
| 437 |
+
*,
|
| 438 |
+
loc=None,
|
| 439 |
+
ip=None,
|
| 440 |
+
):
|
| 441 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 442 |
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
| 443 |
+
llvm.inline_asm(
|
| 444 |
+
None,
|
| 445 |
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
| 446 |
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
| 447 |
+
"l,r,r",
|
| 448 |
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
| 449 |
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
| 450 |
+
# "l,r,r,l",
|
| 451 |
+
has_side_effects=True,
|
| 452 |
+
is_align_stack=False,
|
| 453 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@dsl_user_op
|
| 458 |
+
def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
|
| 459 |
+
"""
|
| 460 |
+
Get the address of the TMA descriptor embedded in a TMA Copy Atom.
|
| 461 |
+
|
| 462 |
+
Extracts the constant memory address of the TMA descriptor for use with
|
| 463 |
+
custom PTX instructions.
|
| 464 |
+
|
| 465 |
+
:param tma_atom: TMA Copy Atom from make_tiled_tma_atom
|
| 466 |
+
:return: Pointer to TMA descriptor in constant memory
|
| 467 |
+
|
| 468 |
+
Example:
|
| 469 |
+
>>> desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 470 |
+
"""
|
| 471 |
+
exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
|
| 472 |
+
tma_desc_ptr_type = ir.Type.parse(
|
| 473 |
+
"!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
|
| 474 |
+
)
|
| 475 |
+
return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
@dsl_user_op
|
| 479 |
+
def tma_gather4_load(
|
| 480 |
+
tma_desc_ptr: cute.Pointer,
|
| 481 |
+
dst_smem_ptr: cute.Pointer,
|
| 482 |
+
mbarrier_ptr: cute.Pointer,
|
| 483 |
+
col_idx: Int32,
|
| 484 |
+
row_indices: Sequence[Int32],
|
| 485 |
+
*,
|
| 486 |
+
num_cta: int = 1,
|
| 487 |
+
multicast_mask=None,
|
| 488 |
+
loc=None,
|
| 489 |
+
ip=None,
|
| 490 |
+
) -> None:
|
| 491 |
+
"""
|
| 492 |
+
Perform TMA gather4 load from global memory to shared memory.
|
| 493 |
+
|
| 494 |
+
Issues PTX instruction:
|
| 495 |
+
cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 496 |
+
[dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
|
| 497 |
+
|
| 498 |
+
This loads 4 rows (specified by row_indices) from a 2D tensor at the given
|
| 499 |
+
column index into shared memory, using the TMA descriptor.
|
| 500 |
+
|
| 501 |
+
:param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
|
| 502 |
+
:type tma_desc_ptr: Pointer
|
| 503 |
+
:param dst_smem_ptr: Destination address in shared memory
|
| 504 |
+
:type dst_smem_ptr: Pointer
|
| 505 |
+
:param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
|
| 506 |
+
:type mbarrier_ptr: Pointer
|
| 507 |
+
:param col_idx: Column index
|
| 508 |
+
:type col_idx: Int32
|
| 509 |
+
:param row_indices: Sequence of exactly 4 row indices
|
| 510 |
+
:type row_indices: Sequence[Int32]
|
| 511 |
+
:param num_cta: Number of CTAs participating (default: 1)
|
| 512 |
+
:type num_cta: int
|
| 513 |
+
:param multicast_mask: Optional multicast mask
|
| 514 |
+
:type multicast_mask: Int16
|
| 515 |
+
|
| 516 |
+
Requirements:
|
| 517 |
+
- row_indices must contain exactly 4 elements
|
| 518 |
+
- Compute capability >= SM_100 (Blackwell)
|
| 519 |
+
- TMA descriptor must be properly initialized for 2D tensor
|
| 520 |
+
|
| 521 |
+
Example:
|
| 522 |
+
>>> from cutlass.cute.nvgpu import cpasync
|
| 523 |
+
>>> from cutlass.cute import core
|
| 524 |
+
>>>
|
| 525 |
+
>>> # Create TMA descriptor
|
| 526 |
+
>>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
|
| 527 |
+
>>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
|
| 528 |
+
>>>
|
| 529 |
+
>>> # Compute indices (typically from kernel logic)
|
| 530 |
+
>>> col_idx = core.get(...) or 5 # Int32 value
|
| 531 |
+
>>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
|
| 532 |
+
>>>
|
| 533 |
+
>>> # Gather 4 rows at computed column
|
| 534 |
+
>>> tma_gather4_load(
|
| 535 |
+
... tma_desc_ptr=tma_desc_ptr,
|
| 536 |
+
... dst_smem_ptr=smem_ptr,
|
| 537 |
+
... mbarrier_ptr=barrier_ptr,
|
| 538 |
+
... col_idx=col_idx,
|
| 539 |
+
... row_indices=row_indices
|
| 540 |
+
... )
|
| 541 |
+
"""
|
| 542 |
+
if len(row_indices) != 4:
|
| 543 |
+
raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
|
| 544 |
+
col_val = Int32(col_idx).ir_value()
|
| 545 |
+
row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
|
| 546 |
+
# Convert pointers to integer addresses
|
| 547 |
+
desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 548 |
+
dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 549 |
+
mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
|
| 550 |
+
if num_cta > 1:
|
| 551 |
+
# Executed by both CTAs. Set peer bit to 0 so that the
|
| 552 |
+
# transaction bytes will update CTA0's barrier.
|
| 553 |
+
mbar_addr = mbar_addr & Sm100MmaPeerBitMask
|
| 554 |
+
mbar_addr = mbar_addr.ir_value()
|
| 555 |
+
# Handle multicast_mask - may already be ir.Value or Python int
|
| 556 |
+
multicast_mask_val = None
|
| 557 |
+
if multicast_mask is not None:
|
| 558 |
+
multicast_mask_val = Int16(multicast_mask).ir_value()
|
| 559 |
+
assert multicast_mask_val is None, "multicast is not supported yet"
|
| 560 |
+
# Emit inline PTX for TMA gather4
|
| 561 |
+
# PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
|
| 562 |
+
# [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
|
| 563 |
+
ptx = (
|
| 564 |
+
f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
|
| 565 |
+
"[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
llvm.inline_asm(
|
| 569 |
+
None,
|
| 570 |
+
[
|
| 571 |
+
dst_addr,
|
| 572 |
+
desc_addr,
|
| 573 |
+
col_val,
|
| 574 |
+
row_vals[0],
|
| 575 |
+
row_vals[1],
|
| 576 |
+
row_vals[2],
|
| 577 |
+
row_vals[3],
|
| 578 |
+
mbar_addr,
|
| 579 |
+
],
|
| 580 |
+
ptx,
|
| 581 |
+
"r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
|
| 582 |
+
has_side_effects=True,
|
| 583 |
+
is_align_stack=False,
|
| 584 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 585 |
+
loc=loc,
|
| 586 |
+
ip=ip,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def cpasync_bulk_get_copy_fn(
|
| 591 |
+
src_tensor: cute.Tensor,
|
| 592 |
+
dst_tensor: cute.Tensor,
|
| 593 |
+
single_stage: bool = False,
|
| 594 |
+
**kwargs,
|
| 595 |
+
) -> Callable:
|
| 596 |
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
| 597 |
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
| 598 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 599 |
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
| 600 |
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
| 601 |
+
|
| 602 |
+
def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 603 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 604 |
+
with cute.arch.elect_one():
|
| 605 |
+
cute.copy(
|
| 606 |
+
atom,
|
| 607 |
+
src[None, src_idx],
|
| 608 |
+
dst[None, dst_idx],
|
| 609 |
+
mbar_ptr=tma_bar_ptr,
|
| 610 |
+
**new_kwargs,
|
| 611 |
+
**kwargs,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
|
| 615 |
+
atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
|
| 616 |
+
with cute.arch.elect_one():
|
| 617 |
+
cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
|
| 618 |
+
|
| 619 |
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def tma_get_copy_fn(
|
| 623 |
+
atom: cute.CopyAtom,
|
| 624 |
+
cta_coord: cute.Coord,
|
| 625 |
+
cta_layout: cute.Layout,
|
| 626 |
+
src_tensor: cute.Tensor,
|
| 627 |
+
dst_tensor: cute.Tensor,
|
| 628 |
+
filter_zeros: bool = False,
|
| 629 |
+
single_stage: bool = False,
|
| 630 |
+
**kwargs,
|
| 631 |
+
) -> Callable:
|
| 632 |
+
src_is_smem = const_expr(
|
| 633 |
+
isinstance(src_tensor.iterator, cute.Pointer)
|
| 634 |
+
and src_tensor.memspace == cute.AddressSpace.smem
|
| 635 |
+
)
|
| 636 |
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
| 637 |
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
| 638 |
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
| 639 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 640 |
+
s, g = cpasync.tma_partition(
|
| 641 |
+
atom,
|
| 642 |
+
cta_coord,
|
| 643 |
+
cta_layout,
|
| 644 |
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 645 |
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 646 |
+
)
|
| 647 |
+
if const_expr(filter_zeros):
|
| 648 |
+
s = cute.filter_zeros(s)
|
| 649 |
+
g = cute.filter_zeros(g)
|
| 650 |
+
src, dst = (s, g) if src_is_smem else (g, s)
|
| 651 |
+
|
| 652 |
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
| 653 |
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
| 654 |
+
|
| 655 |
+
def copy_tma_single_stage(**new_kwargs):
|
| 656 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
| 657 |
+
|
| 658 |
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
| 662 |
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
| 663 |
+
copy(
|
| 664 |
+
src_idx=src_idx,
|
| 665 |
+
dst_idx=producer_state.index,
|
| 666 |
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
| 667 |
+
**new_kwargs,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
return copy_fn
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@cute.jit
|
| 674 |
+
def gather_m_get_copy_fn(
|
| 675 |
+
thr_copy_A: cute.ThrCopy,
|
| 676 |
+
mA: cute.Tensor, # (whatever, K)
|
| 677 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 678 |
+
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
| 679 |
+
limit_m: Int32,
|
| 680 |
+
limit_k: Int32,
|
| 681 |
+
) -> Callable:
|
| 682 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 683 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 684 |
+
# k-major
|
| 685 |
+
assert tAsA.shape[2] == 1
|
| 686 |
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
| 687 |
+
|
| 688 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 689 |
+
if const_expr(not is_even_m_smem):
|
| 690 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 691 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 692 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 693 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 694 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 695 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 696 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 697 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 698 |
+
limit_m = limit_m - tAcA[0][0]
|
| 699 |
+
limit_k = limit_k - tAcA[0][1]
|
| 700 |
+
# Read and cache indices for A
|
| 701 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 702 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 703 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 704 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 705 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 706 |
+
m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
|
| 707 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 708 |
+
row_idx = tAcA[0, m, 0][0]
|
| 709 |
+
if tApA_m[m]:
|
| 710 |
+
m_idx[m] = gsAIdx[row_idx]
|
| 711 |
+
else:
|
| 712 |
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
| 713 |
+
|
| 714 |
+
mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
|
| 715 |
+
|
| 716 |
+
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
| 717 |
+
tApA_k = None
|
| 718 |
+
if const_expr(pred):
|
| 719 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 720 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 721 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 722 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 723 |
+
mA_cur = mA_k[None, (None, src_idx)]
|
| 724 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 725 |
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
| 726 |
+
# ((elems_per_load), thread_per_row)
|
| 727 |
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
| 728 |
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
| 729 |
+
mA_row = cute.tiled_divide(
|
| 730 |
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
| 731 |
+
)[None, None, 0]
|
| 732 |
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
| 733 |
+
# There's only 1 load per row
|
| 734 |
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
| 735 |
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
| 736 |
+
cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
|
| 737 |
+
|
| 738 |
+
return copy_fn
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
@cute.jit
|
| 742 |
+
def gather_k_get_copy_fn(
|
| 743 |
+
thr_copy_A: cute.ThrCopy,
|
| 744 |
+
mA: cute.Tensor, # (tile_M, whatever)
|
| 745 |
+
sA: cute.Tensor, # (tile_M, tile_K, STAGE)
|
| 746 |
+
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
| 747 |
+
limit_m: Int32,
|
| 748 |
+
limit_k: Int32,
|
| 749 |
+
) -> Callable:
|
| 750 |
+
gAIdx, sAIdx = None, None
|
| 751 |
+
if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
|
| 752 |
+
gAIdx = gsAIdx
|
| 753 |
+
else:
|
| 754 |
+
assert gsAIdx.memspace == cute.AddressSpace.smem
|
| 755 |
+
sAIdx = gsAIdx
|
| 756 |
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
| 757 |
+
# (atom_v, CPY_M, 1, STAGE)
|
| 758 |
+
tAsA = thr_copy_A.partition_D(sA)
|
| 759 |
+
# m-major
|
| 760 |
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
| 761 |
+
|
| 762 |
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
| 763 |
+
if const_expr(not is_even_m_smem):
|
| 764 |
+
limit_m = min(limit_m, tile_shape_mk[0])
|
| 765 |
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
| 766 |
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
| 767 |
+
tAcA = thr_copy_A.partition_S(cA)
|
| 768 |
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
| 769 |
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
| 770 |
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
| 771 |
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
| 772 |
+
limit_m = limit_m - tAcA[0][0]
|
| 773 |
+
limit_k = limit_k - tAcA[0][1]
|
| 774 |
+
# Read and cache indices for A
|
| 775 |
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
| 776 |
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
| 777 |
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
| 778 |
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
| 779 |
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
| 780 |
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
| 781 |
+
# This is very convoluted but idk a better way
|
| 782 |
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
| 783 |
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
| 784 |
+
tidx = thr_copy_A.thr_idx
|
| 785 |
+
tAmA = cute.logical_divide(
|
| 786 |
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
| 787 |
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
| 788 |
+
|
| 789 |
+
def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 790 |
+
# Prefetch mAIdx early, even before smem is free
|
| 791 |
+
tApA_k = None
|
| 792 |
+
if const_expr(pred):
|
| 793 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 794 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 795 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 796 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 797 |
+
gAIdx_cur = gAIdx[None, src_idx]
|
| 798 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 799 |
+
for k in cutlass.range(cols_per_thread):
|
| 800 |
+
col_idx = tAcA[0, 0, k][1]
|
| 801 |
+
if const_expr(not pred):
|
| 802 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 803 |
+
else:
|
| 804 |
+
if tApA_k[k]:
|
| 805 |
+
k_idx[k] = gAIdx_cur[col_idx]
|
| 806 |
+
else:
|
| 807 |
+
k_idx[k] = -1
|
| 808 |
+
return k_idx, tApA_k
|
| 809 |
+
|
| 810 |
+
def prefetch_from_smem_fn(
|
| 811 |
+
a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
|
| 812 |
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
| 813 |
+
tApA_k = None
|
| 814 |
+
if const_expr(pred):
|
| 815 |
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
| 816 |
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
| 817 |
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
| 818 |
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
| 819 |
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
| 820 |
+
sAIdx_cur = sAIdx[None, dst_idx]
|
| 821 |
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
| 822 |
+
for k in cutlass.range(cols_per_thread):
|
| 823 |
+
col_idx = tAcA[0, 0, k][1]
|
| 824 |
+
k_idx[k] = sAIdx_cur[col_idx]
|
| 825 |
+
cute.arch.sync_warp()
|
| 826 |
+
with cute.arch.elect_one():
|
| 827 |
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
| 828 |
+
return k_idx, tApA_k
|
| 829 |
+
|
| 830 |
+
def copy_fn(
|
| 831 |
+
src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
|
| 832 |
+
):
|
| 833 |
+
k_idx, tApA_k = k_idx_tApA_k
|
| 834 |
+
tApA_k_pred = None
|
| 835 |
+
if const_expr(pred):
|
| 836 |
+
tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
| 837 |
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
| 838 |
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
| 839 |
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
| 840 |
+
if tApA_m[m]:
|
| 841 |
+
cute.copy(
|
| 842 |
+
thr_copy_A,
|
| 843 |
+
tAmA[None, m, k_idx[k]],
|
| 844 |
+
tAsA[(None, m, k), dst_idx],
|
| 845 |
+
pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
| 849 |
+
gAIdx is not None
|
| 850 |
+
) else prefetch_from_smem_fn
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
@cute.jit
|
| 854 |
+
def gather_m_get_tma_copy_fn(
|
| 855 |
+
tma_atom: cute.CopyAtom,
|
| 856 |
+
mA: cute.Tensor, # (whatever, K)
|
| 857 |
+
sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
|
| 858 |
+
sAIdx: cute.Tensor, # (tile_M),
|
| 859 |
+
warp_idx: Int32,
|
| 860 |
+
num_warps: int,
|
| 861 |
+
num_cta: int = 1,
|
| 862 |
+
) -> Callable:
|
| 863 |
+
tile_M = cute.size(sAIdx, mode=[0])
|
| 864 |
+
tile_K = cute.size(sA[None, None, 0]) // tile_M
|
| 865 |
+
assert tile_M % 4 == 0
|
| 866 |
+
# cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
|
| 867 |
+
cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
|
| 868 |
+
|
| 869 |
+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
|
| 870 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
|
| 871 |
+
cute.make_layout(num_warps), # thr_layout
|
| 872 |
+
cute.make_layout(4), # val_layout
|
| 873 |
+
)
|
| 874 |
+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
|
| 875 |
+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
|
| 876 |
+
# ((4, 1), 8, (64, 1), STAGE)
|
| 877 |
+
tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
|
| 878 |
+
tSR_rAIdx = load_s2r(tSR_sAIdx)
|
| 879 |
+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
|
| 880 |
+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
|
| 881 |
+
|
| 882 |
+
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
|
| 883 |
+
col_idx = tile_K * src_idx
|
| 884 |
+
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
|
| 885 |
+
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
|
| 886 |
+
smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
|
| 887 |
+
with cute.arch.elect_one():
|
| 888 |
+
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
|
| 889 |
+
|
| 890 |
+
return copy_fn
|
build/torch211-cxx11-cu128-x86_64-linux/quack/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Tri Dao.
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from dataclasses import dataclass, fields
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from triton.tools.disasm import extract
|
| 11 |
+
except ImportError:
|
| 12 |
+
extract = None
|
| 13 |
+
|
| 14 |
+
import cutlass
|
| 15 |
+
import cutlass.cute as cute
|
| 16 |
+
from cutlass import Int32, Int64, Float16, BFloat16, Float32
|
| 17 |
+
from cutlass.base_dsl.typing import JitArgument
|
| 18 |
+
from cutlass.cutlass_dsl import NumericMeta
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
| 25 |
+
cute_compile_og = cute.compile
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
torch2cute_dtype_map = {
|
| 29 |
+
torch.float16: Float16,
|
| 30 |
+
torch.bfloat16: BFloat16,
|
| 31 |
+
torch.float32: Float32,
|
| 32 |
+
torch.int32: Int32,
|
| 33 |
+
torch.int64: Int64,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lru_cache
|
| 38 |
+
def get_max_active_clusters(cluster_size):
|
| 39 |
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@lru_cache
|
| 43 |
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 44 |
+
return torch.cuda.get_device_capability(device)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class ParamsBase:
|
| 49 |
+
def __extract_mlir_values__(self):
|
| 50 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 51 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 52 |
+
values, self._values_pos = [], []
|
| 53 |
+
for obj in non_constexpr_fields:
|
| 54 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 55 |
+
values += obj_values
|
| 56 |
+
self._values_pos.append(len(obj_values))
|
| 57 |
+
return values
|
| 58 |
+
|
| 59 |
+
def __new_from_mlir_values__(self, values):
|
| 60 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 61 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 62 |
+
non_constexpr_fields = {
|
| 63 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 64 |
+
}
|
| 65 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 66 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 67 |
+
values = values[n_items:]
|
| 68 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class ArgumentsBase(JitArgument):
|
| 73 |
+
def __c_pointers__(self):
|
| 74 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 75 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 76 |
+
c_ptrs = []
|
| 77 |
+
for obj in non_constexpr_fields:
|
| 78 |
+
if hasattr(obj, "__c_pointers__"):
|
| 79 |
+
c_ptrs.extend(obj.__c_pointers__())
|
| 80 |
+
return c_ptrs
|
| 81 |
+
|
| 82 |
+
def __get_mlir_types__(self):
|
| 83 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 84 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 85 |
+
types, self._values_pos = [], []
|
| 86 |
+
for obj in non_constexpr_fields:
|
| 87 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 88 |
+
obj_types = obj.__get_mlir_types__()
|
| 89 |
+
types.extend(obj_types)
|
| 90 |
+
self._values_pos.append(len(obj_types))
|
| 91 |
+
else:
|
| 92 |
+
self._values_pos.append(0)
|
| 93 |
+
return types
|
| 94 |
+
|
| 95 |
+
def __new_from_mlir_values__(self, values):
|
| 96 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 97 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 98 |
+
non_constexpr_fields = {
|
| 99 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 100 |
+
}
|
| 101 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 102 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 103 |
+
values = values[n_items:]
|
| 104 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
build/torch211-cxx11-cu128-x86_64-linux/quack/layout_utils.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
|
| 7 |
+
from cutlass import Int32, const_expr
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
| 11 |
+
"""Transpose the first two dimensions of a tensor on smem."""
|
| 12 |
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
| 13 |
+
order = (1, 0, *range(2, cute.rank(a)))
|
| 14 |
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
| 18 |
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
| 22 |
+
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
| 23 |
+
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
| 24 |
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@cute.jit
|
| 28 |
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
| 29 |
+
assert t.element_type.width == 16
|
| 30 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
| 31 |
+
t_u32 = cute.recast_tensor(t, Int32)
|
| 32 |
+
|
| 33 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 34 |
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
| 35 |
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
| 36 |
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
| 37 |
+
# upper_map = [0, 3, 1, 2]
|
| 38 |
+
# lower_map = [1, 2, 0, 3]
|
| 39 |
+
# upper_idx = upper_map[quad_idx]
|
| 40 |
+
# indexing isn't supported so we have to do arithmetic
|
| 41 |
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
| 42 |
+
lower_idx = upper_idx ^ 1
|
| 43 |
+
|
| 44 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 45 |
+
width = 4
|
| 46 |
+
mask = cute.arch.WARP_SIZE - width
|
| 47 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 48 |
+
mask_and_clamp = mask << 8 | clamp
|
| 49 |
+
|
| 50 |
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
| 51 |
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
| 52 |
+
upper0 = upper if lane_03 else lower
|
| 53 |
+
lower0 = lower if lane_03 else upper
|
| 54 |
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
| 55 |
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
| 56 |
+
t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
|
| 57 |
+
t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@cute.jit
|
| 61 |
+
def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
|
| 62 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 63 |
+
T0 | T1 | T2 | T3
|
| 64 |
+
a b | c d | e f | g h
|
| 65 |
+
to
|
| 66 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 67 |
+
a | b | c | d | e | f | g | h
|
| 68 |
+
This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
assert t.element_type.width == 32
|
| 72 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 73 |
+
|
| 74 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 75 |
+
# left_map = [0, 2, 1, 3]
|
| 76 |
+
# right_map = [2, 0, 3, 1]
|
| 77 |
+
# indexing isn't supported so we have to do arithmetic
|
| 78 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 79 |
+
right_idx = left_idx ^ 0b10
|
| 80 |
+
|
| 81 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 82 |
+
width = 4
|
| 83 |
+
mask = cute.arch.WARP_SIZE - width
|
| 84 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 85 |
+
mask_and_clamp = mask << 8 | clamp
|
| 86 |
+
|
| 87 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 88 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 89 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 90 |
+
# a b | c d | e f | g h -> a b | c d | f e | h g
|
| 91 |
+
left0 = left if quad_idx < 2 else right
|
| 92 |
+
right0 = right if quad_idx < 2 else left
|
| 93 |
+
# a b | c d | f e | h g -> a b | f d | c e | h g
|
| 94 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 95 |
+
# a b | f d | c e | h g -> a e | f b | c g | h d
|
| 96 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 97 |
+
# a e | f b | c g | h d -> a e | b f | c g | d h
|
| 98 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
|
| 99 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
|
| 100 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@cute.jit
|
| 104 |
+
def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
|
| 105 |
+
"""Permute and shuffle within 4 threads to change the layout from
|
| 106 |
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
| 107 |
+
a | b | c | d | e | f | g | h
|
| 108 |
+
to
|
| 109 |
+
T0 | T1 | T2 | T3
|
| 110 |
+
a b | c d | e f | g h
|
| 111 |
+
This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
assert t.element_type.width == 32
|
| 115 |
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
| 116 |
+
|
| 117 |
+
quad_idx = cute.arch.lane_idx() % 4
|
| 118 |
+
# left_map = [0, 2, 1, 3]
|
| 119 |
+
# right_map = [1, 3, 0, 2]
|
| 120 |
+
# indexing isn't supported so we have to do arithmetic
|
| 121 |
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
| 122 |
+
right_idx = left_idx ^ 0b01
|
| 123 |
+
|
| 124 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 125 |
+
width = 4
|
| 126 |
+
mask = cute.arch.WARP_SIZE - width
|
| 127 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 128 |
+
mask_and_clamp = mask << 8 | clamp
|
| 129 |
+
|
| 130 |
+
# This is just the inverse of permute_Cregs_b32_for_stsm
|
| 131 |
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
| 132 |
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
| 133 |
+
for r in cutlass.range(2, unroll_full=True):
|
| 134 |
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
| 135 |
+
# a e | b f | c g | d h -> a e | f b | c g | h d
|
| 136 |
+
left0 = left if quad_idx % 2 == 0 else right
|
| 137 |
+
right0 = right if quad_idx % 2 == 0 else left
|
| 138 |
+
# a e | f b | c g | h d -> a b | f d | c e | h g
|
| 139 |
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
| 140 |
+
# a b | f d | c e | h g -> a b | c d | f e | h g
|
| 141 |
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
| 142 |
+
# a b | c d | f e | h g -> a b | c d | e f | g h
|
| 143 |
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
|
| 144 |
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@cute.jit
|
| 148 |
+
def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
| 149 |
+
return cute.make_layout(
|
| 150 |
+
tuple(l.shape for l in layouts),
|
| 151 |
+
stride=tuple(l.stride for l in layouts),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
|
| 156 |
+
"""
|
| 157 |
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
| 158 |
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
| 159 |
+
"""
|
| 160 |
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 161 |
+
shape = (
|
| 162 |
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
| 163 |
+
(
|
| 164 |
+
acc_layout_col_major.shape[0][0],
|
| 165 |
+
*acc_layout_col_major.shape[0][2:],
|
| 166 |
+
acc_layout_col_major.shape[2],
|
| 167 |
+
), # MMA_N
|
| 168 |
+
*acc_layout_col_major.shape[3:],
|
| 169 |
+
)
|
| 170 |
+
stride = (
|
| 171 |
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
| 172 |
+
(
|
| 173 |
+
acc_layout_col_major.stride[0][0],
|
| 174 |
+
*acc_layout_col_major.stride[0][2:],
|
| 175 |
+
acc_layout_col_major.stride[2],
|
| 176 |
+
), # MMA_N
|
| 177 |
+
*acc_layout_col_major.stride[3:],
|
| 178 |
+
)
|
| 179 |
+
if const_expr(transpose):
|
| 180 |
+
shape = (shape[1], shape[0], *shape[2:])
|
| 181 |
+
stride = (stride[1], stride[0], *stride[2:])
|
| 182 |
+
acc_layout_mn = cute.make_layout(shape, stride=stride)
|
| 183 |
+
return cute.composition(acc_layout, acc_layout_mn)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 187 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
| 191 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@cute.jit
|
| 195 |
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
| 196 |
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
| 197 |
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
| 198 |
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
| 199 |
+
# TODO: Sm90 FP8
|
| 200 |
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
| 201 |
+
l = cute.logical_divide(
|
| 202 |
+
acc_layout, ((None, None, 2), None, None)
|
| 203 |
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
| 204 |
+
rA_mma_view = cute.make_layout(
|
| 205 |
+
(
|
| 206 |
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
| 207 |
+
l.shape[1],
|
| 208 |
+
(l.shape[0][2][1], l.shape[2]),
|
| 209 |
+
),
|
| 210 |
+
stride=(
|
| 211 |
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
| 212 |
+
l.stride[1],
|
| 213 |
+
(l.stride[0][2][1], l.stride[2]),
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
else: # Sm80
|
| 217 |
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
| 218 |
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
| 219 |
+
rA_mma_view = cute.make_layout(
|
| 220 |
+
(
|
| 221 |
+
(l.shape[0], l.shape[2][0]),
|
| 222 |
+
l.shape[1],
|
| 223 |
+
l.shape[2][1],
|
| 224 |
+
),
|
| 225 |
+
stride=(
|
| 226 |
+
(l.stride[0], l.stride[2][0]),
|
| 227 |
+
l.stride[1],
|
| 228 |
+
l.stride[2][1],
|
| 229 |
+
),
|
| 230 |
+
)
|
| 231 |
+
return rA_mma_view
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
|
| 235 |
+
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def convert_layout_zero_stride(
|
| 239 |
+
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
| 240 |
+
) -> cute.Layout:
|
| 241 |
+
layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
|
| 242 |
+
# Group the modes with non-zero stride in the ref_layout together,
|
| 243 |
+
# and the modes with zero stride together
|
| 244 |
+
layout_flat = cute.flatten(layout)
|
| 245 |
+
ref_layout_flat = cute.flatten(ref_layout)
|
| 246 |
+
nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
|
| 247 |
+
zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
|
| 248 |
+
# There's an edge case when all modes are zero stride
|
| 249 |
+
new_shape = (
|
| 250 |
+
tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
|
| 251 |
+
tuple(layout_flat[i].shape for i in zero_modes),
|
| 252 |
+
)
|
| 253 |
+
new_stride = (
|
| 254 |
+
tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
|
| 255 |
+
tuple(layout_flat[i].stride for i in zero_modes),
|
| 256 |
+
)
|
| 257 |
+
out_layout = cute.make_layout(new_shape, stride=new_stride)
|
| 258 |
+
if const_expr(isinstance(input, cute.Tensor)):
|
| 259 |
+
return cute.make_tensor(input.iterator, out_layout)
|
| 260 |
+
else:
|
| 261 |
+
return out_layout
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def mma_partition_C_vec(
|
| 265 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 266 |
+
) -> cute.Tensor:
|
| 267 |
+
assert cute.rank(sVec) == 2
|
| 268 |
+
assert sVec.stride[0] == 1
|
| 269 |
+
stage = sVec.shape[1]
|
| 270 |
+
shape = (
|
| 271 |
+
(sVec.shape[0], expand_shape, stage)
|
| 272 |
+
if const_expr(is_colvec)
|
| 273 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 274 |
+
)
|
| 275 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 276 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 277 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
|
| 278 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def mma_partition_A_vec(
|
| 282 |
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
| 283 |
+
) -> cute.Tensor:
|
| 284 |
+
assert cute.rank(sVec) == 2
|
| 285 |
+
assert sVec.stride[0] == 1
|
| 286 |
+
stage = sVec.shape[1]
|
| 287 |
+
shape = (
|
| 288 |
+
(sVec.shape[0], expand_shape, stage)
|
| 289 |
+
if const_expr(is_colvec)
|
| 290 |
+
else (expand_shape, sVec.shape[0], stage)
|
| 291 |
+
)
|
| 292 |
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
| 293 |
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
| 294 |
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
| 295 |
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
build/torch211-cxx11-cu128-x86_64-linux/quantize.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Transformer Engine NVFP4 quantization helper.
|
| 5 |
+
|
| 6 |
+
This file is intended as a customer-facing example for preparing KV tensors
|
| 7 |
+
for the KVFP4 attention kernel:
|
| 8 |
+
- BF16/FP16 K/V input
|
| 9 |
+
- packed E2M1 FP4 data from Transformer Engine
|
| 10 |
+
- E4M3 block scales in cuBLAS/cuDNN 128x4 tiled layout
|
| 11 |
+
- one FP32 tensor/global scale per tensor
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
NVFP4_BLOCK_SIZE = 16
|
| 23 |
+
NVFP4_FP4_MAX = 6.0
|
| 24 |
+
NVFP4_FP8_E4M3_MAX = 448.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class Nvfp4QuantizedTensor:
|
| 29 |
+
"""Packed NVFP4 tensor plus dequantization metadata.
|
| 30 |
+
|
| 31 |
+
Attributes
|
| 32 |
+
----------
|
| 33 |
+
data : torch.Tensor
|
| 34 |
+
Packed E2M1 FP4 data from Transformer Engine. The last dimension is
|
| 35 |
+
half of the original logical last dimension because each byte stores
|
| 36 |
+
two FP4 values.
|
| 37 |
+
scale_128x4 : torch.Tensor
|
| 38 |
+
E4M3 block scales in cuBLAS/cuDNN 128x4 tiled rowwise storage.
|
| 39 |
+
global_scale : torch.Tensor
|
| 40 |
+
FP32 tensor/global dequant scale.
|
| 41 |
+
logical_scale_shape : tuple[int, int]
|
| 42 |
+
Logical 2D scale shape ``(rows, cols)`` before 128x4 swizzling.
|
| 43 |
+
original_shape : tuple[int, ...]
|
| 44 |
+
Original BF16/FP16 tensor shape before quantization.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
data: torch.Tensor
|
| 48 |
+
scale_128x4: torch.Tensor
|
| 49 |
+
global_scale: torch.Tensor
|
| 50 |
+
logical_scale_shape: Tuple[int, int]
|
| 51 |
+
original_shape: Tuple[int, ...]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _round_up(x: int, multiple: int) -> int:
|
| 55 |
+
return ((int(x) + multiple - 1) // multiple) * multiple
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def nvfp4_scale_128x4_offset(
|
| 59 |
+
row: torch.Tensor,
|
| 60 |
+
col: torch.Tensor,
|
| 61 |
+
scale_cols: int,
|
| 62 |
+
) -> torch.Tensor:
|
| 63 |
+
"""Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
row : torch.Tensor
|
| 68 |
+
Logical row indices.
|
| 69 |
+
col : torch.Tensor
|
| 70 |
+
Logical scale-column indices.
|
| 71 |
+
scale_cols : int
|
| 72 |
+
Logical number of scale columns before padding to a multiple of 4.
|
| 73 |
+
|
| 74 |
+
Returns
|
| 75 |
+
-------
|
| 76 |
+
torch.Tensor
|
| 77 |
+
Flat offsets into the padded 128x4 tiled storage.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
tiles_n = _round_up(scale_cols, 4) // 4
|
| 81 |
+
tile_m = row // 128
|
| 82 |
+
tile_n = col // 4
|
| 83 |
+
outer = row % 128
|
| 84 |
+
inner = col % 4
|
| 85 |
+
return (
|
| 86 |
+
(tile_m * tiles_n + tile_n) * 512
|
| 87 |
+
+ (outer % 32) * 16
|
| 88 |
+
+ (outer // 32) * 4
|
| 89 |
+
+ inner
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def swizzle_nvfp4_scale_to_128x4(
|
| 94 |
+
scale: torch.Tensor,
|
| 95 |
+
*,
|
| 96 |
+
rows: int,
|
| 97 |
+
cols: int,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout.
|
| 100 |
+
|
| 101 |
+
Parameters
|
| 102 |
+
----------
|
| 103 |
+
scale : torch.Tensor
|
| 104 |
+
Logical rowwise scale tensor with at least shape ``[rows, cols]``.
|
| 105 |
+
rows : int
|
| 106 |
+
Number of logical rows to convert.
|
| 107 |
+
cols : int
|
| 108 |
+
Number of logical scale columns to convert.
|
| 109 |
+
|
| 110 |
+
Returns
|
| 111 |
+
-------
|
| 112 |
+
torch.Tensor
|
| 113 |
+
Scale tensor padded to ``round_up(rows, 128)`` by ``round_up(cols, 4)``
|
| 114 |
+
and swizzled into 128x4 tiled storage.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
if scale.ndim != 2:
|
| 118 |
+
raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}")
|
| 119 |
+
|
| 120 |
+
rows = int(rows)
|
| 121 |
+
cols = int(cols)
|
| 122 |
+
padded_rows = _round_up(rows, 128)
|
| 123 |
+
padded_cols = _round_up(cols, 4)
|
| 124 |
+
if scale.shape[0] < rows or scale.shape[1] < cols:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"scale is smaller than the requested logical shape: "
|
| 127 |
+
f"got {tuple(scale.shape)}, need at least {(rows, cols)}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
logical = scale[:rows, :cols].contiguous()
|
| 131 |
+
if logical.shape != (padded_rows, padded_cols):
|
| 132 |
+
logical = torch.nn.functional.pad(
|
| 133 |
+
logical.to(torch.float32),
|
| 134 |
+
(0, padded_cols - cols, 0, padded_rows - rows),
|
| 135 |
+
).to(scale.dtype)
|
| 136 |
+
swizzled = torch.empty_like(logical)
|
| 137 |
+
|
| 138 |
+
row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None]
|
| 139 |
+
col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :]
|
| 140 |
+
offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1)
|
| 141 |
+
swizzled.reshape(-1)[offset] = logical.reshape(-1)
|
| 142 |
+
return swizzled
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor:
|
| 146 |
+
"""Compute TE NVFP4 tensor/global dequant scale from rowwise amax.
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
amax : torch.Tensor
|
| 151 |
+
Rowwise absolute maxima returned by Transformer Engine.
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
torch.Tensor
|
| 156 |
+
FP32 global scale equal to ``amax / (448 * 6)``.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _import_te_nvfp4_quantizer():
|
| 163 |
+
try:
|
| 164 |
+
from transformer_engine.pytorch.tensor import NVFP4Quantizer
|
| 165 |
+
except Exception as exc: # pragma: no cover - environment dependent
|
| 166 |
+
raise RuntimeError(
|
| 167 |
+
"Transformer Engine NVFP4 quantization is unavailable. Install a "
|
| 168 |
+
"Transformer Engine build with its PyTorch dependencies, including "
|
| 169 |
+
"FlashAttention v3 when required by that TE build."
|
| 170 |
+
) from exc
|
| 171 |
+
return NVFP4Quantizer
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor:
|
| 175 |
+
"""Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine.
|
| 176 |
+
|
| 177 |
+
TE returns rowwise scales in logical padded layout. This helper returns
|
| 178 |
+
the scales in physical 128x4 tiled storage, so the attention kernel can
|
| 179 |
+
load them with ``nvfp4_scale_128x4_offset``.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
x : torch.Tensor
|
| 184 |
+
CUDA BF16 or FP16 tensor. The last dimension must be divisible by 16,
|
| 185 |
+
and the flattened row dimension ``prod(x.shape[:-1])`` must also be
|
| 186 |
+
divisible by 16.
|
| 187 |
+
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
Nvfp4QuantizedTensor
|
| 191 |
+
Packed FP4 data, 128x4-swizzled block scales, global scale, and shape
|
| 192 |
+
metadata needed by the KVFP4 attention kernel or by reference
|
| 193 |
+
dequantization.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
if not x.is_cuda:
|
| 197 |
+
raise ValueError("NVFP4 quantization requires a CUDA tensor")
|
| 198 |
+
if x.dtype not in (torch.bfloat16, torch.float16):
|
| 199 |
+
raise TypeError(f"x must be bf16 or fp16, got {x.dtype}")
|
| 200 |
+
if x.ndim < 2:
|
| 201 |
+
raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}")
|
| 202 |
+
if x.shape[-1] % NVFP4_BLOCK_SIZE != 0:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
rows = 1
|
| 208 |
+
for dim in x.shape[:-1]:
|
| 209 |
+
rows *= int(dim)
|
| 210 |
+
if rows % NVFP4_BLOCK_SIZE != 0:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"flattened row dimension must be divisible by "
|
| 213 |
+
f"{NVFP4_BLOCK_SIZE}, got {rows}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
NVFP4Quantizer = _import_te_nvfp4_quantizer()
|
| 217 |
+
quantizer = NVFP4Quantizer(rowwise=True, columnwise=False)
|
| 218 |
+
qx = quantizer.quantize(x.contiguous())
|
| 219 |
+
meta = qx.get_metadata()
|
| 220 |
+
|
| 221 |
+
data = meta["rowwise_data"]
|
| 222 |
+
if data.dtype is not torch.uint8:
|
| 223 |
+
data = data.view(torch.uint8)
|
| 224 |
+
logical_scale = meta["rowwise_scale_inv"]
|
| 225 |
+
amax = meta["amax_rowwise"]
|
| 226 |
+
scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE
|
| 227 |
+
scale_128x4 = swizzle_nvfp4_scale_to_128x4(
|
| 228 |
+
logical_scale,
|
| 229 |
+
rows=rows,
|
| 230 |
+
cols=scale_cols,
|
| 231 |
+
)
|
| 232 |
+
global_scale = nvfp4_global_scale_from_amax(amax).contiguous()
|
| 233 |
+
|
| 234 |
+
return Nvfp4QuantizedTensor(
|
| 235 |
+
data=data,
|
| 236 |
+
scale_128x4=scale_128x4,
|
| 237 |
+
global_scale=global_scale,
|
| 238 |
+
logical_scale_shape=(rows, scale_cols),
|
| 239 |
+
original_shape=tuple(int(v) for v in x.shape),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def quantize_kv_bf16_to_nvfp4_128x4(
|
| 244 |
+
k: torch.Tensor,
|
| 245 |
+
v: torch.Tensor,
|
| 246 |
+
) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]:
|
| 247 |
+
"""Quantize BF16/FP16 K and V tensors independently for KVFP4 attention.
|
| 248 |
+
|
| 249 |
+
Parameters
|
| 250 |
+
----------
|
| 251 |
+
k : torch.Tensor
|
| 252 |
+
CUDA BF16 or FP16 K tensor.
|
| 253 |
+
v : torch.Tensor
|
| 254 |
+
CUDA BF16 or FP16 V tensor.
|
| 255 |
+
|
| 256 |
+
Returns
|
| 257 |
+
-------
|
| 258 |
+
tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]
|
| 259 |
+
Quantized K and V tensors with independent scales.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def dequantize_nvfp4_128x4_to_bf16(
|
| 266 |
+
qx: Nvfp4QuantizedTensor,
|
| 267 |
+
*,
|
| 268 |
+
include_global_scale: bool = True,
|
| 269 |
+
) -> torch.Tensor:
|
| 270 |
+
"""Reference dequantization for validation.
|
| 271 |
+
|
| 272 |
+
This mirrors the kernel contract:
|
| 273 |
+
x = e2m1 * E4M3_block_scale_1x16 * FP32_global_scale
|
| 274 |
+
|
| 275 |
+
Parameters
|
| 276 |
+
----------
|
| 277 |
+
qx : Nvfp4QuantizedTensor
|
| 278 |
+
Quantized tensor returned by ``quantize_bf16_to_nvfp4_128x4``.
|
| 279 |
+
include_global_scale : bool, optional
|
| 280 |
+
If True, multiply by ``qx.global_scale`` after applying per-block
|
| 281 |
+
scales.
|
| 282 |
+
|
| 283 |
+
Returns
|
| 284 |
+
-------
|
| 285 |
+
torch.Tensor
|
| 286 |
+
BF16 tensor with shape ``qx.original_shape``.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
data = qx.data if qx.data.dtype is torch.uint8 else qx.data.view(torch.uint8)
|
| 290 |
+
if data.shape[-1] * 2 != qx.original_shape[-1]:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"packed data last dimension does not match original shape: "
|
| 293 |
+
f"{data.shape[-1]} packed vs {qx.original_shape[-1]} logical"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
rows, scale_cols = qx.logical_scale_shape
|
| 297 |
+
logical_dim = int(qx.original_shape[-1])
|
| 298 |
+
if scale_cols * NVFP4_BLOCK_SIZE != logical_dim:
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"logical scale columns do not match original last dimension: "
|
| 301 |
+
f"{scale_cols} scale cols vs dim {logical_dim}"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
fp4_lut = torch.tensor(
|
| 305 |
+
[
|
| 306 |
+
0.0,
|
| 307 |
+
0.5,
|
| 308 |
+
1.0,
|
| 309 |
+
1.5,
|
| 310 |
+
2.0,
|
| 311 |
+
3.0,
|
| 312 |
+
4.0,
|
| 313 |
+
6.0,
|
| 314 |
+
-0.0,
|
| 315 |
+
-0.5,
|
| 316 |
+
-1.0,
|
| 317 |
+
-1.5,
|
| 318 |
+
-2.0,
|
| 319 |
+
-3.0,
|
| 320 |
+
-4.0,
|
| 321 |
+
-6.0,
|
| 322 |
+
],
|
| 323 |
+
dtype=torch.float32,
|
| 324 |
+
device=data.device,
|
| 325 |
+
)
|
| 326 |
+
packed = data.reshape(rows, logical_dim // 2)
|
| 327 |
+
lo = packed & 0x0F
|
| 328 |
+
hi = packed >> 4
|
| 329 |
+
values = torch.empty((rows, logical_dim), dtype=torch.float32, device=data.device)
|
| 330 |
+
values[:, 0::2] = fp4_lut[lo.long()]
|
| 331 |
+
values[:, 1::2] = fp4_lut[hi.long()]
|
| 332 |
+
|
| 333 |
+
row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None]
|
| 334 |
+
col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :]
|
| 335 |
+
offset = nvfp4_scale_128x4_offset(row, col, scale_cols)
|
| 336 |
+
scale_u8 = qx.scale_128x4.reshape(-1)[offset.reshape(-1)].reshape(rows, scale_cols)
|
| 337 |
+
scale = scale_u8.view(torch.float8_e4m3fn).to(torch.float32)
|
| 338 |
+
scale = scale.repeat_interleave(NVFP4_BLOCK_SIZE, dim=1)
|
| 339 |
+
out = values * scale
|
| 340 |
+
if include_global_scale:
|
| 341 |
+
global_scale = qx.global_scale.reshape(-1)[0].to(torch.float32)
|
| 342 |
+
out = out * global_scale
|
| 343 |
+
return out.reshape(qx.original_shape).to(torch.bfloat16)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _example() -> None:
|
| 347 |
+
device = torch.device("cuda")
|
| 348 |
+
k = torch.randn(128, 2, 128, device=device, dtype=torch.bfloat16)
|
| 349 |
+
v = torch.randn_like(k)
|
| 350 |
+
k_q, v_q = quantize_kv_bf16_to_nvfp4_128x4(k, v)
|
| 351 |
+
print("K FP4 data:", tuple(k_q.data.shape), k_q.data.dtype)
|
| 352 |
+
print("K scale 128x4:", tuple(k_q.scale_128x4.shape), k_q.scale_128x4.dtype)
|
| 353 |
+
print("K global scale:", tuple(k_q.global_scale.shape), k_q.global_scale.dtype)
|
| 354 |
+
print("V FP4 data:", tuple(v_q.data.shape), v_q.data.dtype)
|
| 355 |
+
print("V scale 128x4:", tuple(v_q.scale_128x4.shape), v_q.scale_128x4.dtype)
|
| 356 |
+
print("V global scale:", tuple(v_q.global_scale.shape), v_q.global_scale.dtype)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
if not torch.cuda.is_available():
|
| 361 |
+
raise RuntimeError("quantize.py requires CUDA")
|
| 362 |
+
_example()
|
build/torch211-cxx11-cu128-x86_64-linux/sparse_index_utils.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Host-side q2k <-> k2q index conversion for sparse attention.
|
| 5 |
+
|
| 6 |
+
These utilities prepare sparse metadata on the Python side for tests,
|
| 7 |
+
benchmarks, and other offline preprocessing flows. They are not kernel
|
| 8 |
+
runtime helpers, so they intentionally live outside `src/common`.
|
| 9 |
+
|
| 10 |
+
Sparse attention pattern:
|
| 11 |
+
- Each Q token independently selects up to topK KV blocks (blk_kv tokens each).
|
| 12 |
+
- Under GQA, all Q heads in one group share the same sparsity pattern,
|
| 13 |
+
so indices are defined at the head_kv level.
|
| 14 |
+
|
| 15 |
+
Shapes:
|
| 16 |
+
q2k_indices: [batch, head_kv, Sq, topK] int32, valid values in [0, num_kv_blocks),
|
| 17 |
+
trailing unused slots padded with -1
|
| 18 |
+
k2q_indices: [batch, head_kv, Nkv, Sq] int32, padded with -1
|
| 19 |
+
k2q_counts: [batch, head_kv, Nkv] int32
|
| 20 |
+
|
| 21 |
+
CSR reverse-index format:
|
| 22 |
+
q2k_indices: [head_kv, total_q, topK] int32, values are batch-local kv_block indices
|
| 23 |
+
k2q_row_ptr: [head_kv, total_rows + 1] int32
|
| 24 |
+
k2q_q_indices: [head_kv, total_q * topK] int32, values are batch-local q_idx
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from typing import Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
from .src.sm100.prepare_k2q_csr import SparseK2qCsrBuilderSm100
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def q2k_to_k2q(
|
| 35 |
+
q2k_indices: torch.Tensor,
|
| 36 |
+
num_kv_blocks: int,
|
| 37 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 38 |
+
"""Convert q2k sparse indices to k2q representation.
|
| 39 |
+
|
| 40 |
+
For each KV block, find which Q tokens attend to it.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
q2k_indices: [batch, head_kv, Sq, topK] int32.
|
| 44 |
+
For each Q token, the KV blocks it attends to. Unused slots must
|
| 45 |
+
be padded with -1.
|
| 46 |
+
num_kv_blocks: Total number of KV blocks (= Skv / blk_kv).
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
k2q_indices: [batch, head_kv, num_kv_blocks, Sq] int32.
|
| 50 |
+
For each KV block, the Q token indices that attend to it,
|
| 51 |
+
left-packed and padded with -1. Last dim fixed to Sq (upper bound).
|
| 52 |
+
k2q_counts: [batch, head_kv, num_kv_blocks] int32.
|
| 53 |
+
Actual number of Q tokens per KV block.
|
| 54 |
+
"""
|
| 55 |
+
B, H, Sq, topK = q2k_indices.shape
|
| 56 |
+
device = q2k_indices.device
|
| 57 |
+
N = Sq * topK
|
| 58 |
+
|
| 59 |
+
kv_flat = q2k_indices.reshape(B, H, N).long()
|
| 60 |
+
valid_flat = kv_flat >= 0
|
| 61 |
+
q_flat = (
|
| 62 |
+
torch.arange(Sq, device=device)
|
| 63 |
+
.unsqueeze(-1)
|
| 64 |
+
.expand(Sq, topK)
|
| 65 |
+
.reshape(N)
|
| 66 |
+
.unsqueeze(0)
|
| 67 |
+
.unsqueeze(0)
|
| 68 |
+
.expand(B, H, N)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
k2q_counts = torch.zeros(B, H, num_kv_blocks, dtype=torch.int32, device=device)
|
| 72 |
+
safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat))
|
| 73 |
+
k2q_counts.scatter_add_(
|
| 74 |
+
2,
|
| 75 |
+
safe_kv_flat,
|
| 76 |
+
valid_flat.to(torch.int32),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
sort_keys = torch.where(
|
| 80 |
+
valid_flat,
|
| 81 |
+
kv_flat,
|
| 82 |
+
torch.full_like(kv_flat, num_kv_blocks),
|
| 83 |
+
)
|
| 84 |
+
sorted_kv, sort_idx = sort_keys.sort(dim=-1, stable=True)
|
| 85 |
+
sorted_q = q_flat.gather(-1, sort_idx)
|
| 86 |
+
sorted_valid = valid_flat.gather(-1, sort_idx)
|
| 87 |
+
|
| 88 |
+
offsets = torch.zeros(B, H, num_kv_blocks, dtype=torch.int64, device=device)
|
| 89 |
+
offsets[:, :, 1:] = k2q_counts[:, :, :-1].cumsum(dim=-1).long()
|
| 90 |
+
|
| 91 |
+
global_pos = torch.arange(N, device=device).unsqueeze(0).unsqueeze(0).expand(B, H, N)
|
| 92 |
+
group_offset = offsets.gather(2, sorted_kv.clamp(max=num_kv_blocks - 1))
|
| 93 |
+
pos_in_group = global_pos - group_offset
|
| 94 |
+
|
| 95 |
+
k2q_indices = torch.full(
|
| 96 |
+
(B, H, num_kv_blocks, Sq), -1, dtype=torch.int32, device=device
|
| 97 |
+
)
|
| 98 |
+
flat_k2q = k2q_indices.reshape(B, H, -1)
|
| 99 |
+
flat_idx = sorted_kv.clamp(max=num_kv_blocks - 1) * Sq + pos_in_group
|
| 100 |
+
for b in range(B):
|
| 101 |
+
for h in range(H):
|
| 102 |
+
valid = sorted_valid[b, h]
|
| 103 |
+
flat_k2q[b, h, flat_idx[b, h, valid]] = sorted_q[b, h, valid].int()
|
| 104 |
+
|
| 105 |
+
return k2q_indices, k2q_counts
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def k2q_to_q2k(
|
| 109 |
+
k2q_indices: torch.Tensor,
|
| 110 |
+
k2q_counts: torch.Tensor,
|
| 111 |
+
Sq: int,
|
| 112 |
+
topK: int,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
"""Convert dense k2q indices back to q2k representation.
|
| 115 |
+
|
| 116 |
+
Parameters
|
| 117 |
+
----------
|
| 118 |
+
k2q_indices : torch.Tensor
|
| 119 |
+
Shape ``[batch, head_kv, num_kv_blocks, Sq]`` and dtype int32. Values
|
| 120 |
+
are Q token indices padded with ``-1``.
|
| 121 |
+
k2q_counts : torch.Tensor
|
| 122 |
+
Shape ``[batch, head_kv, num_kv_blocks]`` and dtype int32. Number of
|
| 123 |
+
valid Q indices per KV block.
|
| 124 |
+
Sq : int
|
| 125 |
+
Q sequence length per batch item in this dense reference format.
|
| 126 |
+
topK : int
|
| 127 |
+
Maximum number of KV blocks selected per Q token.
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
torch.Tensor
|
| 132 |
+
Shape ``[batch, head_kv, Sq, topK]``, dtype int32. Entries are sorted
|
| 133 |
+
by KV block index with ``-1`` padding at the tail.
|
| 134 |
+
"""
|
| 135 |
+
B, H, Nkv, _ = k2q_indices.shape
|
| 136 |
+
device = k2q_indices.device
|
| 137 |
+
|
| 138 |
+
q2k = torch.full((B, H, Sq, topK), -1, dtype=torch.int32, device=device)
|
| 139 |
+
counters = torch.zeros(B, H, Sq, dtype=torch.int64, device=device)
|
| 140 |
+
|
| 141 |
+
for b in range(B):
|
| 142 |
+
for h in range(H):
|
| 143 |
+
for kv_blk in range(Nkv):
|
| 144 |
+
count = k2q_counts[b, h, kv_blk].item()
|
| 145 |
+
for j in range(count):
|
| 146 |
+
qt = k2q_indices[b, h, kv_blk, j].item()
|
| 147 |
+
if qt < 0:
|
| 148 |
+
continue
|
| 149 |
+
p = counters[b, h, qt].item()
|
| 150 |
+
if p < topK:
|
| 151 |
+
q2k[b, h, qt, p] = kv_blk
|
| 152 |
+
counters[b, h, qt] += 1
|
| 153 |
+
|
| 154 |
+
q2k_sort_key = torch.where(q2k < 0, torch.full_like(q2k, Nkv), q2k)
|
| 155 |
+
_, sort_idx = q2k_sort_key.sort(dim=-1)
|
| 156 |
+
q2k = q2k.gather(-1, sort_idx)
|
| 157 |
+
return q2k
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _validate_cu_seqlens(cu_seqlens: torch.Tensor, *, name: str) -> None:
|
| 161 |
+
if cu_seqlens.dtype != torch.int32:
|
| 162 |
+
raise TypeError(f"{name} must be torch.int32, got {cu_seqlens.dtype}")
|
| 163 |
+
if cu_seqlens.ndim != 1:
|
| 164 |
+
raise ValueError(f"{name} must be rank-1, got shape {tuple(cu_seqlens.shape)}")
|
| 165 |
+
if cu_seqlens.numel() < 1:
|
| 166 |
+
raise ValueError(f"{name} must have at least one element")
|
| 167 |
+
if not cu_seqlens.is_contiguous():
|
| 168 |
+
raise ValueError(f"{name} must be contiguous")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _rows_per_batch(cu_seqlens_k: torch.Tensor, kv_block_size: int) -> torch.Tensor:
|
| 172 |
+
seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
| 173 |
+
return (seqlens_k + kv_block_size - 1) // kv_block_size
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _build_packed_row_map(rows_per_batch: torch.Tensor) -> tuple[torch.Tensor, int]:
|
| 177 |
+
rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist()
|
| 178 |
+
batch = len(rows_per_batch_cpu)
|
| 179 |
+
max_rows = max(rows_per_batch_cpu, default=0)
|
| 180 |
+
row_dtype = (
|
| 181 |
+
torch.int32
|
| 182 |
+
if sum(rows_per_batch_cpu) < torch.iinfo(torch.int32).max
|
| 183 |
+
else torch.int64
|
| 184 |
+
)
|
| 185 |
+
row_map_cpu = torch.full((batch, max_rows), -1, dtype=row_dtype)
|
| 186 |
+
row_linear = 0
|
| 187 |
+
for kv_block_idx in range(max_rows):
|
| 188 |
+
for batch_idx, row_count in enumerate(rows_per_batch_cpu):
|
| 189 |
+
if kv_block_idx < row_count:
|
| 190 |
+
row_map_cpu[batch_idx, kv_block_idx] = row_linear
|
| 191 |
+
row_linear += 1
|
| 192 |
+
return row_map_cpu.to(rows_per_batch.device), row_linear
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def build_k2q_csr_torch_reference(
|
| 196 |
+
q2k_indices: torch.Tensor,
|
| 197 |
+
cu_seqlens_q: torch.Tensor,
|
| 198 |
+
cu_seqlens_k: torch.Tensor,
|
| 199 |
+
kv_block_size: int,
|
| 200 |
+
) -> tuple:
|
| 201 |
+
"""Torch reference for q2k -> k2q CSR conversion.
|
| 202 |
+
|
| 203 |
+
Parameters
|
| 204 |
+
----------
|
| 205 |
+
q2k_indices : torch.Tensor
|
| 206 |
+
Shape ``[head_kv, total_q, topK]``, dtype int32. Values are
|
| 207 |
+
batch-local KV block indices padded with ``-1``.
|
| 208 |
+
cu_seqlens_q : torch.Tensor
|
| 209 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
|
| 210 |
+
cu_seqlens_k : torch.Tensor
|
| 211 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
|
| 212 |
+
kv_block_size : int
|
| 213 |
+
Number of KV tokens per sparse block.
|
| 214 |
+
|
| 215 |
+
Returns
|
| 216 |
+
-------
|
| 217 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 218 |
+
``(k2q_row_ptr, k2q_q_indices)`` where ``k2q_row_ptr`` has shape
|
| 219 |
+
``[head_kv, total_rows + 1]`` and ``k2q_q_indices`` has shape
|
| 220 |
+
``[head_kv, total_q * topK]``.
|
| 221 |
+
"""
|
| 222 |
+
if kv_block_size <= 0:
|
| 223 |
+
raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}")
|
| 224 |
+
if q2k_indices.dtype != torch.int32:
|
| 225 |
+
raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}")
|
| 226 |
+
if q2k_indices.ndim != 3:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"q2k_indices must have shape [head_kv, total_q, topK], "
|
| 229 |
+
f"got {tuple(q2k_indices.shape)}"
|
| 230 |
+
)
|
| 231 |
+
_validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q")
|
| 232 |
+
_validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k")
|
| 233 |
+
if cu_seqlens_q.shape != cu_seqlens_k.shape:
|
| 234 |
+
raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]")
|
| 235 |
+
if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device:
|
| 236 |
+
raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device")
|
| 237 |
+
|
| 238 |
+
head_kv, total_q, topk = q2k_indices.shape
|
| 239 |
+
if total_q != int(cu_seqlens_q[-1].item()):
|
| 240 |
+
raise ValueError(
|
| 241 |
+
f"q2k_indices.shape[1] ({total_q}) must equal cu_seqlens_q[-1] "
|
| 242 |
+
f"({int(cu_seqlens_q[-1].item())})"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
rows_per_batch = _rows_per_batch(cu_seqlens_k, kv_block_size)
|
| 246 |
+
row_map, total_rows = _build_packed_row_map(rows_per_batch)
|
| 247 |
+
nnz_upper_bound = total_q * topk
|
| 248 |
+
|
| 249 |
+
k2q_row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device)
|
| 250 |
+
k2q_q_indices = torch.full(
|
| 251 |
+
(head_kv, nnz_upper_bound), -1, dtype=torch.int32, device=q2k_indices.device
|
| 252 |
+
)
|
| 253 |
+
if total_rows == 0 or total_q == 0 or topk == 0:
|
| 254 |
+
return k2q_row_ptr, k2q_q_indices
|
| 255 |
+
|
| 256 |
+
counts = torch.zeros((head_kv, total_rows), dtype=torch.int32, device=q2k_indices.device)
|
| 257 |
+
total_entries = total_q * topk
|
| 258 |
+
row_dtype = torch.int32 if total_rows < torch.iinfo(torch.int32).max else torch.int64
|
| 259 |
+
row_all = torch.empty((head_kv, total_entries), dtype=row_dtype, device=q2k_indices.device)
|
| 260 |
+
q_all = torch.empty((head_kv, total_entries), dtype=torch.int32, device=q2k_indices.device)
|
| 261 |
+
valid_all = torch.empty((head_kv, total_entries), dtype=torch.bool, device=q2k_indices.device)
|
| 262 |
+
rows_per_batch_cpu = rows_per_batch.to("cpu", non_blocking=False).tolist()
|
| 263 |
+
q_cu_cpu = cu_seqlens_q.to("cpu", non_blocking=False).tolist()
|
| 264 |
+
entry_cursor = 0
|
| 265 |
+
|
| 266 |
+
for batch_idx, kv_rows in enumerate(rows_per_batch_cpu):
|
| 267 |
+
q_start = q_cu_cpu[batch_idx]
|
| 268 |
+
q_end = q_cu_cpu[batch_idx + 1]
|
| 269 |
+
q_len = q_end - q_start
|
| 270 |
+
if q_len == 0:
|
| 271 |
+
continue
|
| 272 |
+
num_entries = q_len * topk
|
| 273 |
+
q2k_batch = q2k_indices[:, q_start:q_end, :]
|
| 274 |
+
valid_batch = q2k_batch >= 0
|
| 275 |
+
if valid_batch.any():
|
| 276 |
+
max_valid_kv = int(q2k_batch[valid_batch].max().item())
|
| 277 |
+
if max_valid_kv >= kv_rows:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
f"q2k_indices references kv_block {max_valid_kv} for batch {batch_idx}, "
|
| 280 |
+
f"but that batch only has {kv_rows} logical kv blocks"
|
| 281 |
+
)
|
| 282 |
+
kv_flat = q2k_batch.reshape(head_kv, num_entries).long()
|
| 283 |
+
valid_flat = valid_batch.reshape(head_kv, num_entries)
|
| 284 |
+
safe_kv_flat = torch.where(valid_flat, kv_flat, torch.zeros_like(kv_flat))
|
| 285 |
+
row_map_batch = row_map[batch_idx]
|
| 286 |
+
row_flat = row_map_batch[safe_kv_flat]
|
| 287 |
+
q_flat = (
|
| 288 |
+
torch.arange(q_len, device=q2k_indices.device, dtype=torch.int32)
|
| 289 |
+
.view(1, q_len, 1)
|
| 290 |
+
.expand(head_kv, q_len, topk)
|
| 291 |
+
.reshape(head_kv, num_entries)
|
| 292 |
+
)
|
| 293 |
+
row_all[:, entry_cursor : entry_cursor + num_entries] = row_flat
|
| 294 |
+
q_all[:, entry_cursor : entry_cursor + num_entries] = q_flat
|
| 295 |
+
valid_all[:, entry_cursor : entry_cursor + num_entries] = valid_flat
|
| 296 |
+
counts.scatter_add_(1, row_flat.to(torch.int64), valid_flat.to(torch.int32))
|
| 297 |
+
entry_cursor += num_entries
|
| 298 |
+
|
| 299 |
+
k2q_row_ptr[:, 1:] = counts.cumsum(dim=1, dtype=torch.int32)
|
| 300 |
+
|
| 301 |
+
sort_stride = max(total_q, 1)
|
| 302 |
+
invalid_key = total_rows * sort_stride
|
| 303 |
+
max_sort_key = invalid_key + max(total_q - 1, 0)
|
| 304 |
+
if max_sort_key < torch.iinfo(torch.int32).max:
|
| 305 |
+
sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int32)
|
| 306 |
+
sort_keys[valid_all] = row_all[valid_all] * sort_stride + q_all[valid_all]
|
| 307 |
+
else:
|
| 308 |
+
sort_keys = torch.full_like(row_all, invalid_key, dtype=torch.int64)
|
| 309 |
+
sort_keys[valid_all] = (
|
| 310 |
+
row_all[valid_all].to(torch.int64) * sort_stride
|
| 311 |
+
+ q_all[valid_all].to(torch.int64)
|
| 312 |
+
)
|
| 313 |
+
_, sort_idx = sort_keys.sort(dim=1, stable=True)
|
| 314 |
+
sorted_q = q_all.gather(1, sort_idx)
|
| 315 |
+
|
| 316 |
+
valid_counts = valid_all.sum(dim=1)
|
| 317 |
+
write_mask = (
|
| 318 |
+
torch.arange(total_entries, device=q2k_indices.device)
|
| 319 |
+
.unsqueeze(0)
|
| 320 |
+
.expand(head_kv, -1)
|
| 321 |
+
< valid_counts.unsqueeze(1)
|
| 322 |
+
)
|
| 323 |
+
k2q_q_indices[write_mask] = sorted_q[write_mask]
|
| 324 |
+
|
| 325 |
+
return k2q_row_ptr, k2q_q_indices
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
_K2Q_CSR_BUILDER = SparseK2qCsrBuilderSm100()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def build_k2q_csr(
|
| 332 |
+
q2k_indices: torch.Tensor,
|
| 333 |
+
cu_seqlens_q: torch.Tensor,
|
| 334 |
+
cu_seqlens_k: torch.Tensor,
|
| 335 |
+
kv_block_size: int,
|
| 336 |
+
*,
|
| 337 |
+
total_k: Optional[int] = None,
|
| 338 |
+
max_seqlen_k: Optional[int] = None,
|
| 339 |
+
max_seqlen_q: Optional[int] = None,
|
| 340 |
+
total_rows: Optional[int] = None,
|
| 341 |
+
qhead_per_kv: int = 1,
|
| 342 |
+
return_schedule: bool = False,
|
| 343 |
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, object]:
|
| 344 |
+
"""Build the public k2q CSR reverse index on GPU.
|
| 345 |
+
|
| 346 |
+
Runtime construction does not read device-side ``cu_seqlens`` on the host,
|
| 347 |
+
so callers must provide size hints such as ``total_k`` from already-known
|
| 348 |
+
tensor shapes.
|
| 349 |
+
|
| 350 |
+
Parameters
|
| 351 |
+
----------
|
| 352 |
+
q2k_indices : torch.Tensor
|
| 353 |
+
Shape ``[head_kv, total_q, topK]``, dtype int32, contiguous. Values are
|
| 354 |
+
batch-local KV block indices with trailing ``-1`` padding.
|
| 355 |
+
cu_seqlens_q : torch.Tensor
|
| 356 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of Q lengths.
|
| 357 |
+
cu_seqlens_k : torch.Tensor
|
| 358 |
+
Shape ``[batch_size + 1]``, dtype int32. Prefix sums of KV lengths.
|
| 359 |
+
kv_block_size : int
|
| 360 |
+
Number of KV tokens per sparse block.
|
| 361 |
+
total_k : int
|
| 362 |
+
Total KV token count. Required; normally ``k.shape[0]`` for dense KV
|
| 363 |
+
or ``sum(kv_segment_lens)`` for paged KV.
|
| 364 |
+
max_seqlen_k : int, optional
|
| 365 |
+
Maximum KV sequence length. Passing this avoids recomputing a bound.
|
| 366 |
+
max_seqlen_q : int, optional
|
| 367 |
+
Maximum Q sequence length.
|
| 368 |
+
total_rows : int, optional
|
| 369 |
+
Total number of packed KV-block rows across the batch. If omitted,
|
| 370 |
+
the builder derives it from ``cu_seqlens_k`` and ``kv_block_size``.
|
| 371 |
+
qhead_per_kv : int, optional
|
| 372 |
+
Number of Q heads per KV head under GQA.
|
| 373 |
+
return_schedule : bool, optional
|
| 374 |
+
If True, also return the sparse forward schedule object produced by the
|
| 375 |
+
SM100 builder.
|
| 376 |
+
|
| 377 |
+
Returns
|
| 378 |
+
-------
|
| 379 |
+
tuple[torch.Tensor, torch.Tensor] or tuple[torch.Tensor, torch.Tensor, object]
|
| 380 |
+
``(k2q_row_ptr, k2q_q_indices)`` or
|
| 381 |
+
``(k2q_row_ptr, k2q_q_indices, schedule)``. CSR tensors are int32 on
|
| 382 |
+
the same CUDA device as ``q2k_indices``.
|
| 383 |
+
"""
|
| 384 |
+
if total_k is None:
|
| 385 |
+
raise ValueError("build_k2q_csr requires total_k from k.shape[0]")
|
| 386 |
+
if kv_block_size <= 0:
|
| 387 |
+
raise ValueError(f"kv_block_size must be > 0, got {kv_block_size}")
|
| 388 |
+
if q2k_indices.dtype != torch.int32:
|
| 389 |
+
raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}")
|
| 390 |
+
if q2k_indices.ndim != 3:
|
| 391 |
+
raise ValueError(f"q2k_indices must be rank-3, got shape {tuple(q2k_indices.shape)}")
|
| 392 |
+
if not q2k_indices.is_contiguous():
|
| 393 |
+
raise ValueError("q2k_indices must be contiguous with layout [head_kv, total_q, topK]")
|
| 394 |
+
_validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q")
|
| 395 |
+
_validate_cu_seqlens(cu_seqlens_k, name="cu_seqlens_k")
|
| 396 |
+
if cu_seqlens_q.shape != cu_seqlens_k.shape:
|
| 397 |
+
raise ValueError("cu_seqlens_q and cu_seqlens_k must have the same shape [B + 1]")
|
| 398 |
+
if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device:
|
| 399 |
+
raise ValueError("q2k_indices, cu_seqlens_q, and cu_seqlens_k must be on the same device")
|
| 400 |
+
return _K2Q_CSR_BUILDER(
|
| 401 |
+
q2k_indices,
|
| 402 |
+
cu_seqlens_q,
|
| 403 |
+
cu_seqlens_k,
|
| 404 |
+
total_k=int(total_k),
|
| 405 |
+
blk_kv=int(kv_block_size),
|
| 406 |
+
max_seqlen_k=max_seqlen_k,
|
| 407 |
+
max_seqlen_q=max_seqlen_q,
|
| 408 |
+
total_rows=total_rows,
|
| 409 |
+
qhead_per_kv=qhead_per_kv,
|
| 410 |
+
return_schedule=return_schedule,
|
| 411 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/aot_cache.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Persistent AOT cache for CuTe DSL compiled kernels.
|
| 5 |
+
|
| 6 |
+
Saves compiled TVM FFI kernels as .o files on first compile,
|
| 7 |
+
loads them on subsequent runs to skip JIT compilation.
|
| 8 |
+
|
| 9 |
+
Environment variables:
|
| 10 |
+
MM_SPARSE_ATTN_AOT_CACHE: Override cache directory
|
| 11 |
+
(default: ~/.cache/minfer/mm_sparse_attn)
|
| 12 |
+
MM_SPARSE_ATTN_AOT_DISABLE=1: Disable AOT cache entirely
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
import cutlass.cute as cute
|
| 20 |
+
|
| 21 |
+
_AOT_CACHE_DIR = os.environ.get(
|
| 22 |
+
"MM_SPARSE_ATTN_AOT_CACHE",
|
| 23 |
+
os.path.expanduser("~/.cache/minfer/mm_sparse_attn"),
|
| 24 |
+
)
|
| 25 |
+
_AOT_DISABLE = os.environ.get("MM_SPARSE_ATTN_AOT_DISABLE", "0") == "1"
|
| 26 |
+
|
| 27 |
+
_loaded_modules: dict[str, object] = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _key_to_path(key: tuple) -> str:
|
| 31 |
+
h = hashlib.sha256(repr(key).encode()).hexdigest()[:16]
|
| 32 |
+
name = str(key[0]).replace("/", "_")
|
| 33 |
+
return os.path.join(_AOT_CACHE_DIR, f"{name}_{h}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def try_load_aot(key: tuple):
|
| 37 |
+
if _AOT_DISABLE:
|
| 38 |
+
return None
|
| 39 |
+
obj_path = _key_to_path(key) + ".o"
|
| 40 |
+
if not os.path.isfile(obj_path):
|
| 41 |
+
return None
|
| 42 |
+
func_name = str(key[0])
|
| 43 |
+
try:
|
| 44 |
+
if obj_path not in _loaded_modules:
|
| 45 |
+
_loaded_modules[obj_path] = cute.runtime.load_module(
|
| 46 |
+
obj_path, enable_tvm_ffi=True
|
| 47 |
+
)
|
| 48 |
+
return getattr(_loaded_modules[obj_path], func_name)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"[aot_cache] Failed to load {obj_path}: {e}")
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_aot(key: tuple, compiled) -> None:
|
| 55 |
+
if _AOT_DISABLE:
|
| 56 |
+
return
|
| 57 |
+
if not hasattr(compiled, "export_to_c"):
|
| 58 |
+
return
|
| 59 |
+
obj_path = _key_to_path(key) + ".o"
|
| 60 |
+
os.makedirs(_AOT_CACHE_DIR, exist_ok=True)
|
| 61 |
+
tmp_path = obj_path + f".tmp.{os.getpid()}"
|
| 62 |
+
func_name = str(key[0])
|
| 63 |
+
try:
|
| 64 |
+
t0 = time.time()
|
| 65 |
+
compiled.export_to_c(tmp_path, function_name=func_name)
|
| 66 |
+
os.replace(tmp_path, obj_path)
|
| 67 |
+
dt = time.time() - t0
|
| 68 |
+
print(f"[aot_cache] Saved {func_name} -> {obj_path} ({dt:.1f}s)")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"[aot_cache] Failed to save {func_name}: {e}")
|
| 71 |
+
if os.path.exists(tmp_path):
|
| 72 |
+
os.remove(tmp_path)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/barrier.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Int32
|
| 7 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 8 |
+
from cutlass._mlir.dialects import llvm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dsl_user_op
|
| 12 |
+
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
|
| 13 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 14 |
+
state = llvm.inline_asm(
|
| 15 |
+
T.i32(),
|
| 16 |
+
[lock_ptr_i64],
|
| 17 |
+
"ld.global.acquire.gpu.b32 $0, [$1];",
|
| 18 |
+
"=r,l",
|
| 19 |
+
has_side_effects=True,
|
| 20 |
+
is_align_stack=False,
|
| 21 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 22 |
+
)
|
| 23 |
+
return cutlass.Int32(state)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dsl_user_op
|
| 27 |
+
def red_relaxed(
|
| 28 |
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
| 29 |
+
) -> None:
|
| 30 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 31 |
+
llvm.inline_asm(
|
| 32 |
+
None,
|
| 33 |
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
| 34 |
+
"red.relaxed.gpu.global.add.s32 [$0], $1;",
|
| 35 |
+
"l,r",
|
| 36 |
+
has_side_effects=True,
|
| 37 |
+
is_align_stack=False,
|
| 38 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dsl_user_op
|
| 43 |
+
def red_release(
|
| 44 |
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
| 45 |
+
) -> None:
|
| 46 |
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 47 |
+
llvm.inline_asm(
|
| 48 |
+
None,
|
| 49 |
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
| 50 |
+
"red.release.gpu.global.add.s32 [$0], $1;",
|
| 51 |
+
"l,r",
|
| 52 |
+
has_side_effects=True,
|
| 53 |
+
is_align_stack=False,
|
| 54 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@cute.jit
|
| 59 |
+
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
|
| 60 |
+
flag_ptr = lock_ptr + flag_offset
|
| 61 |
+
if thread_idx == 0:
|
| 62 |
+
read_val = Int32(0)
|
| 63 |
+
while read_val != val:
|
| 64 |
+
read_val = ld_acquire(flag_ptr)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@cute.jit
|
| 68 |
+
def arrive_inc(
|
| 69 |
+
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
|
| 70 |
+
) -> None:
|
| 71 |
+
flag_ptr = lock_ptr + flag_offset
|
| 72 |
+
if thread_idx == 0:
|
| 73 |
+
red_release(flag_ptr, val)
|
| 74 |
+
# red_relaxed(flag_ptr, val)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/blackwell_helpers.py
ADDED
|
@@ -0,0 +1,1093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Int32, Boolean, const_expr
|
| 9 |
+
from cutlass.cute.nvgpu import tcgen05
|
| 10 |
+
from cutlass._mlir.dialects import llvm
|
| 11 |
+
|
| 12 |
+
from . import mma_sm100_desc as sm100_desc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@cute.jit
|
| 16 |
+
def gemm_w_idx(
|
| 17 |
+
tiled_mma: cute.TiledMma,
|
| 18 |
+
acc: cute.Tensor,
|
| 19 |
+
tCrA: cute.Tensor,
|
| 20 |
+
tCrB: cute.Tensor,
|
| 21 |
+
A_idx: Optional[Int32] = None,
|
| 22 |
+
B_idx: Optional[Int32] = None,
|
| 23 |
+
zero_init: bool | Boolean = False,
|
| 24 |
+
swap_AB: bool = False,
|
| 25 |
+
num_unroll_groups: int = 1,
|
| 26 |
+
) -> None:
|
| 27 |
+
if const_expr(swap_AB):
|
| 28 |
+
return gemm_w_idx(
|
| 29 |
+
tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 33 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 34 |
+
|
| 35 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 36 |
+
for k in cutlass.range(
|
| 37 |
+
cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups
|
| 38 |
+
):
|
| 39 |
+
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
| 40 |
+
cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@cute.jit
|
| 44 |
+
def gemm_ptx_w_idx(
|
| 45 |
+
tiled_mma: cute.TiledMma,
|
| 46 |
+
acc: cute.Tensor,
|
| 47 |
+
tCrA: cute.Tensor,
|
| 48 |
+
tCrB: cute.Tensor,
|
| 49 |
+
sA: Optional[cute.Tensor],
|
| 50 |
+
sB: cute.Tensor,
|
| 51 |
+
A_idx: Optional[Int32] = None,
|
| 52 |
+
B_idx: Optional[Int32] = None,
|
| 53 |
+
zero_init: bool | Boolean = False,
|
| 54 |
+
cta_group: int = 1,
|
| 55 |
+
**kwargs,
|
| 56 |
+
) -> None:
|
| 57 |
+
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
| 58 |
+
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
| 59 |
+
sA_cur = None
|
| 60 |
+
if const_expr(sA is not None):
|
| 61 |
+
sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
|
| 62 |
+
sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
|
| 63 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 64 |
+
acc_tmem_addr = acc.iterator.toint()
|
| 65 |
+
gemm_ptx_partial(
|
| 66 |
+
mma_atom.op,
|
| 67 |
+
acc_tmem_addr,
|
| 68 |
+
rA,
|
| 69 |
+
rB,
|
| 70 |
+
sA_cur,
|
| 71 |
+
sB_cur,
|
| 72 |
+
zero_init=zero_init,
|
| 73 |
+
cta_group=cta_group,
|
| 74 |
+
**kwargs,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@cute.jit
|
| 79 |
+
def gemm(
|
| 80 |
+
tiled_mma: cute.TiledMma,
|
| 81 |
+
acc: cute.Tensor,
|
| 82 |
+
tCrA: cute.Tensor,
|
| 83 |
+
tCrB: cute.Tensor,
|
| 84 |
+
zero_init: bool | Boolean = False,
|
| 85 |
+
) -> None:
|
| 86 |
+
mma_atom = cute.make_mma_atom(tiled_mma.op)
|
| 87 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 88 |
+
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
|
| 89 |
+
cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def i64_to_i32x2(i: int) -> Tuple[int, int]:
|
| 93 |
+
"""Convert a 64-bit integer to a tuple of two 32-bit integers."""
|
| 94 |
+
return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@cute.jit
|
| 98 |
+
def gemm_ptx(
|
| 99 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 100 |
+
acc: cute.Tensor,
|
| 101 |
+
tCrA: cute.Tensor,
|
| 102 |
+
tCrB: cute.Tensor,
|
| 103 |
+
sA: Optional[cute.Tensor],
|
| 104 |
+
sB: cute.Tensor,
|
| 105 |
+
zero_init: bool | Boolean = False,
|
| 106 |
+
) -> None:
|
| 107 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 108 |
+
if const_expr(not is_ts):
|
| 109 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 110 |
+
sA_layout = sA.layout if sA is not None else None
|
| 111 |
+
sB_layout = sB.layout
|
| 112 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 113 |
+
if const_expr(not is_ts):
|
| 114 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 115 |
+
smem_desc_base_a: int = const_expr(
|
| 116 |
+
sm100_desc.make_smem_desc_base(
|
| 117 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 118 |
+
sA_swizzle,
|
| 119 |
+
sm100_desc.Major.K
|
| 120 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 121 |
+
else sm100_desc.Major.MN,
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 125 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 126 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 127 |
+
else:
|
| 128 |
+
smem_desc_base_a = None
|
| 129 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 130 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 131 |
+
smem_desc_base_b: int = const_expr(
|
| 132 |
+
sm100_desc.make_smem_desc_base(
|
| 133 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 134 |
+
sB_swizzle,
|
| 135 |
+
sm100_desc.Major.K
|
| 136 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 137 |
+
else sm100_desc.Major.MN,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 141 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 142 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 143 |
+
|
| 144 |
+
if const_expr(not is_ts):
|
| 145 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
|
| 146 |
+
sA[None, None, 0].iterator
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
smem_desc_start_a_lo = None
|
| 150 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
|
| 151 |
+
sB[None, None, 0].iterator
|
| 152 |
+
)
|
| 153 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
| 154 |
+
if const_expr(not is_ts):
|
| 155 |
+
smem_desc_a_lo = smem_desc_start_a_lo + (
|
| 156 |
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
| 157 |
+
)
|
| 158 |
+
smem_desc_b_lo = smem_desc_start_b_lo + (
|
| 159 |
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
| 160 |
+
)
|
| 161 |
+
# with cute.arch.elect_one():
|
| 162 |
+
# cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
|
| 163 |
+
# cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
|
| 164 |
+
with cute.arch.elect_one():
|
| 165 |
+
if const_expr(not is_ts):
|
| 166 |
+
llvm.inline_asm(
|
| 167 |
+
None,
|
| 168 |
+
[
|
| 169 |
+
acc.iterator.toint().ir_value(),
|
| 170 |
+
smem_desc_a_lo.ir_value(),
|
| 171 |
+
smem_desc_b_lo.ir_value(),
|
| 172 |
+
Int32(not zero_init or k != 0).ir_value(),
|
| 173 |
+
],
|
| 174 |
+
"{\n\t"
|
| 175 |
+
".reg .pred p;\n\t"
|
| 176 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 177 |
+
".reg .b32 idesc;\n\t"
|
| 178 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 179 |
+
f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
|
| 180 |
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
| 181 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 182 |
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
|
| 183 |
+
"}\n",
|
| 184 |
+
"r,r,r,r",
|
| 185 |
+
has_side_effects=True,
|
| 186 |
+
is_align_stack=False,
|
| 187 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
llvm.inline_asm(
|
| 191 |
+
None,
|
| 192 |
+
[
|
| 193 |
+
acc.iterator.toint().ir_value(),
|
| 194 |
+
tCrA[None, None, k].iterator.toint().ir_value(),
|
| 195 |
+
smem_desc_b_lo.ir_value(),
|
| 196 |
+
Int32(not zero_init or k != 0).ir_value(),
|
| 197 |
+
],
|
| 198 |
+
"{\n\t"
|
| 199 |
+
".reg .pred p;\n\t"
|
| 200 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 201 |
+
f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
|
| 202 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 203 |
+
f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
|
| 204 |
+
"}\n",
|
| 205 |
+
"r,r,r,r",
|
| 206 |
+
has_side_effects=True,
|
| 207 |
+
is_align_stack=False,
|
| 208 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@cute.jit
|
| 213 |
+
def gemm_ptx_loop(
|
| 214 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 215 |
+
acc: cute.Tensor,
|
| 216 |
+
tCrA: cute.Tensor,
|
| 217 |
+
tCrB: cute.Tensor,
|
| 218 |
+
sA: Optional[cute.Tensor],
|
| 219 |
+
sB: cute.Tensor,
|
| 220 |
+
zero_init: bool | Boolean = False,
|
| 221 |
+
) -> None:
|
| 222 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 223 |
+
if const_expr(not is_ts):
|
| 224 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 225 |
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
| 226 |
+
sB_layout = sB.layout
|
| 227 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 228 |
+
if const_expr(not is_ts):
|
| 229 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 230 |
+
smem_desc_base_a: int = const_expr(
|
| 231 |
+
sm100_desc.make_smem_desc_base(
|
| 232 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 233 |
+
sA_swizzle,
|
| 234 |
+
sm100_desc.Major.K
|
| 235 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 236 |
+
else sm100_desc.Major.MN,
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 240 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 241 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 242 |
+
else:
|
| 243 |
+
smem_desc_base_a = None
|
| 244 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 245 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 246 |
+
smem_desc_base_b: int = const_expr(
|
| 247 |
+
sm100_desc.make_smem_desc_base(
|
| 248 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 249 |
+
sB_swizzle,
|
| 250 |
+
sm100_desc.Major.K
|
| 251 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 252 |
+
else sm100_desc.Major.MN,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 256 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 257 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 258 |
+
|
| 259 |
+
if const_expr(not is_ts):
|
| 260 |
+
offset_a = [
|
| 261 |
+
(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
|
| 262 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
| 263 |
+
]
|
| 264 |
+
else:
|
| 265 |
+
offset_a = [
|
| 266 |
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
| 267 |
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
|
| 268 |
+
]
|
| 269 |
+
offset_a_diff = [
|
| 270 |
+
offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 271 |
+
]
|
| 272 |
+
offset_b = [
|
| 273 |
+
(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
|
| 274 |
+
for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
|
| 275 |
+
]
|
| 276 |
+
offset_b_diff = [
|
| 277 |
+
offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
if const_expr(not is_ts):
|
| 281 |
+
smem_desc_start_a_lo = Int32(
|
| 282 |
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
smem_desc_start_a_lo = None
|
| 286 |
+
smem_desc_start_b_lo = Int32(
|
| 287 |
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
| 288 |
+
)
|
| 289 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 290 |
+
if const_expr(not is_ts):
|
| 291 |
+
llvm.inline_asm(
|
| 292 |
+
None,
|
| 293 |
+
[
|
| 294 |
+
acc.iterator.toint().ir_value(),
|
| 295 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 296 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 297 |
+
Int32(not zero_init).ir_value(),
|
| 298 |
+
],
|
| 299 |
+
"{\n\t"
|
| 300 |
+
".reg .pred leader_thread;\n\t"
|
| 301 |
+
".reg .pred p;\n\t"
|
| 302 |
+
".reg .b32 idesc;\n\t"
|
| 303 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 304 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 305 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 306 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 307 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 308 |
+
"mov.b32 smem_desc_a_lo, $1;\n\t"
|
| 309 |
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 310 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 311 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 312 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 313 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 314 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 315 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 316 |
+
+ "".join(
|
| 317 |
+
(
|
| 318 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 319 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 320 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 321 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 322 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 323 |
+
)
|
| 324 |
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 325 |
+
)
|
| 326 |
+
+ "}\n",
|
| 327 |
+
"r,r,r,r",
|
| 328 |
+
has_side_effects=True,
|
| 329 |
+
is_align_stack=False,
|
| 330 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
llvm.inline_asm(
|
| 334 |
+
None,
|
| 335 |
+
[
|
| 336 |
+
acc.iterator.toint().ir_value(),
|
| 337 |
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 338 |
+
Int32(smem_desc_start_b_lo).ir_value(),
|
| 339 |
+
Int32(not zero_init).ir_value(),
|
| 340 |
+
],
|
| 341 |
+
"{\n\t"
|
| 342 |
+
".reg .pred leader_thread;\n\t"
|
| 343 |
+
".reg .pred p;\n\t"
|
| 344 |
+
".reg .b32 idesc;\n\t"
|
| 345 |
+
".reg .b32 tmem_a;\n\t"
|
| 346 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 347 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 348 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 349 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 350 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 351 |
+
"mov.b32 tmem_a, $1;\n\t"
|
| 352 |
+
"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 353 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 354 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 355 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 356 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 357 |
+
+ "".join(
|
| 358 |
+
(
|
| 359 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 360 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 361 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 362 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 363 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 364 |
+
)
|
| 365 |
+
for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
|
| 366 |
+
)
|
| 367 |
+
+ "}\n",
|
| 368 |
+
"r,r,r,r",
|
| 369 |
+
has_side_effects=True,
|
| 370 |
+
is_align_stack=False,
|
| 371 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@cute.jit
|
| 376 |
+
def gemm_ptx_partial(
|
| 377 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 378 |
+
acc_tmem_addr: Int32,
|
| 379 |
+
tCrA: cute.Tensor,
|
| 380 |
+
tCrB: cute.Tensor,
|
| 381 |
+
sA: Optional[cute.Tensor],
|
| 382 |
+
sB: cute.Tensor,
|
| 383 |
+
mbar_ptr: Optional[cutlass.Pointer] = None,
|
| 384 |
+
mbar_phase: Optional[Int32] = None,
|
| 385 |
+
split_arrive: Optional[int] = None,
|
| 386 |
+
zero_init: bool | Boolean = False,
|
| 387 |
+
# sA_offset: Int32 = 0,
|
| 388 |
+
# acc_offset: Int32 = 0,
|
| 389 |
+
tA_addr: Optional[Int32] = None,
|
| 390 |
+
cta_group: int = 1,
|
| 391 |
+
mma_kind: str = "f16",
|
| 392 |
+
) -> None:
|
| 393 |
+
# acc_tmem_addr += acc_offset
|
| 394 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 395 |
+
if const_expr(not is_ts):
|
| 396 |
+
assert sA is not None, "sA must be provided when a_src is not TMEM"
|
| 397 |
+
sA_layout = sA.layout if sA is not None else tCrA.layout
|
| 398 |
+
sB_layout = sB.layout
|
| 399 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 400 |
+
if const_expr(not is_ts):
|
| 401 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 402 |
+
smem_desc_base_a: int = const_expr(
|
| 403 |
+
sm100_desc.make_smem_desc_base(
|
| 404 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 405 |
+
sA_swizzle,
|
| 406 |
+
sm100_desc.Major.K
|
| 407 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 408 |
+
else sm100_desc.Major.MN,
|
| 409 |
+
)
|
| 410 |
+
)
|
| 411 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 412 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 413 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 414 |
+
else:
|
| 415 |
+
smem_desc_base_a = None
|
| 416 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 417 |
+
sB_swizzle = sB.iterator.type.swizzle_type
|
| 418 |
+
smem_desc_base_b: int = const_expr(
|
| 419 |
+
sm100_desc.make_smem_desc_base(
|
| 420 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 421 |
+
sB_swizzle,
|
| 422 |
+
sm100_desc.Major.K
|
| 423 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 424 |
+
else sm100_desc.Major.MN,
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 428 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 429 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 430 |
+
|
| 431 |
+
tCrA_layout = (
|
| 432 |
+
tCrA.layout
|
| 433 |
+
if const_expr(not is_ts)
|
| 434 |
+
else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
|
| 435 |
+
)
|
| 436 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
|
| 437 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
| 438 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
|
| 439 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
| 440 |
+
|
| 441 |
+
if const_expr(not is_ts):
|
| 442 |
+
smem_desc_start_a_lo = Int32(
|
| 443 |
+
smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
|
| 444 |
+
)
|
| 445 |
+
# ) + sA_offset
|
| 446 |
+
else:
|
| 447 |
+
smem_desc_start_a_lo = None
|
| 448 |
+
smem_desc_start_b_lo = Int32(
|
| 449 |
+
smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
|
| 450 |
+
)
|
| 451 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 452 |
+
if const_expr(not is_ts):
|
| 453 |
+
assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
|
| 454 |
+
llvm.inline_asm(
|
| 455 |
+
None,
|
| 456 |
+
[
|
| 457 |
+
# acc.iterator.toint().ir_value(),
|
| 458 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 459 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 460 |
+
Int32(not zero_init).ir_value(),
|
| 461 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 462 |
+
],
|
| 463 |
+
"{\n\t"
|
| 464 |
+
".reg .pred leader_thread;\n\t"
|
| 465 |
+
".reg .pred p;\n\t"
|
| 466 |
+
".reg .b32 idesc;\n\t"
|
| 467 |
+
".reg .b32 tmem_acc;\n\t"
|
| 468 |
+
".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
|
| 469 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 470 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 471 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 472 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 473 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 474 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 475 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 476 |
+
"mov.b32 smem_desc_a_lo_start, $0;\n\t"
|
| 477 |
+
"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 478 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 479 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 480 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
|
| 481 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 482 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 483 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 484 |
+
+ "".join(
|
| 485 |
+
(
|
| 486 |
+
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 487 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 488 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
|
| 489 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 490 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 491 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 492 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 493 |
+
)
|
| 494 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 495 |
+
)
|
| 496 |
+
+ "}\n",
|
| 497 |
+
# "r,r,r",
|
| 498 |
+
"r,r,r,r",
|
| 499 |
+
has_side_effects=True,
|
| 500 |
+
is_align_stack=False,
|
| 501 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
# For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
|
| 505 |
+
# explicitly pass in the tA_addr for correctness.
|
| 506 |
+
tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
|
| 507 |
+
input_args = [
|
| 508 |
+
# Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
|
| 509 |
+
Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
|
| 510 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 511 |
+
Int32(not zero_init).ir_value(),
|
| 512 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 513 |
+
]
|
| 514 |
+
if const_expr(mbar_ptr is not None):
|
| 515 |
+
assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
|
| 516 |
+
assert split_arrive is not None, (
|
| 517 |
+
"split_arrive must be provided when mbar_ptr is not None"
|
| 518 |
+
)
|
| 519 |
+
split_arrive_idx = split_arrive // op.shape_mnk[2]
|
| 520 |
+
input_args.append(mbar_ptr.toint().ir_value())
|
| 521 |
+
input_args.append(Int32(mbar_phase).ir_value())
|
| 522 |
+
mbar_wait_str = (
|
| 523 |
+
".reg .pred P1; \n\t"
|
| 524 |
+
"LAB_WAIT: \n\t"
|
| 525 |
+
"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
|
| 526 |
+
"@P1 bra DONE; \n\t"
|
| 527 |
+
"bra LAB_WAIT; \n\t"
|
| 528 |
+
"DONE: \n\t"
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
mbar_wait_str = ""
|
| 532 |
+
llvm.inline_asm(
|
| 533 |
+
None,
|
| 534 |
+
# [
|
| 535 |
+
# # acc.iterator.toint().ir_value(),
|
| 536 |
+
# Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 537 |
+
# Int32(smem_desc_start_b_lo).ir_value(),
|
| 538 |
+
# Int32(not zero_init).ir_value(),
|
| 539 |
+
# ],
|
| 540 |
+
input_args,
|
| 541 |
+
"{\n\t"
|
| 542 |
+
".reg .pred leader_thread;\n\t"
|
| 543 |
+
".reg .pred p;\n\t"
|
| 544 |
+
".reg .b32 idesc;\n\t"
|
| 545 |
+
".reg .b32 tmem_acc;\n\t"
|
| 546 |
+
".reg .b32 tmem_a;\n\t"
|
| 547 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 548 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 549 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 550 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 551 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 552 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 553 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 554 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 555 |
+
f"mov.b32 tmem_a, $0;\n\t"
|
| 556 |
+
f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 557 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 558 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 559 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 560 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 561 |
+
+ "".join(
|
| 562 |
+
(
|
| 563 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 564 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 565 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 566 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 567 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 568 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 569 |
+
)
|
| 570 |
+
for k in range(
|
| 571 |
+
1,
|
| 572 |
+
cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx,
|
| 573 |
+
)
|
| 574 |
+
)
|
| 575 |
+
+ mbar_wait_str
|
| 576 |
+
+ (
|
| 577 |
+
"".join(
|
| 578 |
+
(
|
| 579 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 580 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 581 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 582 |
+
)
|
| 583 |
+
for k in range(split_arrive_idx, cute.size(tCrA.shape[2]))
|
| 584 |
+
)
|
| 585 |
+
if const_expr(mbar_ptr is not None)
|
| 586 |
+
else ""
|
| 587 |
+
)
|
| 588 |
+
+ "}\n",
|
| 589 |
+
"r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
|
| 590 |
+
has_side_effects=True,
|
| 591 |
+
is_align_stack=False,
|
| 592 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
@cute.jit
|
| 597 |
+
def gemm_ptx_partial1(
|
| 598 |
+
op: cute.nvgpu.tcgen05.mma.MmaOp,
|
| 599 |
+
acc_tmem_addr: cutlass.Constexpr[int],
|
| 600 |
+
tCrA: cute.Tensor,
|
| 601 |
+
tCrB: cute.Tensor,
|
| 602 |
+
sA_base_addr_for_desc: Int32,
|
| 603 |
+
sA_addr_offset_for_desc: cutlass.Constexpr[int],
|
| 604 |
+
sA_stage: Int32,
|
| 605 |
+
sB_base_addr_for_desc: Int32,
|
| 606 |
+
sB_addr_offset_for_desc: cutlass.Constexpr[int],
|
| 607 |
+
sB_stage: Int32,
|
| 608 |
+
sA_layout: Optional[cute.Layout],
|
| 609 |
+
sB_layout: Optional[cute.Layout],
|
| 610 |
+
sA_swizzle: Optional[cute.Swizzle],
|
| 611 |
+
sB_swizzle: cute.Swizzle,
|
| 612 |
+
zero_init: bool | Boolean = False,
|
| 613 |
+
) -> None:
|
| 614 |
+
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
|
| 615 |
+
if const_expr(not is_ts):
|
| 616 |
+
assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
|
| 617 |
+
assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
|
| 618 |
+
idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 619 |
+
if const_expr(not is_ts):
|
| 620 |
+
smem_desc_base_a: int = const_expr(
|
| 621 |
+
sm100_desc.make_smem_desc_base(
|
| 622 |
+
cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
|
| 623 |
+
sA_swizzle,
|
| 624 |
+
sm100_desc.Major.K
|
| 625 |
+
if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 626 |
+
else sm100_desc.Major.MN,
|
| 627 |
+
)
|
| 628 |
+
)
|
| 629 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 630 |
+
smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
|
| 631 |
+
smem_desc_a_hi = const_expr(smem_desc_a_hi)
|
| 632 |
+
else:
|
| 633 |
+
smem_desc_base_a = None
|
| 634 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 635 |
+
smem_desc_base_b: int = const_expr(
|
| 636 |
+
sm100_desc.make_smem_desc_base(
|
| 637 |
+
cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
|
| 638 |
+
sB_swizzle,
|
| 639 |
+
sm100_desc.Major.K
|
| 640 |
+
if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
|
| 641 |
+
else sm100_desc.Major.MN,
|
| 642 |
+
)
|
| 643 |
+
)
|
| 644 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 645 |
+
smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
|
| 646 |
+
smem_desc_b_hi = const_expr(smem_desc_b_hi)
|
| 647 |
+
mask = [Int32(0)] * 4
|
| 648 |
+
|
| 649 |
+
if const_expr(not is_ts):
|
| 650 |
+
offset_a = [
|
| 651 |
+
(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
|
| 652 |
+
for k in range(cute.size(tCrA.shape[2]))
|
| 653 |
+
]
|
| 654 |
+
else:
|
| 655 |
+
offset_a = [
|
| 656 |
+
cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
|
| 657 |
+
for k in range(cute.size(tCrA.shape[2]))
|
| 658 |
+
]
|
| 659 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
|
| 660 |
+
offset_b = [
|
| 661 |
+
(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
|
| 662 |
+
for k in range(cute.size(tCrB.shape[2]))
|
| 663 |
+
]
|
| 664 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
|
| 665 |
+
|
| 666 |
+
if const_expr(not is_ts):
|
| 667 |
+
# smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
|
| 668 |
+
smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
|
| 669 |
+
else:
|
| 670 |
+
smem_desc_start_a_lo = None
|
| 671 |
+
# smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
|
| 672 |
+
smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
|
| 673 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 674 |
+
if const_expr(not is_ts):
|
| 675 |
+
llvm.inline_asm(
|
| 676 |
+
None,
|
| 677 |
+
[
|
| 678 |
+
# acc.iterator.toint().ir_value(),
|
| 679 |
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 680 |
+
Int32(sA_base_addr_for_desc).ir_value(),
|
| 681 |
+
Int32(sA_stage).ir_value(),
|
| 682 |
+
# Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 683 |
+
Int32(sB_base_addr_for_desc).ir_value(),
|
| 684 |
+
Int32(sB_stage).ir_value(),
|
| 685 |
+
Int32(not zero_init).ir_value(),
|
| 686 |
+
mask[0].ir_value(),
|
| 687 |
+
mask[1].ir_value(),
|
| 688 |
+
mask[2].ir_value(),
|
| 689 |
+
mask[3].ir_value(),
|
| 690 |
+
],
|
| 691 |
+
"{\n\t"
|
| 692 |
+
".reg .pred leader_thread;\n\t"
|
| 693 |
+
".reg .pred p;\n\t"
|
| 694 |
+
".reg .b32 idesc;\n\t"
|
| 695 |
+
".reg .b32 tmem_acc;\n\t"
|
| 696 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 697 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 698 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 699 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 700 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 701 |
+
f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 702 |
+
# "mov.b32 smem_desc_a_lo, $0;\n\t"
|
| 703 |
+
# f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
|
| 704 |
+
f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
|
| 705 |
+
# "mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 706 |
+
f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
|
| 707 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 708 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 709 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 710 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 711 |
+
"setp.ne.b32 p, $4, 0;\n\t"
|
| 712 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
|
| 713 |
+
+ "".join(
|
| 714 |
+
(
|
| 715 |
+
f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 716 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 717 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 718 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 719 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
|
| 720 |
+
)
|
| 721 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 722 |
+
)
|
| 723 |
+
+ "}\n",
|
| 724 |
+
"r,r,r,r,r,r,r,r,r",
|
| 725 |
+
has_side_effects=True,
|
| 726 |
+
is_align_stack=False,
|
| 727 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 728 |
+
)
|
| 729 |
+
else:
|
| 730 |
+
llvm.inline_asm(
|
| 731 |
+
None,
|
| 732 |
+
[
|
| 733 |
+
# acc.iterator.toint().ir_value(),
|
| 734 |
+
Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
|
| 735 |
+
Int32(smem_desc_start_b_lo).ir_value(),
|
| 736 |
+
Int32(not zero_init).ir_value(),
|
| 737 |
+
mask[0].ir_value(),
|
| 738 |
+
mask[1].ir_value(),
|
| 739 |
+
mask[2].ir_value(),
|
| 740 |
+
mask[3].ir_value(),
|
| 741 |
+
],
|
| 742 |
+
"{\n\t"
|
| 743 |
+
".reg .pred leader_thread;\n\t"
|
| 744 |
+
".reg .pred p;\n\t"
|
| 745 |
+
".reg .b32 idesc;\n\t"
|
| 746 |
+
".reg .b32 tmem_a;\n\t"
|
| 747 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 748 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 749 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 750 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 751 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 752 |
+
f"mov.b32 tmem_a, $1;\n\t"
|
| 753 |
+
f"mov.b32 smem_desc_b_lo, $2;\n\t"
|
| 754 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 755 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 756 |
+
"setp.ne.b32 p, $3, 0;\n\t"
|
| 757 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
|
| 758 |
+
+ "".join(
|
| 759 |
+
(
|
| 760 |
+
f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 761 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 762 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 763 |
+
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
|
| 764 |
+
)
|
| 765 |
+
for k in range(1, cute.size(tCrA.shape[2]))
|
| 766 |
+
)
|
| 767 |
+
+ "}\n",
|
| 768 |
+
"r,r,r,r,r,r,r,r",
|
| 769 |
+
has_side_effects=True,
|
| 770 |
+
is_align_stack=False,
|
| 771 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
@cute.jit
|
| 776 |
+
def gemm_ptx_precomputed(
|
| 777 |
+
acc_tmem_addr: Int32,
|
| 778 |
+
smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
|
| 779 |
+
smem_desc_start_b: Int32,
|
| 780 |
+
idesc: int,
|
| 781 |
+
smem_desc_base_a: Optional[int],
|
| 782 |
+
smem_desc_base_b: int,
|
| 783 |
+
tCrA_layout: cute.Layout,
|
| 784 |
+
tCrB_layout: cute.Layout,
|
| 785 |
+
mbar_ptr: Optional[cutlass.Pointer] = None,
|
| 786 |
+
mbar_phase: Optional[Int32] = None,
|
| 787 |
+
zero_init: bool | Boolean = False,
|
| 788 |
+
cta_group: int = 1,
|
| 789 |
+
) -> None:
|
| 790 |
+
# acc_tmem_addr += acc_offset
|
| 791 |
+
is_ts = const_expr(smem_desc_base_a is None)
|
| 792 |
+
num_k_tile = cute.size(tCrA_layout.shape[2])
|
| 793 |
+
if const_expr(not is_ts):
|
| 794 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 795 |
+
else:
|
| 796 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 797 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 798 |
+
|
| 799 |
+
tCrA_layout = (
|
| 800 |
+
tCrA_layout
|
| 801 |
+
if const_expr(not is_ts)
|
| 802 |
+
# else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
|
| 803 |
+
# currently hard-coding the width to 16
|
| 804 |
+
else cute.recast_layout(32, 16, tCrA_layout)
|
| 805 |
+
)
|
| 806 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
|
| 807 |
+
offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)]
|
| 808 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
|
| 809 |
+
offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)]
|
| 810 |
+
|
| 811 |
+
smem_desc_start_a_lo = None
|
| 812 |
+
if const_expr(not is_ts):
|
| 813 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
|
| 814 |
+
# smem_desc_start_a_lo = smem_desc_start_a
|
| 815 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
|
| 816 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 817 |
+
if const_expr(not is_ts):
|
| 818 |
+
assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
|
| 819 |
+
llvm.inline_asm(
|
| 820 |
+
None,
|
| 821 |
+
[
|
| 822 |
+
# acc.iterator.toint().ir_value(),
|
| 823 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
|
| 824 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 825 |
+
Int32(not zero_init).ir_value(),
|
| 826 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 827 |
+
],
|
| 828 |
+
"{\n\t"
|
| 829 |
+
".reg .pred leader_thread;\n\t"
|
| 830 |
+
".reg .pred p;\n\t"
|
| 831 |
+
".reg .b32 idesc;\n\t"
|
| 832 |
+
".reg .b32 tmem_acc;\n\t"
|
| 833 |
+
".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
|
| 834 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 835 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 836 |
+
".reg .b64 smem_desc_a, smem_desc_b;\n\t"
|
| 837 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 838 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 839 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 840 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 841 |
+
"mov.b32 smem_desc_a_lo_start, $0;\n\t"
|
| 842 |
+
"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 843 |
+
f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
|
| 844 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 845 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
|
| 846 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 847 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 848 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
|
| 849 |
+
+ "".join(
|
| 850 |
+
(
|
| 851 |
+
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
|
| 852 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 853 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
|
| 854 |
+
f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 855 |
+
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 856 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 857 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
|
| 858 |
+
)
|
| 859 |
+
for k in range(1, num_k_tile)
|
| 860 |
+
)
|
| 861 |
+
+ "}\n",
|
| 862 |
+
# "r,r,r",
|
| 863 |
+
"r,r,r,r",
|
| 864 |
+
has_side_effects=True,
|
| 865 |
+
is_align_stack=False,
|
| 866 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 867 |
+
)
|
| 868 |
+
else:
|
| 869 |
+
input_args = [
|
| 870 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(),
|
| 871 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 872 |
+
Int32(not zero_init).ir_value(),
|
| 873 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 874 |
+
]
|
| 875 |
+
if const_expr(mbar_ptr is not None):
|
| 876 |
+
assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
|
| 877 |
+
input_args.append(mbar_ptr.toint().ir_value())
|
| 878 |
+
input_args.append(Int32(mbar_phase).ir_value())
|
| 879 |
+
mbar_wait_str = (
|
| 880 |
+
".reg .pred P1; \n\t"
|
| 881 |
+
"LAB_WAIT: \n\t"
|
| 882 |
+
"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
|
| 883 |
+
"@P1 bra DONE; \n\t"
|
| 884 |
+
"bra LAB_WAIT; \n\t"
|
| 885 |
+
"DONE: \n\t"
|
| 886 |
+
)
|
| 887 |
+
else:
|
| 888 |
+
mbar_wait_str = ""
|
| 889 |
+
llvm.inline_asm(
|
| 890 |
+
None,
|
| 891 |
+
# [
|
| 892 |
+
# # acc.iterator.toint().ir_value(),
|
| 893 |
+
# Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(),
|
| 894 |
+
# Int32(smem_desc_start_b_lo).ir_value(),
|
| 895 |
+
# Int32(not zero_init).ir_value(),
|
| 896 |
+
# ],
|
| 897 |
+
input_args,
|
| 898 |
+
"{\n\t"
|
| 899 |
+
".reg .pred leader_thread;\n\t"
|
| 900 |
+
".reg .pred p;\n\t"
|
| 901 |
+
".reg .b32 idesc;\n\t"
|
| 902 |
+
".reg .b32 tmem_acc;\n\t"
|
| 903 |
+
".reg .b32 tmem_a;\n\t"
|
| 904 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 905 |
+
".reg .b32 smem_desc_b_lo;\n\t"
|
| 906 |
+
".reg .b32 smem_desc_b_hi;\n\t"
|
| 907 |
+
".reg .b64 smem_desc_b;\n\t"
|
| 908 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 909 |
+
f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 910 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 911 |
+
f"mov.b32 tmem_acc, $3;\n\t"
|
| 912 |
+
f"mov.b32 tmem_a, $0;\n\t"
|
| 913 |
+
f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
|
| 914 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 915 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 916 |
+
"setp.ne.b32 p, $2, 0;\n\t"
|
| 917 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
|
| 918 |
+
+ "".join(
|
| 919 |
+
(
|
| 920 |
+
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
|
| 921 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 922 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 923 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 924 |
+
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
|
| 925 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 926 |
+
)
|
| 927 |
+
for k in range(
|
| 928 |
+
1,
|
| 929 |
+
num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3,
|
| 930 |
+
)
|
| 931 |
+
)
|
| 932 |
+
+ mbar_wait_str
|
| 933 |
+
+ (
|
| 934 |
+
"".join(
|
| 935 |
+
(
|
| 936 |
+
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
|
| 937 |
+
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 938 |
+
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 939 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
|
| 940 |
+
)
|
| 941 |
+
for k in range(num_k_tile // 4 * 3, num_k_tile)
|
| 942 |
+
)
|
| 943 |
+
if const_expr(mbar_ptr is not None)
|
| 944 |
+
else ""
|
| 945 |
+
)
|
| 946 |
+
+ "}\n",
|
| 947 |
+
"r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
|
| 948 |
+
has_side_effects=True,
|
| 949 |
+
is_align_stack=False,
|
| 950 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
@cute.jit
|
| 955 |
+
def declare_ptx_smem_desc(
|
| 956 |
+
smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
|
| 957 |
+
smem_desc_base_a: Optional[int],
|
| 958 |
+
tCrA_layout: cute.Layout,
|
| 959 |
+
var_name_prefix: str = "smem_desc",
|
| 960 |
+
) -> None:
|
| 961 |
+
is_ts = const_expr(smem_desc_base_a is None)
|
| 962 |
+
num_k_tile = cute.size(tCrA_layout.shape[2])
|
| 963 |
+
smem_desc_base_a_lo, smem_desc_a_hi = None, None
|
| 964 |
+
if const_expr(not is_ts):
|
| 965 |
+
smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
|
| 966 |
+
tCrA_layout = (
|
| 967 |
+
tCrA_layout
|
| 968 |
+
if const_expr(not is_ts)
|
| 969 |
+
# else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
|
| 970 |
+
# currently hard-coding the width to 16
|
| 971 |
+
else cute.recast_layout(32, 16, tCrA_layout)
|
| 972 |
+
)
|
| 973 |
+
offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
|
| 974 |
+
smem_desc_start_a_lo = None
|
| 975 |
+
if const_expr(not is_ts):
|
| 976 |
+
smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
|
| 977 |
+
if const_expr(not is_ts):
|
| 978 |
+
llvm.inline_asm(
|
| 979 |
+
None,
|
| 980 |
+
[Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()],
|
| 981 |
+
f".reg .b32 {var_name_prefix}_lo;\n\t"
|
| 982 |
+
f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t"
|
| 983 |
+
f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t"
|
| 984 |
+
+ "".join(
|
| 985 |
+
(
|
| 986 |
+
f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t"
|
| 987 |
+
f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t"
|
| 988 |
+
)
|
| 989 |
+
for k in range(1, num_k_tile)
|
| 990 |
+
),
|
| 991 |
+
"r",
|
| 992 |
+
has_side_effects=True,
|
| 993 |
+
is_align_stack=False,
|
| 994 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
@cute.jit
|
| 999 |
+
def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None:
|
| 1000 |
+
idesc = const_expr(sm100_desc.mma_op_to_idesc(op))
|
| 1001 |
+
llvm.inline_asm(
|
| 1002 |
+
None,
|
| 1003 |
+
[],
|
| 1004 |
+
f".reg .b32 {var_name};\n\t" # noqa
|
| 1005 |
+
f"mov.b32 {var_name}, {hex(idesc)};\n\t",
|
| 1006 |
+
constraints="",
|
| 1007 |
+
has_side_effects=True,
|
| 1008 |
+
is_align_stack=False,
|
| 1009 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
@cute.jit
|
| 1014 |
+
def gemm_ptx_precomputed_varname(
|
| 1015 |
+
acc_tmem_addr: Int32,
|
| 1016 |
+
smem_desc_start_b: Int32,
|
| 1017 |
+
# idesc: int,
|
| 1018 |
+
smem_desc_base_b: int,
|
| 1019 |
+
tCrB_layout: cute.Layout,
|
| 1020 |
+
smem_var_name_prefix: str,
|
| 1021 |
+
idesc_var_name: str,
|
| 1022 |
+
smem_offset: int,
|
| 1023 |
+
zero_init: bool | Boolean = False,
|
| 1024 |
+
cta_group: int = 1,
|
| 1025 |
+
mma_kind: str = "f16",
|
| 1026 |
+
) -> None:
|
| 1027 |
+
is_ts = False
|
| 1028 |
+
num_k_tile = cute.size(tCrB_layout.shape[2])
|
| 1029 |
+
smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
|
| 1030 |
+
offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
|
| 1031 |
+
|
| 1032 |
+
smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
|
| 1033 |
+
pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
|
| 1034 |
+
if const_expr(not is_ts):
|
| 1035 |
+
llvm.inline_asm(
|
| 1036 |
+
None,
|
| 1037 |
+
[
|
| 1038 |
+
Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
|
| 1039 |
+
Int32(not zero_init).ir_value(),
|
| 1040 |
+
Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
|
| 1041 |
+
],
|
| 1042 |
+
"{\n\t"
|
| 1043 |
+
".reg .pred leader_thread;\n\t"
|
| 1044 |
+
".reg .pred p;\n\t"
|
| 1045 |
+
# ".reg .b32 idesc;\n\t"
|
| 1046 |
+
".reg .b32 tmem_acc;\n\t"
|
| 1047 |
+
".reg .b32 smem_desc_b_lo_start;\n\t"
|
| 1048 |
+
".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
|
| 1049 |
+
".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
|
| 1050 |
+
# ".reg .b64 smem_desc_b;\n\t"
|
| 1051 |
+
f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t"
|
| 1052 |
+
"elect.sync _|leader_thread, -1;\n\t"
|
| 1053 |
+
# f"mov.b32 idesc, {hex(idesc)};\n\t"
|
| 1054 |
+
# f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
|
| 1055 |
+
f"mov.b32 tmem_acc, $2;\n\t"
|
| 1056 |
+
"mov.b32 smem_desc_b_lo_start, $0;\n\t"
|
| 1057 |
+
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
|
| 1058 |
+
f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t"
|
| 1059 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1060 |
+
f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1061 |
+
f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
|
| 1062 |
+
+ "".join(
|
| 1063 |
+
(
|
| 1064 |
+
f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
|
| 1065 |
+
f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1066 |
+
f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 1067 |
+
f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1068 |
+
f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 1069 |
+
)
|
| 1070 |
+
for k in range(1, num_k_tile)
|
| 1071 |
+
)
|
| 1072 |
+
+ "setp.ne.b32 p, $1, 0;\n\t"
|
| 1073 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t"
|
| 1074 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t"
|
| 1075 |
+
+ "".join(
|
| 1076 |
+
(
|
| 1077 |
+
# f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
|
| 1078 |
+
# f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
|
| 1079 |
+
# f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
|
| 1080 |
+
# f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
|
| 1081 |
+
# f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
|
| 1082 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t"
|
| 1083 |
+
# f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t"
|
| 1084 |
+
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{mma_kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t"
|
| 1085 |
+
)
|
| 1086 |
+
for k in range(1, num_k_tile)
|
| 1087 |
+
)
|
| 1088 |
+
+ "}\n",
|
| 1089 |
+
"r,r,r",
|
| 1090 |
+
has_side_effects=True,
|
| 1091 |
+
is_align_stack=False,
|
| 1092 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1093 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/block_info.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32, const_expr
|
| 10 |
+
|
| 11 |
+
from ...src.common.seqlen_info import SeqlenInfoQK
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class BlockInfo:
|
| 16 |
+
tile_m: cutlass.Constexpr[int]
|
| 17 |
+
tile_n: cutlass.Constexpr[int]
|
| 18 |
+
is_causal: cutlass.Constexpr[bool]
|
| 19 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 20 |
+
|
| 21 |
+
@cute.jit
|
| 22 |
+
def get_n_block_min_max(
|
| 23 |
+
self,
|
| 24 |
+
seqlen_info: SeqlenInfoQK,
|
| 25 |
+
m_block: Int32,
|
| 26 |
+
split_idx: Int32 = 0,
|
| 27 |
+
num_splits: Int32 = 1,
|
| 28 |
+
) -> Tuple[Int32, Int32]:
|
| 29 |
+
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
|
| 30 |
+
if const_expr(self.is_causal):
|
| 31 |
+
m_idx_max = (m_block + 1) * self.tile_m
|
| 32 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 33 |
+
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
|
| 34 |
+
n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
|
| 35 |
+
n_block_max = min(n_block_max, cute.ceil_div(n_idx, self.tile_n))
|
| 36 |
+
n_block_min = 0
|
| 37 |
+
if num_splits > 1:
|
| 38 |
+
num_n_blocks_per_split = (
|
| 39 |
+
Int32(0)
|
| 40 |
+
if n_block_max <= n_block_min
|
| 41 |
+
else (n_block_max - n_block_min + num_splits - 1) // num_splits
|
| 42 |
+
)
|
| 43 |
+
n_block_min = n_block_min + split_idx * num_n_blocks_per_split
|
| 44 |
+
n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
|
| 45 |
+
return n_block_min, n_block_max
|
| 46 |
+
|
| 47 |
+
@cute.jit
|
| 48 |
+
def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
|
| 49 |
+
m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
|
| 50 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 51 |
+
m_block_max = cute.ceil_div(
|
| 52 |
+
seqlen_info.seqlen_q * self.qhead_per_kvhead_packgqa, self.tile_m
|
| 53 |
+
)
|
| 54 |
+
m_block_min = 0
|
| 55 |
+
if const_expr(self.is_causal):
|
| 56 |
+
n_idx_min = n_block * self.tile_n
|
| 57 |
+
m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
|
| 58 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 59 |
+
m_idx *= self.qhead_per_kvhead_packgqa
|
| 60 |
+
m_block_min = cutlass.max(m_block_min, m_idx // self.tile_m)
|
| 61 |
+
return m_block_min, m_block_max
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/copy_utils.py
ADDED
|
@@ -0,0 +1,1179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Copy, store, and layout execution helpers.
|
| 5 |
+
|
| 6 |
+
`copy_utils.py` is the canonical owner for generic copy primitives, async
|
| 7 |
+
bulk copy orchestration, TMA copy adapters, and non-TMA store/layout helpers.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Optional, Type, Callable
|
| 12 |
+
|
| 13 |
+
import cutlass
|
| 14 |
+
import cutlass.cute as cute
|
| 15 |
+
from cutlass import Float32, Int32, const_expr
|
| 16 |
+
from cutlass.cute.nvgpu import cpasync
|
| 17 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 18 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 19 |
+
from cutlass._mlir.dialects import llvm
|
| 20 |
+
import cutlass.pipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Generic Copy Primitives
|
| 24 |
+
|
| 25 |
+
@dsl_user_op
|
| 26 |
+
def cvt_copy(
|
| 27 |
+
atom: cute.CopyAtom,
|
| 28 |
+
src: cute.Tensor,
|
| 29 |
+
dst: cute.Tensor,
|
| 30 |
+
*,
|
| 31 |
+
pred: Optional[cute.Tensor] = None,
|
| 32 |
+
loc=None,
|
| 33 |
+
ip=None,
|
| 34 |
+
**kwargs,
|
| 35 |
+
) -> None:
|
| 36 |
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
| 37 |
+
if const_expr(src.element_type != dst.element_type):
|
| 38 |
+
src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip)
|
| 39 |
+
src_cvt.store(src.load().to(dst.element_type))
|
| 40 |
+
src = src_cvt
|
| 41 |
+
cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dsl_user_op
|
| 45 |
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
| 46 |
+
dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
|
| 47 |
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
| 48 |
+
return dst
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dsl_user_op
|
| 52 |
+
def get_copy_atom(
|
| 53 |
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
| 54 |
+
) -> cute.CopyAtom:
|
| 55 |
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
| 56 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 57 |
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dsl_user_op
|
| 61 |
+
def make_tmem_copy(
|
| 62 |
+
tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None
|
| 63 |
+
) -> cute.CopyAtom:
|
| 64 |
+
num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)
|
| 65 |
+
assert num_dp == 32
|
| 66 |
+
assert num_bits == 32
|
| 67 |
+
tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)
|
| 68 |
+
layout_tv = cute.make_layout(
|
| 69 |
+
((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))
|
| 70 |
+
)
|
| 71 |
+
return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dsl_user_op
|
| 75 |
+
def copy(
|
| 76 |
+
src: cute.Tensor,
|
| 77 |
+
dst: cute.Tensor,
|
| 78 |
+
*,
|
| 79 |
+
pred: Optional[cute.Tensor] = None,
|
| 80 |
+
num_copy_elems: int = 1,
|
| 81 |
+
is_async: bool = False,
|
| 82 |
+
loc=None,
|
| 83 |
+
ip=None,
|
| 84 |
+
**kwargs,
|
| 85 |
+
) -> None:
|
| 86 |
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
| 87 |
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def tiled_copy_1d(
|
| 91 |
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
| 92 |
+
) -> cute.TiledCopy:
|
| 93 |
+
num_copy_bits = num_copy_elems * dtype.width
|
| 94 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 95 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 96 |
+
thr_layout = cute.make_layout(num_threads)
|
| 97 |
+
val_layout = cute.make_layout(num_copy_elems)
|
| 98 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def tiled_copy_2d(
|
| 102 |
+
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
| 103 |
+
) -> cute.TiledCopy:
|
| 104 |
+
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
| 105 |
+
copy_elems = num_copy_bits // dtype.width
|
| 106 |
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
| 107 |
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
| 108 |
+
gmem_threads_per_row = major_mode_size // copy_elems
|
| 109 |
+
assert num_threads % gmem_threads_per_row == 0
|
| 110 |
+
thr_layout = cute.make_ordered_layout(
|
| 111 |
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 112 |
+
order=(1, 0),
|
| 113 |
+
)
|
| 114 |
+
val_layout = cute.make_layout((1, copy_elems))
|
| 115 |
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dsl_user_op
|
| 119 |
+
def atomic_add_fp32x4(
|
| 120 |
+
a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
|
| 121 |
+
) -> None:
|
| 122 |
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 123 |
+
# cache_hint = cutlass.Int64(0x12F0000000000000)
|
| 124 |
+
llvm.inline_asm(
|
| 125 |
+
None,
|
| 126 |
+
[
|
| 127 |
+
gmem_ptr_i64,
|
| 128 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 129 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 130 |
+
Float32(c).ir_value(loc=loc, ip=ip),
|
| 131 |
+
Float32(d).ir_value(loc=loc, ip=ip),
|
| 132 |
+
],
|
| 133 |
+
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
|
| 134 |
+
"{\n\t"
|
| 135 |
+
# ".reg .b128 abcd;\n\t"
|
| 136 |
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
| 137 |
+
".reg .v4 .f32 abcd;\n\t"
|
| 138 |
+
# "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
|
| 139 |
+
"mov.f32 abcd.x, $1;\n\t"
|
| 140 |
+
"mov.f32 abcd.y, $2;\n\t"
|
| 141 |
+
"mov.f32 abcd.z, $3;\n\t"
|
| 142 |
+
"mov.f32 abcd.w, $4;\n\t"
|
| 143 |
+
"red.global.add.v4.f32 [$0], abcd;\n\t"
|
| 144 |
+
# "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t"
|
| 145 |
+
"}\n",
|
| 146 |
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
|
| 147 |
+
# "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
|
| 148 |
+
"l,f,f,f,f",
|
| 149 |
+
# "l,f,l",
|
| 150 |
+
has_side_effects=True,
|
| 151 |
+
is_align_stack=False,
|
| 152 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Store/Layout Helpers
|
| 157 |
+
|
| 158 |
+
@dsl_user_op
|
| 159 |
+
def atomic_add_i32(gmem_ptr, *, loc=None, ip=None):
|
| 160 |
+
"""Simple atomicAdd. Intended for use under a single-thread guard."""
|
| 161 |
+
result = llvm.inline_asm(
|
| 162 |
+
T.i32(),
|
| 163 |
+
[gmem_ptr.toint().ir_value(loc=loc, ip=ip)],
|
| 164 |
+
"atom.global.add.u32 $0, [$1], 1;\n",
|
| 165 |
+
"=r,l",
|
| 166 |
+
has_side_effects=True,
|
| 167 |
+
is_align_stack=False,
|
| 168 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 169 |
+
loc=loc,
|
| 170 |
+
ip=ip,
|
| 171 |
+
)
|
| 172 |
+
return Int32(result)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dsl_user_op
|
| 176 |
+
def atomic_add_broadcast_i32(gmem_ptr, *, loc=None, ip=None):
|
| 177 |
+
"""Lane-0 atomicAdd broadcast to the whole warp via shfl."""
|
| 178 |
+
result = llvm.inline_asm(
|
| 179 |
+
T.i32(),
|
| 180 |
+
[gmem_ptr.toint().ir_value(loc=loc, ip=ip)],
|
| 181 |
+
"{\n"
|
| 182 |
+
".reg .pred p;\n"
|
| 183 |
+
".reg .u32 lane, r;\n"
|
| 184 |
+
"mov.u32 lane, %laneid;\n"
|
| 185 |
+
"mov.u32 r, 0;\n"
|
| 186 |
+
"setp.eq.u32 p, lane, 0;\n"
|
| 187 |
+
"@p atom.global.add.u32 r, [$1], 1;\n"
|
| 188 |
+
"shfl.sync.idx.b32 r, r, 0, 31, 0xffffffff;\n"
|
| 189 |
+
"mov.u32 $0, r;\n"
|
| 190 |
+
"}\n",
|
| 191 |
+
"=r,l",
|
| 192 |
+
has_side_effects=True,
|
| 193 |
+
is_align_stack=False,
|
| 194 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 195 |
+
loc=loc,
|
| 196 |
+
ip=ip,
|
| 197 |
+
)
|
| 198 |
+
return Int32(result)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dsl_user_op
|
| 202 |
+
def stg_128(
|
| 203 |
+
gmem_ptr: cute.Pointer,
|
| 204 |
+
v0: Float32,
|
| 205 |
+
v1: Float32,
|
| 206 |
+
v2: Float32,
|
| 207 |
+
v3: Float32,
|
| 208 |
+
*,
|
| 209 |
+
loc=None,
|
| 210 |
+
ip=None,
|
| 211 |
+
):
|
| 212 |
+
llvm.inline_asm(
|
| 213 |
+
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
|
| 214 |
+
[
|
| 215 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 216 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 217 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 218 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 219 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 220 |
+
],
|
| 221 |
+
"st.global.v4.f32 [$4], {$5, $6, $7, $8}; "
|
| 222 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 223 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
|
| 224 |
+
"=f,=f,=f,=f,l,f,f,f,f",
|
| 225 |
+
has_side_effects=True,
|
| 226 |
+
is_align_stack=False,
|
| 227 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 228 |
+
loc=loc,
|
| 229 |
+
ip=ip,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dsl_user_op
|
| 234 |
+
def stg_128_cs(
|
| 235 |
+
gmem_ptr: cute.Pointer,
|
| 236 |
+
v0: Float32,
|
| 237 |
+
v1: Float32,
|
| 238 |
+
v2: Float32,
|
| 239 |
+
v3: Float32,
|
| 240 |
+
*,
|
| 241 |
+
loc=None,
|
| 242 |
+
ip=None,
|
| 243 |
+
):
|
| 244 |
+
llvm.inline_asm(
|
| 245 |
+
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
|
| 246 |
+
[
|
| 247 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 248 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 249 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 250 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 251 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 252 |
+
],
|
| 253 |
+
"st.global.cs.v4.f32 [$4], {$5, $6, $7, $8}; "
|
| 254 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 255 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
|
| 256 |
+
"=f,=f,=f,=f,l,f,f,f,f",
|
| 257 |
+
has_side_effects=True,
|
| 258 |
+
is_align_stack=False,
|
| 259 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 260 |
+
loc=loc,
|
| 261 |
+
ip=ip,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@dsl_user_op
|
| 266 |
+
def stg_64_bf16(
|
| 267 |
+
gmem_ptr: cute.Pointer,
|
| 268 |
+
v0: Float32,
|
| 269 |
+
v1: Float32,
|
| 270 |
+
v2: Float32,
|
| 271 |
+
v3: Float32,
|
| 272 |
+
*,
|
| 273 |
+
loc=None,
|
| 274 |
+
ip=None,
|
| 275 |
+
):
|
| 276 |
+
llvm.inline_asm(
|
| 277 |
+
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
|
| 278 |
+
[
|
| 279 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 280 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 281 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 282 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 283 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 284 |
+
],
|
| 285 |
+
"{\n"
|
| 286 |
+
".reg .b16 h0, h1, h2, h3;\n"
|
| 287 |
+
".reg .b32 p0, p1;\n"
|
| 288 |
+
"cvt.rn.bf16.f32 h0, $5;\n"
|
| 289 |
+
"cvt.rn.bf16.f32 h1, $6;\n"
|
| 290 |
+
"cvt.rn.bf16.f32 h2, $7;\n"
|
| 291 |
+
"cvt.rn.bf16.f32 h3, $8;\n"
|
| 292 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 293 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 294 |
+
"st.global.v2.b32 [$4], {p0, p1};\n"
|
| 295 |
+
"}\n"
|
| 296 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 297 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
|
| 298 |
+
"=f,=f,=f,=f,l,f,f,f,f",
|
| 299 |
+
has_side_effects=True,
|
| 300 |
+
is_align_stack=False,
|
| 301 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 302 |
+
loc=loc,
|
| 303 |
+
ip=ip,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@dsl_user_op
|
| 308 |
+
def stg_64_f16(
|
| 309 |
+
gmem_ptr: cute.Pointer,
|
| 310 |
+
v0: Float32,
|
| 311 |
+
v1: Float32,
|
| 312 |
+
v2: Float32,
|
| 313 |
+
v3: Float32,
|
| 314 |
+
*,
|
| 315 |
+
loc=None,
|
| 316 |
+
ip=None,
|
| 317 |
+
):
|
| 318 |
+
llvm.inline_asm(
|
| 319 |
+
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
|
| 320 |
+
[
|
| 321 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 322 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 323 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 324 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 325 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 326 |
+
],
|
| 327 |
+
"{\n"
|
| 328 |
+
".reg .f16 h0, h1, h2, h3;\n"
|
| 329 |
+
".reg .b32 p0, p1;\n"
|
| 330 |
+
"cvt.rn.f16.f32 h0, $5;\n"
|
| 331 |
+
"cvt.rn.f16.f32 h1, $6;\n"
|
| 332 |
+
"cvt.rn.f16.f32 h2, $7;\n"
|
| 333 |
+
"cvt.rn.f16.f32 h3, $8;\n"
|
| 334 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 335 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 336 |
+
"st.global.v2.b32 [$4], {p0, p1};\n"
|
| 337 |
+
"}\n"
|
| 338 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 339 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
|
| 340 |
+
"=f,=f,=f,=f,l,f,f,f,f",
|
| 341 |
+
has_side_effects=True,
|
| 342 |
+
is_align_stack=False,
|
| 343 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 344 |
+
loc=loc,
|
| 345 |
+
ip=ip,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@dsl_user_op
|
| 350 |
+
def stg_32_fp8_e4m3(
|
| 351 |
+
gmem_ptr: cute.Pointer,
|
| 352 |
+
v0: Float32,
|
| 353 |
+
v1: Float32,
|
| 354 |
+
v2: Float32,
|
| 355 |
+
v3: Float32,
|
| 356 |
+
*,
|
| 357 |
+
loc=None,
|
| 358 |
+
ip=None,
|
| 359 |
+
):
|
| 360 |
+
llvm.inline_asm(
|
| 361 |
+
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
|
| 362 |
+
[
|
| 363 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 364 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 365 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 366 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 367 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 368 |
+
],
|
| 369 |
+
"{\n"
|
| 370 |
+
".reg .b16 h0, h1;\n"
|
| 371 |
+
".reg .b32 p0;\n"
|
| 372 |
+
"cvt.rn.satfinite.e4m3x2.f32 h0, $6, $5;\n"
|
| 373 |
+
"cvt.rn.satfinite.e4m3x2.f32 h1, $8, $7;\n"
|
| 374 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 375 |
+
"st.global.b32 [$4], p0;\n"
|
| 376 |
+
"}\n"
|
| 377 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 378 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000;",
|
| 379 |
+
"=f,=f,=f,=f,l,f,f,f,f",
|
| 380 |
+
has_side_effects=True,
|
| 381 |
+
is_align_stack=False,
|
| 382 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 383 |
+
loc=loc,
|
| 384 |
+
ip=ip,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@dsl_user_op
|
| 389 |
+
def sts_32_bf16(
|
| 390 |
+
smem_ptr: cute.Pointer,
|
| 391 |
+
v0: Float32,
|
| 392 |
+
v1: Float32,
|
| 393 |
+
*,
|
| 394 |
+
loc=None,
|
| 395 |
+
ip=None,
|
| 396 |
+
):
|
| 397 |
+
"""Store two bf16 values to shared memory as one 32-bit transaction."""
|
| 398 |
+
llvm.inline_asm(
|
| 399 |
+
None,
|
| 400 |
+
[
|
| 401 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 402 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 403 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 404 |
+
],
|
| 405 |
+
"{\n"
|
| 406 |
+
".reg .u32 sa;\n"
|
| 407 |
+
".reg .b16 h0, h1;\n"
|
| 408 |
+
".reg .b32 p0;\n"
|
| 409 |
+
"cvt.u32.u64 sa, $0;\n"
|
| 410 |
+
"cvt.rn.bf16.f32 h0, $1;\n"
|
| 411 |
+
"cvt.rn.bf16.f32 h1, $2;\n"
|
| 412 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 413 |
+
"st.shared.b32 [sa], p0;\n"
|
| 414 |
+
"}\n",
|
| 415 |
+
"l,f,f",
|
| 416 |
+
has_side_effects=True,
|
| 417 |
+
is_align_stack=False,
|
| 418 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 419 |
+
loc=loc,
|
| 420 |
+
ip=ip,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@dsl_user_op
|
| 425 |
+
def sts_32_f16(
|
| 426 |
+
smem_ptr: cute.Pointer,
|
| 427 |
+
v0: Float32,
|
| 428 |
+
v1: Float32,
|
| 429 |
+
*,
|
| 430 |
+
loc=None,
|
| 431 |
+
ip=None,
|
| 432 |
+
):
|
| 433 |
+
"""Store two fp16 values to shared memory as one 32-bit transaction."""
|
| 434 |
+
llvm.inline_asm(
|
| 435 |
+
None,
|
| 436 |
+
[
|
| 437 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 438 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 439 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 440 |
+
],
|
| 441 |
+
"{\n"
|
| 442 |
+
".reg .u32 sa;\n"
|
| 443 |
+
".reg .f16 h0, h1;\n"
|
| 444 |
+
".reg .b32 p0;\n"
|
| 445 |
+
"cvt.u32.u64 sa, $0;\n"
|
| 446 |
+
"cvt.rn.f16.f32 h0, $1;\n"
|
| 447 |
+
"cvt.rn.f16.f32 h1, $2;\n"
|
| 448 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 449 |
+
"st.shared.b32 [sa], p0;\n"
|
| 450 |
+
"}\n",
|
| 451 |
+
"l,f,f",
|
| 452 |
+
has_side_effects=True,
|
| 453 |
+
is_align_stack=False,
|
| 454 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 455 |
+
loc=loc,
|
| 456 |
+
ip=ip,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
@dsl_user_op
|
| 461 |
+
def stg_128_bf16(
|
| 462 |
+
gmem_ptr: cute.Pointer,
|
| 463 |
+
v0: Float32,
|
| 464 |
+
v1: Float32,
|
| 465 |
+
v2: Float32,
|
| 466 |
+
v3: Float32,
|
| 467 |
+
v4: Float32,
|
| 468 |
+
v5: Float32,
|
| 469 |
+
v6: Float32,
|
| 470 |
+
v7: Float32,
|
| 471 |
+
*,
|
| 472 |
+
loc=None,
|
| 473 |
+
ip=None,
|
| 474 |
+
):
|
| 475 |
+
llvm.inline_asm(
|
| 476 |
+
llvm.StructType.get_literal(
|
| 477 |
+
[T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
|
| 478 |
+
),
|
| 479 |
+
[
|
| 480 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 481 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 482 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 483 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 484 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 485 |
+
Float32(v4).ir_value(loc=loc, ip=ip),
|
| 486 |
+
Float32(v5).ir_value(loc=loc, ip=ip),
|
| 487 |
+
Float32(v6).ir_value(loc=loc, ip=ip),
|
| 488 |
+
Float32(v7).ir_value(loc=loc, ip=ip),
|
| 489 |
+
],
|
| 490 |
+
"{\n"
|
| 491 |
+
".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
|
| 492 |
+
".reg .b32 p0, p1, p2, p3;\n"
|
| 493 |
+
"cvt.rn.bf16.f32 h0, $9;\n"
|
| 494 |
+
"cvt.rn.bf16.f32 h1, $10;\n"
|
| 495 |
+
"cvt.rn.bf16.f32 h2, $11;\n"
|
| 496 |
+
"cvt.rn.bf16.f32 h3, $12;\n"
|
| 497 |
+
"cvt.rn.bf16.f32 h4, $13;\n"
|
| 498 |
+
"cvt.rn.bf16.f32 h5, $14;\n"
|
| 499 |
+
"cvt.rn.bf16.f32 h6, $15;\n"
|
| 500 |
+
"cvt.rn.bf16.f32 h7, $16;\n"
|
| 501 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 502 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 503 |
+
"mov.b32 p2, {h4, h5};\n"
|
| 504 |
+
"mov.b32 p3, {h6, h7};\n"
|
| 505 |
+
"st.global.v4.b32 [$8], {p0, p1, p2, p3};\n"
|
| 506 |
+
"}\n"
|
| 507 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 508 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
|
| 509 |
+
"mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
|
| 510 |
+
"mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
|
| 511 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
|
| 512 |
+
has_side_effects=True,
|
| 513 |
+
is_align_stack=False,
|
| 514 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 515 |
+
loc=loc,
|
| 516 |
+
ip=ip,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
@dsl_user_op
|
| 521 |
+
def stg_128_bf16_cs(
|
| 522 |
+
gmem_ptr: cute.Pointer,
|
| 523 |
+
v0: Float32,
|
| 524 |
+
v1: Float32,
|
| 525 |
+
v2: Float32,
|
| 526 |
+
v3: Float32,
|
| 527 |
+
v4: Float32,
|
| 528 |
+
v5: Float32,
|
| 529 |
+
v6: Float32,
|
| 530 |
+
v7: Float32,
|
| 531 |
+
*,
|
| 532 |
+
loc=None,
|
| 533 |
+
ip=None,
|
| 534 |
+
):
|
| 535 |
+
llvm.inline_asm(
|
| 536 |
+
llvm.StructType.get_literal(
|
| 537 |
+
[T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
|
| 538 |
+
),
|
| 539 |
+
[
|
| 540 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 541 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 542 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 543 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 544 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 545 |
+
Float32(v4).ir_value(loc=loc, ip=ip),
|
| 546 |
+
Float32(v5).ir_value(loc=loc, ip=ip),
|
| 547 |
+
Float32(v6).ir_value(loc=loc, ip=ip),
|
| 548 |
+
Float32(v7).ir_value(loc=loc, ip=ip),
|
| 549 |
+
],
|
| 550 |
+
"{\n"
|
| 551 |
+
".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
|
| 552 |
+
".reg .b32 p0, p1, p2, p3;\n"
|
| 553 |
+
"cvt.rn.bf16.f32 h0, $9;\n"
|
| 554 |
+
"cvt.rn.bf16.f32 h1, $10;\n"
|
| 555 |
+
"cvt.rn.bf16.f32 h2, $11;\n"
|
| 556 |
+
"cvt.rn.bf16.f32 h3, $12;\n"
|
| 557 |
+
"cvt.rn.bf16.f32 h4, $13;\n"
|
| 558 |
+
"cvt.rn.bf16.f32 h5, $14;\n"
|
| 559 |
+
"cvt.rn.bf16.f32 h6, $15;\n"
|
| 560 |
+
"cvt.rn.bf16.f32 h7, $16;\n"
|
| 561 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 562 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 563 |
+
"mov.b32 p2, {h4, h5};\n"
|
| 564 |
+
"mov.b32 p3, {h6, h7};\n"
|
| 565 |
+
"st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n"
|
| 566 |
+
"}\n"
|
| 567 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 568 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
|
| 569 |
+
"mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
|
| 570 |
+
"mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
|
| 571 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
|
| 572 |
+
has_side_effects=True,
|
| 573 |
+
is_align_stack=False,
|
| 574 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 575 |
+
loc=loc,
|
| 576 |
+
ip=ip,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@dsl_user_op
|
| 581 |
+
def stg_128_f16(
|
| 582 |
+
gmem_ptr: cute.Pointer,
|
| 583 |
+
v0: Float32,
|
| 584 |
+
v1: Float32,
|
| 585 |
+
v2: Float32,
|
| 586 |
+
v3: Float32,
|
| 587 |
+
v4: Float32,
|
| 588 |
+
v5: Float32,
|
| 589 |
+
v6: Float32,
|
| 590 |
+
v7: Float32,
|
| 591 |
+
*,
|
| 592 |
+
loc=None,
|
| 593 |
+
ip=None,
|
| 594 |
+
):
|
| 595 |
+
llvm.inline_asm(
|
| 596 |
+
llvm.StructType.get_literal(
|
| 597 |
+
[T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
|
| 598 |
+
),
|
| 599 |
+
[
|
| 600 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 601 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 602 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 603 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 604 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 605 |
+
Float32(v4).ir_value(loc=loc, ip=ip),
|
| 606 |
+
Float32(v5).ir_value(loc=loc, ip=ip),
|
| 607 |
+
Float32(v6).ir_value(loc=loc, ip=ip),
|
| 608 |
+
Float32(v7).ir_value(loc=loc, ip=ip),
|
| 609 |
+
],
|
| 610 |
+
"{\n"
|
| 611 |
+
".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
|
| 612 |
+
".reg .b32 p0, p1, p2, p3;\n"
|
| 613 |
+
"cvt.rn.f16.f32 h0, $9;\n"
|
| 614 |
+
"cvt.rn.f16.f32 h1, $10;\n"
|
| 615 |
+
"cvt.rn.f16.f32 h2, $11;\n"
|
| 616 |
+
"cvt.rn.f16.f32 h3, $12;\n"
|
| 617 |
+
"cvt.rn.f16.f32 h4, $13;\n"
|
| 618 |
+
"cvt.rn.f16.f32 h5, $14;\n"
|
| 619 |
+
"cvt.rn.f16.f32 h6, $15;\n"
|
| 620 |
+
"cvt.rn.f16.f32 h7, $16;\n"
|
| 621 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 622 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 623 |
+
"mov.b32 p2, {h4, h5};\n"
|
| 624 |
+
"mov.b32 p3, {h6, h7};\n"
|
| 625 |
+
"st.global.v4.b32 [$8], {p0, p1, p2, p3};\n"
|
| 626 |
+
"}\n"
|
| 627 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 628 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
|
| 629 |
+
"mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
|
| 630 |
+
"mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
|
| 631 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
|
| 632 |
+
has_side_effects=True,
|
| 633 |
+
is_align_stack=False,
|
| 634 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 635 |
+
loc=loc,
|
| 636 |
+
ip=ip,
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
@dsl_user_op
|
| 641 |
+
def stg_128_f16_cs(
|
| 642 |
+
gmem_ptr: cute.Pointer,
|
| 643 |
+
v0: Float32,
|
| 644 |
+
v1: Float32,
|
| 645 |
+
v2: Float32,
|
| 646 |
+
v3: Float32,
|
| 647 |
+
v4: Float32,
|
| 648 |
+
v5: Float32,
|
| 649 |
+
v6: Float32,
|
| 650 |
+
v7: Float32,
|
| 651 |
+
*,
|
| 652 |
+
loc=None,
|
| 653 |
+
ip=None,
|
| 654 |
+
):
|
| 655 |
+
llvm.inline_asm(
|
| 656 |
+
llvm.StructType.get_literal(
|
| 657 |
+
[T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32(), T.f32()]
|
| 658 |
+
),
|
| 659 |
+
[
|
| 660 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 661 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 662 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 663 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 664 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 665 |
+
Float32(v4).ir_value(loc=loc, ip=ip),
|
| 666 |
+
Float32(v5).ir_value(loc=loc, ip=ip),
|
| 667 |
+
Float32(v6).ir_value(loc=loc, ip=ip),
|
| 668 |
+
Float32(v7).ir_value(loc=loc, ip=ip),
|
| 669 |
+
],
|
| 670 |
+
"{\n"
|
| 671 |
+
".reg .f16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
|
| 672 |
+
".reg .b32 p0, p1, p2, p3;\n"
|
| 673 |
+
"cvt.rn.f16.f32 h0, $9;\n"
|
| 674 |
+
"cvt.rn.f16.f32 h1, $10;\n"
|
| 675 |
+
"cvt.rn.f16.f32 h2, $11;\n"
|
| 676 |
+
"cvt.rn.f16.f32 h3, $12;\n"
|
| 677 |
+
"cvt.rn.f16.f32 h4, $13;\n"
|
| 678 |
+
"cvt.rn.f16.f32 h5, $14;\n"
|
| 679 |
+
"cvt.rn.f16.f32 h6, $15;\n"
|
| 680 |
+
"cvt.rn.f16.f32 h7, $16;\n"
|
| 681 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 682 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 683 |
+
"mov.b32 p2, {h4, h5};\n"
|
| 684 |
+
"mov.b32 p3, {h6, h7};\n"
|
| 685 |
+
"st.global.cs.v4.b32 [$8], {p0, p1, p2, p3};\n"
|
| 686 |
+
"}\n"
|
| 687 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 688 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
|
| 689 |
+
"mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
|
| 690 |
+
"mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000;",
|
| 691 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,l,f,f,f,f,f,f,f,f",
|
| 692 |
+
has_side_effects=True,
|
| 693 |
+
is_align_stack=False,
|
| 694 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 695 |
+
loc=loc,
|
| 696 |
+
ip=ip,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
@dsl_user_op
|
| 701 |
+
def stg_128_fp8_e4m3_cs(
|
| 702 |
+
gmem_ptr: cute.Pointer,
|
| 703 |
+
v0: Float32,
|
| 704 |
+
v1: Float32,
|
| 705 |
+
v2: Float32,
|
| 706 |
+
v3: Float32,
|
| 707 |
+
v4: Float32,
|
| 708 |
+
v5: Float32,
|
| 709 |
+
v6: Float32,
|
| 710 |
+
v7: Float32,
|
| 711 |
+
v8: Float32,
|
| 712 |
+
v9: Float32,
|
| 713 |
+
v10: Float32,
|
| 714 |
+
v11: Float32,
|
| 715 |
+
v12: Float32,
|
| 716 |
+
v13: Float32,
|
| 717 |
+
v14: Float32,
|
| 718 |
+
v15: Float32,
|
| 719 |
+
*,
|
| 720 |
+
loc=None,
|
| 721 |
+
ip=None,
|
| 722 |
+
):
|
| 723 |
+
llvm.inline_asm(
|
| 724 |
+
llvm.StructType.get_literal(
|
| 725 |
+
[
|
| 726 |
+
T.f32(),
|
| 727 |
+
T.f32(),
|
| 728 |
+
T.f32(),
|
| 729 |
+
T.f32(),
|
| 730 |
+
T.f32(),
|
| 731 |
+
T.f32(),
|
| 732 |
+
T.f32(),
|
| 733 |
+
T.f32(),
|
| 734 |
+
T.f32(),
|
| 735 |
+
T.f32(),
|
| 736 |
+
T.f32(),
|
| 737 |
+
T.f32(),
|
| 738 |
+
T.f32(),
|
| 739 |
+
T.f32(),
|
| 740 |
+
T.f32(),
|
| 741 |
+
T.f32(),
|
| 742 |
+
]
|
| 743 |
+
),
|
| 744 |
+
[
|
| 745 |
+
gmem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 746 |
+
Float32(v0).ir_value(loc=loc, ip=ip),
|
| 747 |
+
Float32(v1).ir_value(loc=loc, ip=ip),
|
| 748 |
+
Float32(v2).ir_value(loc=loc, ip=ip),
|
| 749 |
+
Float32(v3).ir_value(loc=loc, ip=ip),
|
| 750 |
+
Float32(v4).ir_value(loc=loc, ip=ip),
|
| 751 |
+
Float32(v5).ir_value(loc=loc, ip=ip),
|
| 752 |
+
Float32(v6).ir_value(loc=loc, ip=ip),
|
| 753 |
+
Float32(v7).ir_value(loc=loc, ip=ip),
|
| 754 |
+
Float32(v8).ir_value(loc=loc, ip=ip),
|
| 755 |
+
Float32(v9).ir_value(loc=loc, ip=ip),
|
| 756 |
+
Float32(v10).ir_value(loc=loc, ip=ip),
|
| 757 |
+
Float32(v11).ir_value(loc=loc, ip=ip),
|
| 758 |
+
Float32(v12).ir_value(loc=loc, ip=ip),
|
| 759 |
+
Float32(v13).ir_value(loc=loc, ip=ip),
|
| 760 |
+
Float32(v14).ir_value(loc=loc, ip=ip),
|
| 761 |
+
Float32(v15).ir_value(loc=loc, ip=ip),
|
| 762 |
+
],
|
| 763 |
+
"{\n"
|
| 764 |
+
".reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;\n"
|
| 765 |
+
".reg .b32 p0, p1, p2, p3;\n"
|
| 766 |
+
"cvt.rn.satfinite.e4m3x2.f32 h0, $18, $17;\n"
|
| 767 |
+
"cvt.rn.satfinite.e4m3x2.f32 h1, $20, $19;\n"
|
| 768 |
+
"cvt.rn.satfinite.e4m3x2.f32 h2, $22, $21;\n"
|
| 769 |
+
"cvt.rn.satfinite.e4m3x2.f32 h3, $24, $23;\n"
|
| 770 |
+
"cvt.rn.satfinite.e4m3x2.f32 h4, $26, $25;\n"
|
| 771 |
+
"cvt.rn.satfinite.e4m3x2.f32 h5, $28, $27;\n"
|
| 772 |
+
"cvt.rn.satfinite.e4m3x2.f32 h6, $30, $29;\n"
|
| 773 |
+
"cvt.rn.satfinite.e4m3x2.f32 h7, $32, $31;\n"
|
| 774 |
+
"mov.b32 p0, {h0, h1};\n"
|
| 775 |
+
"mov.b32 p1, {h2, h3};\n"
|
| 776 |
+
"mov.b32 p2, {h4, h5};\n"
|
| 777 |
+
"mov.b32 p3, {h6, h7};\n"
|
| 778 |
+
"st.global.cs.v4.b32 [$16], {p0, p1, p2, p3};\n"
|
| 779 |
+
"}\n"
|
| 780 |
+
"mov.f32 $0, 0f00000000; mov.f32 $1, 0f00000000; "
|
| 781 |
+
"mov.f32 $2, 0f00000000; mov.f32 $3, 0f00000000; "
|
| 782 |
+
"mov.f32 $4, 0f00000000; mov.f32 $5, 0f00000000; "
|
| 783 |
+
"mov.f32 $6, 0f00000000; mov.f32 $7, 0f00000000; "
|
| 784 |
+
"mov.f32 $8, 0f00000000; mov.f32 $9, 0f00000000; "
|
| 785 |
+
"mov.f32 $10, 0f00000000; mov.f32 $11, 0f00000000; "
|
| 786 |
+
"mov.f32 $12, 0f00000000; mov.f32 $13, 0f00000000; "
|
| 787 |
+
"mov.f32 $14, 0f00000000; mov.f32 $15, 0f00000000;",
|
| 788 |
+
(
|
| 789 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,"
|
| 790 |
+
"=f,=f,=f,=f,=f,=f,=f,=f,"
|
| 791 |
+
"l,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f"
|
| 792 |
+
),
|
| 793 |
+
has_side_effects=True,
|
| 794 |
+
is_align_stack=False,
|
| 795 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 796 |
+
loc=loc,
|
| 797 |
+
ip=ip,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def convert_layout_from_tmem16x256b_to_acc_sm90(acc_layout: cute.Layout) -> cute.Layout:
|
| 802 |
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 803 |
+
acc_layout_mn = cute.make_layout(
|
| 804 |
+
(
|
| 805 |
+
acc_layout_col_major.shape[0][0],
|
| 806 |
+
acc_layout_col_major.shape[0][1],
|
| 807 |
+
acc_layout_col_major.shape[1],
|
| 808 |
+
*acc_layout_col_major.shape[2:],
|
| 809 |
+
),
|
| 810 |
+
stride=(
|
| 811 |
+
acc_layout_col_major.stride[0][0],
|
| 812 |
+
acc_layout_col_major.stride[0][1],
|
| 813 |
+
acc_layout_col_major.stride[1],
|
| 814 |
+
*acc_layout_col_major.stride[2:],
|
| 815 |
+
),
|
| 816 |
+
)
|
| 817 |
+
return cute.composition(acc_layout, acc_layout_mn)
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
| 821 |
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
| 822 |
+
acc_layout_mn = cute.make_layout(
|
| 823 |
+
(
|
| 824 |
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]),
|
| 825 |
+
(
|
| 826 |
+
acc_layout_col_major.shape[0][0],
|
| 827 |
+
*acc_layout_col_major.shape[0][2:],
|
| 828 |
+
acc_layout_col_major.shape[2],
|
| 829 |
+
),
|
| 830 |
+
*acc_layout_col_major.shape[3:],
|
| 831 |
+
),
|
| 832 |
+
stride=(
|
| 833 |
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]),
|
| 834 |
+
(
|
| 835 |
+
acc_layout_col_major.stride[0][0],
|
| 836 |
+
*acc_layout_col_major.stride[0][2:],
|
| 837 |
+
acc_layout_col_major.stride[2],
|
| 838 |
+
),
|
| 839 |
+
*acc_layout_col_major.stride[3:],
|
| 840 |
+
),
|
| 841 |
+
)
|
| 842 |
+
return cute.composition(acc_layout, acc_layout_mn)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def make_16x256b_tensor_mn_view(tensor: cute.Tensor) -> cute.Tensor:
|
| 846 |
+
layout = convert_layout_acc_mn(
|
| 847 |
+
convert_layout_from_tmem16x256b_to_acc_sm90(tensor.layout)
|
| 848 |
+
)
|
| 849 |
+
return cute.make_tensor(tensor.iterator, layout)
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def real_col_to_stg128_fake_col(col: Int32) -> Int32:
|
| 853 |
+
nt = col // Int32(16)
|
| 854 |
+
col16 = col - nt * Int32(16)
|
| 855 |
+
pair = col16 // Int32(2)
|
| 856 |
+
rank = pair % Int32(4)
|
| 857 |
+
kv = (pair // Int32(4)) * Int32(2) + (col16 % Int32(2))
|
| 858 |
+
return nt * Int32(16) + rank * Int32(4) + kv
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def stg128_fake_col_to_real_col(fake_col: Int32) -> Int32:
|
| 862 |
+
nt = fake_col // Int32(16)
|
| 863 |
+
fake16 = fake_col - nt * Int32(16)
|
| 864 |
+
rank = fake16 // Int32(4)
|
| 865 |
+
kv = fake16 % Int32(4)
|
| 866 |
+
return nt * Int32(16) + rank * Int32(2) + (kv // Int32(2)) * Int32(8) + (kv % Int32(2))
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def real_col_to_stg128_half_fake_col(col: Int32) -> Int32:
|
| 870 |
+
nt = col // Int32(32)
|
| 871 |
+
col32 = col - nt * Int32(32)
|
| 872 |
+
lane = (col32 % Int32(8)) // Int32(2)
|
| 873 |
+
group = col32 // Int32(8)
|
| 874 |
+
elem = col32 % Int32(2)
|
| 875 |
+
return nt * Int32(32) + lane * Int32(8) + group * Int32(2) + elem
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def stg128_half_fake_col_to_real_col(fake_col: Int32) -> Int32:
|
| 879 |
+
nt = fake_col // Int32(32)
|
| 880 |
+
fake32 = fake_col - nt * Int32(32)
|
| 881 |
+
lane = fake32 // Int32(8)
|
| 882 |
+
lane_slot = fake32 - lane * Int32(8)
|
| 883 |
+
group = lane_slot // Int32(2)
|
| 884 |
+
elem = lane_slot - group * Int32(2)
|
| 885 |
+
return nt * Int32(32) + group * Int32(8) + lane * Int32(2) + elem
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def real_col_to_stg128_fp8_fake_col(col: Int32) -> Int32:
|
| 889 |
+
nt = col // Int32(64)
|
| 890 |
+
col64 = col - nt * Int32(64)
|
| 891 |
+
lane = (col64 % Int32(8)) // Int32(2)
|
| 892 |
+
group = col64 // Int32(8)
|
| 893 |
+
elem = col64 % Int32(2)
|
| 894 |
+
return nt * Int32(64) + lane * Int32(16) + group * Int32(2) + elem
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def stg128_fp8_fake_col_to_real_col(fake_col: Int32) -> Int32:
|
| 898 |
+
nt = fake_col // Int32(64)
|
| 899 |
+
fake64 = fake_col - nt * Int32(64)
|
| 900 |
+
lane = fake64 // Int32(16)
|
| 901 |
+
lane_slot = fake64 - lane * Int32(16)
|
| 902 |
+
group = lane_slot // Int32(2)
|
| 903 |
+
elem = lane_slot - group * Int32(2)
|
| 904 |
+
return nt * Int32(64) + group * Int32(8) + lane * Int32(2) + elem
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
# Cluster & Bulk Async Ops
|
| 908 |
+
|
| 909 |
+
@dsl_user_op
|
| 910 |
+
def set_block_rank(
|
| 911 |
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
| 912 |
+
) -> Int32:
|
| 913 |
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
| 914 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 915 |
+
return Int32(
|
| 916 |
+
llvm.inline_asm(
|
| 917 |
+
T.i32(),
|
| 918 |
+
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
| 919 |
+
"mapa.shared::cluster.u32 $0, $1, $2;",
|
| 920 |
+
"=r,r,r",
|
| 921 |
+
has_side_effects=False,
|
| 922 |
+
is_align_stack=False,
|
| 923 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 924 |
+
)
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
@dsl_user_op
|
| 929 |
+
def store_shared_remote_fp32x4(
|
| 930 |
+
a: Float32,
|
| 931 |
+
b: Float32,
|
| 932 |
+
c: Float32,
|
| 933 |
+
d: Float32,
|
| 934 |
+
smem_ptr: cute.Pointer,
|
| 935 |
+
mbar_ptr: cute.Pointer,
|
| 936 |
+
peer_cta_rank_in_cluster: Int32,
|
| 937 |
+
*,
|
| 938 |
+
loc=None,
|
| 939 |
+
ip=None,
|
| 940 |
+
) -> None:
|
| 941 |
+
remote_smem_ptr_i32 = set_block_rank(
|
| 942 |
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 943 |
+
).ir_value()
|
| 944 |
+
remote_mbar_ptr_i32 = set_block_rank(
|
| 945 |
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 946 |
+
).ir_value()
|
| 947 |
+
llvm.inline_asm(
|
| 948 |
+
None,
|
| 949 |
+
[
|
| 950 |
+
remote_smem_ptr_i32,
|
| 951 |
+
remote_mbar_ptr_i32,
|
| 952 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 953 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 954 |
+
Float32(c).ir_value(loc=loc, ip=ip),
|
| 955 |
+
Float32(d).ir_value(loc=loc, ip=ip),
|
| 956 |
+
],
|
| 957 |
+
"{\n\t"
|
| 958 |
+
".reg .v4 .f32 abcd;\n\t"
|
| 959 |
+
"mov.f32 abcd.x, $2;\n\t"
|
| 960 |
+
"mov.f32 abcd.y, $3;\n\t"
|
| 961 |
+
"mov.f32 abcd.z, $4;\n\t"
|
| 962 |
+
"mov.f32 abcd.w, $5;\n\t"
|
| 963 |
+
"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t"
|
| 964 |
+
"}\n",
|
| 965 |
+
"r,r,f,f,f,f",
|
| 966 |
+
has_side_effects=True,
|
| 967 |
+
is_align_stack=False,
|
| 968 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
@dsl_user_op
|
| 973 |
+
def cpasync_bulk_s2cluster(
|
| 974 |
+
smem_src_ptr: cute.Pointer,
|
| 975 |
+
smem_dst_ptr: cute.Pointer,
|
| 976 |
+
mbar_ptr: cute.Pointer,
|
| 977 |
+
size: int | Int32,
|
| 978 |
+
peer_cta_rank_in_cluster: Int32,
|
| 979 |
+
*,
|
| 980 |
+
loc=None,
|
| 981 |
+
ip=None,
|
| 982 |
+
):
|
| 983 |
+
smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 984 |
+
smem_dst_ptr_i32 = set_block_rank(
|
| 985 |
+
smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
| 986 |
+
).ir_value()
|
| 987 |
+
mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
|
| 988 |
+
llvm.inline_asm(
|
| 989 |
+
None,
|
| 990 |
+
[
|
| 991 |
+
smem_dst_ptr_i32,
|
| 992 |
+
smem_src_ptr_i32,
|
| 993 |
+
mbar_ptr_i32,
|
| 994 |
+
Int32(size).ir_value(loc=loc, ip=ip),
|
| 995 |
+
],
|
| 996 |
+
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];",
|
| 997 |
+
"r,r,r,r",
|
| 998 |
+
has_side_effects=True,
|
| 999 |
+
is_align_stack=False,
|
| 1000 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
@dsl_user_op
|
| 1005 |
+
def cpasync_bulk_g2s(
|
| 1006 |
+
gmem_ptr: cute.Pointer,
|
| 1007 |
+
smem_ptr: cute.Pointer,
|
| 1008 |
+
tma_bar_ptr: cute.Pointer,
|
| 1009 |
+
size: int | Int32,
|
| 1010 |
+
*,
|
| 1011 |
+
loc=None,
|
| 1012 |
+
ip=None,
|
| 1013 |
+
):
|
| 1014 |
+
gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 1015 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 1016 |
+
mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 1017 |
+
llvm.inline_asm(
|
| 1018 |
+
None,
|
| 1019 |
+
[gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],
|
| 1020 |
+
"cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];",
|
| 1021 |
+
"l,r,r,r",
|
| 1022 |
+
has_side_effects=True,
|
| 1023 |
+
is_align_stack=False,
|
| 1024 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
@dsl_user_op
|
| 1029 |
+
def cpasync_reduce_bulk_add_f32(
|
| 1030 |
+
smem_ptr: cute.Pointer,
|
| 1031 |
+
gmem_ptr: cute.Pointer,
|
| 1032 |
+
store_bytes: int | Int32,
|
| 1033 |
+
*,
|
| 1034 |
+
loc=None,
|
| 1035 |
+
ip=None,
|
| 1036 |
+
):
|
| 1037 |
+
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 1038 |
+
# cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
|
| 1039 |
+
llvm.inline_asm(
|
| 1040 |
+
None,
|
| 1041 |
+
[gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
|
| 1042 |
+
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
|
| 1043 |
+
"l,r,r",
|
| 1044 |
+
# [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
|
| 1045 |
+
# "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
|
| 1046 |
+
# "l,r,r,l",
|
| 1047 |
+
has_side_effects=True,
|
| 1048 |
+
is_align_stack=False,
|
| 1049 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
def cpasync_bulk_get_copy_fn(
|
| 1054 |
+
src_tensor: cute.Tensor,
|
| 1055 |
+
dst_tensor: cute.Tensor,
|
| 1056 |
+
single_stage: bool = False,
|
| 1057 |
+
**kwargs,
|
| 1058 |
+
) -> Callable:
|
| 1059 |
+
# src_is_smem = const_expr(
|
| 1060 |
+
# isinstance(src_tensor.iterator, cute.Pointer)
|
| 1061 |
+
# and src_tensor.memspace == cute.AddressSpace.smem
|
| 1062 |
+
# )
|
| 1063 |
+
group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
|
| 1064 |
+
group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
|
| 1065 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 1066 |
+
src = cute.group_modes(src_tensor, 0, group_rank_src)
|
| 1067 |
+
dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
|
| 1068 |
+
|
| 1069 |
+
def copy_bulk(src_idx, dst_idx, **new_kwargs):
|
| 1070 |
+
size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)
|
| 1071 |
+
cpasync_bulk_g2s(
|
| 1072 |
+
src[None, src_idx].iterator,
|
| 1073 |
+
dst[None, dst_idx].iterator,
|
| 1074 |
+
size=size,
|
| 1075 |
+
**new_kwargs,
|
| 1076 |
+
**kwargs,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
def copy_bulk_single_stage(**new_kwargs):
|
| 1080 |
+
size = const_expr(cute.size(src.shape) * src.element_type.width // 8)
|
| 1081 |
+
cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)
|
| 1082 |
+
|
| 1083 |
+
return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
# TMA Copy Adapters
|
| 1087 |
+
|
| 1088 |
+
def tma_get_copy_fn(
|
| 1089 |
+
atom: cute.CopyAtom,
|
| 1090 |
+
cta_coord: cute.Coord,
|
| 1091 |
+
cta_layout: cute.Layout,
|
| 1092 |
+
src_tensor: cute.Tensor,
|
| 1093 |
+
dst_tensor: cute.Tensor,
|
| 1094 |
+
filter_zeros: bool = False,
|
| 1095 |
+
single_stage: bool = False,
|
| 1096 |
+
**kwargs,
|
| 1097 |
+
) -> Callable:
|
| 1098 |
+
src_is_smem = const_expr(
|
| 1099 |
+
isinstance(src_tensor.iterator, cute.Pointer)
|
| 1100 |
+
and src_tensor.memspace == cute.AddressSpace.smem
|
| 1101 |
+
)
|
| 1102 |
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
| 1103 |
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
| 1104 |
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
| 1105 |
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
| 1106 |
+
s, g = cpasync.tma_partition(
|
| 1107 |
+
atom,
|
| 1108 |
+
cta_coord,
|
| 1109 |
+
cta_layout,
|
| 1110 |
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
| 1111 |
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
| 1112 |
+
)
|
| 1113 |
+
if const_expr(filter_zeros):
|
| 1114 |
+
s = cute.filter_zeros(s)
|
| 1115 |
+
g = cute.filter_zeros(g)
|
| 1116 |
+
src, dst = (s, g) if src_is_smem else (g, s)
|
| 1117 |
+
|
| 1118 |
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
| 1119 |
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
| 1120 |
+
|
| 1121 |
+
def copy_tma_single_stage(**new_kwargs):
|
| 1122 |
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
| 1123 |
+
|
| 1124 |
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
| 1128 |
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
| 1129 |
+
copy(
|
| 1130 |
+
src_idx=src_idx,
|
| 1131 |
+
dst_idx=producer_state.index,
|
| 1132 |
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
| 1133 |
+
**new_kwargs,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
return copy_fn
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
__all__ = [
|
| 1140 |
+
"atomic_add_broadcast_i32",
|
| 1141 |
+
"atomic_add_fp32x4",
|
| 1142 |
+
"atomic_add_i32",
|
| 1143 |
+
"convert_layout_acc_mn",
|
| 1144 |
+
"convert_layout_from_tmem16x256b_to_acc_sm90",
|
| 1145 |
+
"copy",
|
| 1146 |
+
"cpasync_bulk_g2s",
|
| 1147 |
+
"cpasync_bulk_get_copy_fn",
|
| 1148 |
+
"cpasync_bulk_s2cluster",
|
| 1149 |
+
"cpasync_reduce_bulk_add_f32",
|
| 1150 |
+
"cvt_copy",
|
| 1151 |
+
"get_copy_atom",
|
| 1152 |
+
"load_s2r",
|
| 1153 |
+
"make_16x256b_tensor_mn_view",
|
| 1154 |
+
"make_tmem_copy",
|
| 1155 |
+
"real_col_to_stg128_fake_col",
|
| 1156 |
+
"real_col_to_stg128_fp8_fake_col",
|
| 1157 |
+
"real_col_to_stg128_half_fake_col",
|
| 1158 |
+
"set_block_rank",
|
| 1159 |
+
"stg128_fake_col_to_real_col",
|
| 1160 |
+
"stg128_fp8_fake_col_to_real_col",
|
| 1161 |
+
"stg128_half_fake_col_to_real_col",
|
| 1162 |
+
"stg_128",
|
| 1163 |
+
"stg_128_cs",
|
| 1164 |
+
"stg_128_bf16",
|
| 1165 |
+
"stg_128_bf16_cs",
|
| 1166 |
+
"stg_128_f16",
|
| 1167 |
+
"stg_128_f16_cs",
|
| 1168 |
+
"stg_128_fp8_e4m3_cs",
|
| 1169 |
+
"stg_32_fp8_e4m3",
|
| 1170 |
+
"stg_64_bf16",
|
| 1171 |
+
"stg_64_f16",
|
| 1172 |
+
"sts_32_bf16",
|
| 1173 |
+
"sts_32_f16",
|
| 1174 |
+
"store_shared_remote_fp32x4",
|
| 1175 |
+
"tiled_copy_1d",
|
| 1176 |
+
"tiled_copy_2d",
|
| 1177 |
+
"tma_get_copy_fn",
|
| 1178 |
+
"tma_producer_copy_fn",
|
| 1179 |
+
]
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/cute_dsl_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import pathlib
|
| 7 |
+
import time
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from functools import partial, lru_cache
|
| 10 |
+
from dataclasses import dataclass, fields
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger("minimax")
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from triton.tools.disasm import extract
|
| 18 |
+
except ImportError:
|
| 19 |
+
extract = None
|
| 20 |
+
|
| 21 |
+
import cutlass
|
| 22 |
+
import cutlass.cute as cute
|
| 23 |
+
from cutlass.base_dsl.typing import JitArgument
|
| 24 |
+
from cutlass.cutlass_dsl import NumericMeta
|
| 25 |
+
from cutlass.cute.runtime import from_dlpack
|
| 26 |
+
|
| 27 |
+
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
|
| 31 |
+
cute_compile_og = cute.compile
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
torch2cute_dtype_map = {
|
| 35 |
+
torch.float16: cutlass.Float16,
|
| 36 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 37 |
+
torch.float32: cutlass.Float32,
|
| 38 |
+
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@lru_cache
|
| 43 |
+
def get_max_active_clusters(cluster_size):
|
| 44 |
+
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@lru_cache
|
| 48 |
+
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
|
| 49 |
+
return torch.cuda.get_device_capability(device)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class ArgumentsBase(JitArgument):
|
| 54 |
+
def __c_pointers__(self):
|
| 55 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 56 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 57 |
+
c_ptrs = []
|
| 58 |
+
for obj in non_constexpr_fields:
|
| 59 |
+
if hasattr(obj, "__c_pointers__"):
|
| 60 |
+
c_ptrs.extend(obj.__c_pointers__())
|
| 61 |
+
return c_ptrs
|
| 62 |
+
|
| 63 |
+
def __get_mlir_types__(self):
|
| 64 |
+
all_fields = [getattr(self, field.name) for field in fields(self)]
|
| 65 |
+
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
|
| 66 |
+
types, self._values_pos = [], []
|
| 67 |
+
for obj in non_constexpr_fields:
|
| 68 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 69 |
+
obj_types = obj.__get_mlir_types__()
|
| 70 |
+
types.extend(obj_types)
|
| 71 |
+
self._values_pos.append(len(obj_types))
|
| 72 |
+
else:
|
| 73 |
+
self._values_pos.append(0)
|
| 74 |
+
return types
|
| 75 |
+
|
| 76 |
+
def __new_from_mlir_values__(self, values):
|
| 77 |
+
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
| 78 |
+
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
|
| 79 |
+
non_constexpr_fields = {
|
| 80 |
+
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
|
| 81 |
+
}
|
| 82 |
+
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
| 83 |
+
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
| 84 |
+
values = values[n_items:]
|
| 85 |
+
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_cubin_module_data_patched(cubin_data, filepath):
|
| 89 |
+
pathlib.Path(filepath).write_bytes(cubin_data)
|
| 90 |
+
return load_cubin_module_data_og(cubin_data)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def cute_compile_patched(*args, **kwargs):
|
| 94 |
+
"""A patched version of cute.compile.
|
| 95 |
+
|
| 96 |
+
Behaviour:
|
| 97 |
+
- Dumps SASS to a file if ``CUTE_CUBIN_PATH`` is set.
|
| 98 |
+
- Logs JIT compile wall time at DEBUG level via the ``minimax`` logger,
|
| 99 |
+
tagged with the kernel's class name when available. Enable with
|
| 100 |
+
``logging.getLogger("minimax").setLevel(logging.DEBUG)`` or env
|
| 101 |
+
``MINIMAX_LOG_COMPILE=1``; this is how we distinguish a slow JIT
|
| 102 |
+
(~2-10s) from a kernel hang (>30s = deadlock, see CLAUDE.md).
|
| 103 |
+
"""
|
| 104 |
+
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
|
| 105 |
+
if cubin_path is not None:
|
| 106 |
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
|
| 107 |
+
load_cubin_module_data_patched, filepath=cubin_path
|
| 108 |
+
)
|
| 109 |
+
kernel_obj = args[0] if args else kwargs.get("op")
|
| 110 |
+
kernel_name = type(kernel_obj).__name__ if kernel_obj is not None else "<unknown>"
|
| 111 |
+
t0 = time.time()
|
| 112 |
+
output = cute_compile_og(*args, **kwargs)
|
| 113 |
+
dt = time.time() - t0
|
| 114 |
+
logger.debug("[%s] compiled in %.1fs", kernel_name, dt)
|
| 115 |
+
if cubin_path is not None:
|
| 116 |
+
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
|
| 117 |
+
if extract is not None:
|
| 118 |
+
sass = extract(cubin_path, None)
|
| 119 |
+
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
| 120 |
+
return output
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if os.getenv("MINIMAX_LOG_COMPILE", "0") == "1":
|
| 124 |
+
if not logger.handlers:
|
| 125 |
+
_h = logging.StreamHandler()
|
| 126 |
+
_h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s"))
|
| 127 |
+
logger.addHandler(_h)
|
| 128 |
+
logger.setLevel(logging.DEBUG)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Monkey-patch cute.compile so every JIT compile across the repo gets timed
|
| 132 |
+
# without touching individual call sites. Idempotent: only patches once.
|
| 133 |
+
if cute.compile is not cute_compile_patched:
|
| 134 |
+
cute.compile = cute_compile_patched
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def assume_strides_aligned(t):
|
| 138 |
+
"""Assume all strides except the last are divisible by 128 bits.
|
| 139 |
+
|
| 140 |
+
Python int strides (e.g., stride=0 from GQA expand) are kept as-is
|
| 141 |
+
since they're static and don't need alignment assumptions.
|
| 142 |
+
"""
|
| 143 |
+
divby = 128 // t.element_type.width
|
| 144 |
+
strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
|
| 145 |
+
return (*strides, t.stride[-1])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def assume_tensor_aligned(t):
|
| 149 |
+
"""Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
|
| 150 |
+
if t is None:
|
| 151 |
+
return None
|
| 152 |
+
return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
|
| 156 |
+
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
|
| 157 |
+
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
|
| 158 |
+
if fully_dynamic:
|
| 159 |
+
return tensor.mark_layout_dynamic()
|
| 160 |
+
if leading_dim == -1:
|
| 161 |
+
leading_dim = t.ndim - 1
|
| 162 |
+
return tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def to_cute_aux_tensor(t, enable_tvm_ffi=True):
|
| 166 |
+
"""Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
|
| 167 |
+
This allows the user to specify alignment and leading dimension for aux tensors used in
|
| 168 |
+
custom score_mod callables.
|
| 169 |
+
"""
|
| 170 |
+
assumed_align: int = getattr(t, "__assumed_align__", None)
|
| 171 |
+
leading_dim: int = getattr(t, "__leading_dim__", None)
|
| 172 |
+
fully_dynamic: bool = leading_dim is None
|
| 173 |
+
|
| 174 |
+
return to_cute_tensor(
|
| 175 |
+
t,
|
| 176 |
+
assumed_align=assumed_align,
|
| 177 |
+
leading_dim=leading_dim,
|
| 178 |
+
fully_dynamic=fully_dynamic,
|
| 179 |
+
enable_tvm_ffi=enable_tvm_ffi,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
|
| 184 |
+
"""Return tuple of bools indicating which dims have stride=0 (broadcast).
|
| 185 |
+
|
| 186 |
+
This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
|
| 187 |
+
stride=0 as static, meaning kernels compiled with different broadcast
|
| 188 |
+
patterns are not interchangeable.
|
| 189 |
+
"""
|
| 190 |
+
return tuple(s == 0 for s in tensor.stride())
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/fast_math.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
import cutlass
|
| 5 |
+
import cutlass.cute as cute
|
| 6 |
+
from cutlass import Int32
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@cute.jit
|
| 10 |
+
def clz(x: Int32) -> Int32:
|
| 11 |
+
# for i in cutlass.range_constexpr(32):
|
| 12 |
+
# if (1 << (31 - i)) & x:
|
| 13 |
+
# return Int32(i)
|
| 14 |
+
# return Int32(32)
|
| 15 |
+
# Early exit is not supported yet
|
| 16 |
+
res = Int32(32)
|
| 17 |
+
done = False
|
| 18 |
+
for i in cutlass.range(32):
|
| 19 |
+
if ((1 << (31 - i)) & x) and not done:
|
| 20 |
+
res = Int32(i)
|
| 21 |
+
done = True
|
| 22 |
+
return res
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/mask.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from typing import Callable, Optional, TypeAlias
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Float32, Int32, Uint32, const_expr
|
| 10 |
+
|
| 11 |
+
from ...src.common import utils as utils
|
| 12 |
+
from ...src.common.seqlen_info import SeqlenInfoQK
|
| 13 |
+
|
| 14 |
+
MaskGenFn: TypeAlias = Callable[[int], Uint32]
|
| 15 |
+
MASK_R2P_CHUNK_SIZE: int = 32
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@cute.jit
|
| 19 |
+
def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
|
| 20 |
+
m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)
|
| 21 |
+
return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@cute.jit
|
| 25 |
+
def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
|
| 26 |
+
n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)
|
| 27 |
+
return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@cute.jit
|
| 31 |
+
def mask_r2p_lambda(
|
| 32 |
+
X: cute.Tensor,
|
| 33 |
+
mask_gen_fn: cutlass.Constexpr[MaskGenFn],
|
| 34 |
+
rank1: bool = False,
|
| 35 |
+
) -> None:
|
| 36 |
+
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
|
| 37 |
+
for s in cutlass.range_constexpr(cute.ceil_div(ncol, MASK_R2P_CHUNK_SIZE)):
|
| 38 |
+
mask = mask_gen_fn(s)
|
| 39 |
+
for i in cutlass.range_constexpr(min(MASK_R2P_CHUNK_SIZE, ncol - s * MASK_R2P_CHUNK_SIZE)):
|
| 40 |
+
in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
|
| 41 |
+
c = s * MASK_R2P_CHUNK_SIZE + i
|
| 42 |
+
if const_expr(rank1):
|
| 43 |
+
X[c] = X[c] if in_bound else -Float32.inf
|
| 44 |
+
else:
|
| 45 |
+
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
|
| 46 |
+
X[r, c] = X[r, c] if in_bound else -Float32.inf
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@cute.jit
|
| 50 |
+
def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
|
| 51 |
+
return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class AttentionMask:
|
| 56 |
+
tile_m: cutlass.Constexpr[int]
|
| 57 |
+
tile_n: cutlass.Constexpr[int]
|
| 58 |
+
seqlen_info: SeqlenInfoQK
|
| 59 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 60 |
+
swap_AB: cutlass.Constexpr[bool] = False
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def seqlen_q(self) -> Int32:
|
| 64 |
+
return self.seqlen_info.seqlen_q
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def seqlen_k(self) -> Int32:
|
| 68 |
+
return self.seqlen_info.seqlen_k
|
| 69 |
+
|
| 70 |
+
@cute.jit
|
| 71 |
+
def apply_mask_sm100(
|
| 72 |
+
self,
|
| 73 |
+
acc_S: cute.Tensor,
|
| 74 |
+
tScS_t2r: cute.Tensor,
|
| 75 |
+
m_block: Int32,
|
| 76 |
+
n_block: Int32,
|
| 77 |
+
mask_seqlen: cutlass.Constexpr[bool],
|
| 78 |
+
mask_causal: cutlass.Constexpr[bool],
|
| 79 |
+
row_idx: Optional[Int32] = None,
|
| 80 |
+
kv_valid_cols: Optional[Int32] = None,
|
| 81 |
+
kv_block_col_start: Optional[Int32] = None,
|
| 82 |
+
) -> None:
|
| 83 |
+
if const_expr(not mask_seqlen and not mask_causal):
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
col_limit = Int32(self.tile_n)
|
| 87 |
+
if const_expr(mask_seqlen):
|
| 88 |
+
if const_expr(kv_valid_cols is not None):
|
| 89 |
+
col_limit = kv_valid_cols
|
| 90 |
+
else:
|
| 91 |
+
col_limit = self.seqlen_k - n_block * Int32(self.tile_n)
|
| 92 |
+
|
| 93 |
+
if const_expr(mask_causal):
|
| 94 |
+
if const_expr(row_idx is None):
|
| 95 |
+
row_axis = 0 if const_expr(not self.swap_AB) else 1
|
| 96 |
+
row_idx_cur = tScS_t2r[0][row_axis] + m_block * Int32(self.tile_m)
|
| 97 |
+
if const_expr(self.qhead_per_kvhead_packgqa > 1):
|
| 98 |
+
row_idx_cur = row_idx_cur // Int32(self.qhead_per_kvhead_packgqa)
|
| 99 |
+
else:
|
| 100 |
+
row_idx_cur = row_idx
|
| 101 |
+
if const_expr(kv_block_col_start is not None):
|
| 102 |
+
block_col_start = kv_block_col_start
|
| 103 |
+
else:
|
| 104 |
+
block_col_start = n_block * Int32(self.tile_n)
|
| 105 |
+
causal_col_limit = (
|
| 106 |
+
row_idx_cur + self.seqlen_k - self.seqlen_q
|
| 107 |
+
- block_col_start + Int32(1)
|
| 108 |
+
)
|
| 109 |
+
col_limit = (
|
| 110 |
+
cutlass.min(col_limit, causal_col_limit)
|
| 111 |
+
if const_expr(mask_seqlen)
|
| 112 |
+
else causal_col_limit
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if col_limit < Int32(self.tile_n):
|
| 116 |
+
mask_r2p_lambda(
|
| 117 |
+
acc_S,
|
| 118 |
+
lambda s: r2p_bitmask_below(col_limit, s),
|
| 119 |
+
rank1=True,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@cute.jit
|
| 123 |
+
def apply_mask_sm100_transposed(
|
| 124 |
+
self,
|
| 125 |
+
acc_S: cute.Tensor,
|
| 126 |
+
tScS_t2r: cute.Tensor,
|
| 127 |
+
t0ScS_t2r: cute.Tensor,
|
| 128 |
+
m_block: cutlass.Int32,
|
| 129 |
+
n_block: cutlass.Int32,
|
| 130 |
+
mask_seqlen: cutlass.Constexpr,
|
| 131 |
+
mask_causal: cutlass.Constexpr,
|
| 132 |
+
is_full_block: bool = False,
|
| 133 |
+
check_m_boundary: bool = True,
|
| 134 |
+
valid_tok_count: Optional[Int32] = None,
|
| 135 |
+
q_idx_tile: Optional[cute.Tensor] = None,
|
| 136 |
+
masked_tok_count: Optional[Int32] = None,
|
| 137 |
+
) -> None:
|
| 138 |
+
del is_full_block, check_m_boundary
|
| 139 |
+
del t0ScS_t2r
|
| 140 |
+
row_axis = 0 if const_expr(not self.swap_AB) else 1
|
| 141 |
+
col_axis = 1 if const_expr(not self.swap_AB) else 0
|
| 142 |
+
|
| 143 |
+
if const_expr(valid_tok_count is not None):
|
| 144 |
+
kv_block_col_start = n_block * Int32(self.tile_n)
|
| 145 |
+
causal_q_offset = self.seqlen_k - self.seqlen_q
|
| 146 |
+
nfrag = const_expr(cute.size(acc_S.shape))
|
| 147 |
+
for i in cutlass.range(nfrag, unroll_full=True):
|
| 148 |
+
row_idx = tScS_t2r[i][row_axis]
|
| 149 |
+
tok_idx = row_idx // Int32(self.qhead_per_kvhead_packgqa)
|
| 150 |
+
acc_S[i] = -Float32.inf if tok_idx >= valid_tok_count else acc_S[i]
|
| 151 |
+
if const_expr(mask_seqlen):
|
| 152 |
+
kv_idx = kv_block_col_start + tScS_t2r[i][col_axis]
|
| 153 |
+
acc_S[i] = -Float32.inf if kv_idx >= self.seqlen_k else acc_S[i]
|
| 154 |
+
if const_expr(mask_causal):
|
| 155 |
+
if const_expr(q_idx_tile is not None):
|
| 156 |
+
causal_tok_count = (
|
| 157 |
+
masked_tok_count
|
| 158 |
+
if const_expr(masked_tok_count is not None)
|
| 159 |
+
else Int32(0)
|
| 160 |
+
)
|
| 161 |
+
if tok_idx < causal_tok_count:
|
| 162 |
+
q_idx = q_idx_tile[tok_idx]
|
| 163 |
+
kv_idx = kv_block_col_start + tScS_t2r[i][col_axis]
|
| 164 |
+
acc_S[i] = (
|
| 165 |
+
-Float32.inf if kv_idx > q_idx + causal_q_offset else acc_S[i]
|
| 166 |
+
)
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
thr_col_offset = tScS_t2r[0][col_axis]
|
| 170 |
+
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
|
| 171 |
+
|
| 172 |
+
if const_expr(not mask_causal):
|
| 173 |
+
if const_expr(mask_seqlen) and seqlenk_col_limit <= 0:
|
| 174 |
+
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
|
| 175 |
+
acc_S[i] = -cutlass.Float32.inf
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
thr_row_offset = tScS_t2r[0][row_axis]
|
| 179 |
+
seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
|
| 180 |
+
row_limit_top = seqlenq_row_limit - seqlenk_col_limit
|
| 181 |
+
if const_expr(mask_seqlen) and seqlenk_col_limit <= 0:
|
| 182 |
+
row_limit_top = self.tile_m
|
| 183 |
+
num_rep = cute.size(tScS_t2r, mode=[0])
|
| 184 |
+
row_limit = row_to_r2p_idx(row_limit_top, num_rep, 2)
|
| 185 |
+
mask_r2p_lambda(
|
| 186 |
+
acc_S,
|
| 187 |
+
lambda s: r2p_bitmask_above(row_limit, s),
|
| 188 |
+
rank1=True,
|
| 189 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/mma_sm100_desc.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
#
|
| 4 |
+
# The bit-field encodings, enum values, and descriptor layout below mirror the
|
| 5 |
+
# SM100 tcgen05 MMA instruction descriptor as documented and
|
| 6 |
+
# implemented in NVIDIA CUTLASS (BSD-3-Clause). The numeric values MUST stay
|
| 7 |
+
# identical to the hardware/ISA encodings; see the "Third-party licenses"
|
| 8 |
+
# section of README.md at the repo root for attribution.
|
| 9 |
+
|
| 10 |
+
from enum import IntEnum
|
| 11 |
+
|
| 12 |
+
import cutlass
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Enumerations that match the HW encodings (values MUST stay identical)
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Major(IntEnum): # matrix "layout" in the ISA docs
|
| 21 |
+
K = 0
|
| 22 |
+
MN = 1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ScaleIn(IntEnum): # negate flags
|
| 26 |
+
One = 0
|
| 27 |
+
Neg = 1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Saturate(IntEnum):
|
| 31 |
+
False_ = 0
|
| 32 |
+
True_ = 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CFormat(IntEnum): # 2-bit field (bits 4-5)
|
| 36 |
+
F16 = 0
|
| 37 |
+
F32 = 1
|
| 38 |
+
S32 = 2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class F16F32Format(IntEnum): # 3-bit field (A/B element type)
|
| 42 |
+
F16 = 0
|
| 43 |
+
BF16 = 1
|
| 44 |
+
TF32 = 2
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class S8Format(IntEnum):
|
| 48 |
+
UINT8 = 0
|
| 49 |
+
INT8 = 1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MXF8F6F4Format(IntEnum):
|
| 53 |
+
E4M3 = 0
|
| 54 |
+
E5M2 = 1
|
| 55 |
+
E2M3 = 3
|
| 56 |
+
E3M2 = 4
|
| 57 |
+
E2M1 = 5
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class MaxShift(IntEnum):
|
| 61 |
+
NoShift = 0
|
| 62 |
+
MaxShift8 = 1
|
| 63 |
+
MaxShift16 = 2
|
| 64 |
+
MaxShift32 = 3
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# CUTLASS-type -> encoding helpers
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def to_UMMA_format(cutlass_type) -> int:
|
| 73 |
+
"""
|
| 74 |
+
Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
|
| 75 |
+
"""
|
| 76 |
+
if cutlass_type is cutlass.Int8:
|
| 77 |
+
return S8Format.INT8
|
| 78 |
+
# Unsigned 8-bit (if available in your CUTLASS build)
|
| 79 |
+
if cutlass_type is cutlass.Uint8:
|
| 80 |
+
return S8Format.UINT8
|
| 81 |
+
# FP-16 / BF-16
|
| 82 |
+
if cutlass_type is cutlass.Float16:
|
| 83 |
+
return F16F32Format.F16
|
| 84 |
+
if cutlass_type is cutlass.BFloat16:
|
| 85 |
+
return F16F32Format.BF16
|
| 86 |
+
# TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
|
| 87 |
+
if cutlass_type is cutlass.TFloat32:
|
| 88 |
+
return F16F32Format.TF32
|
| 89 |
+
# Float-8 / Float-6 / Float-4
|
| 90 |
+
if cutlass_type is cutlass.Float8E4M3FN:
|
| 91 |
+
return MXF8F6F4Format.E4M3
|
| 92 |
+
if cutlass_type is cutlass.Float8E5M2:
|
| 93 |
+
return MXF8F6F4Format.E5M2
|
| 94 |
+
raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def to_C_format(cutlass_type) -> int:
|
| 98 |
+
"""
|
| 99 |
+
Map a CUTLASS scalar class to the 2-bit accumulator encoding.
|
| 100 |
+
"""
|
| 101 |
+
if cutlass_type is cutlass.Float16:
|
| 102 |
+
return CFormat.F16
|
| 103 |
+
if cutlass_type is cutlass.Float32:
|
| 104 |
+
return CFormat.F32
|
| 105 |
+
if cutlass_type is cutlass.Int32:
|
| 106 |
+
return CFormat.S32
|
| 107 |
+
raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# The constructor – accepts only CUTLASS scalar classes
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def make_instr_desc(
|
| 116 |
+
a_type, # CUTLASS scalar class, e.g. cutlass.Int8
|
| 117 |
+
b_type,
|
| 118 |
+
c_type,
|
| 119 |
+
M: int, # 64, 128 or 256
|
| 120 |
+
N: int, # 8 … 256 (multiple of 8)
|
| 121 |
+
a_major: Major,
|
| 122 |
+
b_major: Major,
|
| 123 |
+
a_neg: ScaleIn = ScaleIn.One,
|
| 124 |
+
b_neg: ScaleIn = ScaleIn.One,
|
| 125 |
+
c_sat: Saturate = Saturate.False_,
|
| 126 |
+
is_sparse: bool = False,
|
| 127 |
+
max_shift: MaxShift = MaxShift.NoShift,
|
| 128 |
+
) -> int:
|
| 129 |
+
"""
|
| 130 |
+
Build the 32-bit instruction descriptor for SM100 MMA.
|
| 131 |
+
All matrix/accumulator **types must be CUTLASS scalar classes** –
|
| 132 |
+
passing integers is forbidden.
|
| 133 |
+
"""
|
| 134 |
+
# --- encode element formats -------------------------------------------------
|
| 135 |
+
a_fmt = int(to_UMMA_format(a_type))
|
| 136 |
+
b_fmt = int(to_UMMA_format(b_type))
|
| 137 |
+
c_fmt = int(to_C_format(c_type))
|
| 138 |
+
is_f8f6f4 = a_type in (cutlass.Float8E4M3FN, cutlass.Float8E5M2)
|
| 139 |
+
|
| 140 |
+
# --- range checks on M/N -----------------------------------------------------
|
| 141 |
+
if M not in (64, 128, 256):
|
| 142 |
+
raise ValueError("M must be 64, 128 or 256")
|
| 143 |
+
if N < 8 or N > 256 or (N & 7):
|
| 144 |
+
raise ValueError("N must be a multiple of 8 in the range 8…256")
|
| 145 |
+
|
| 146 |
+
m_dim = M >> 4 # 5-bit field
|
| 147 |
+
n_dim = N >> 3 # 6-bit field
|
| 148 |
+
|
| 149 |
+
# fmt: off
|
| 150 |
+
# --- pack the bit-fields -----------------------------------------------------
|
| 151 |
+
desc = 0
|
| 152 |
+
desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
|
| 153 |
+
desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
|
| 154 |
+
desc |= (int(c_sat) & 0x1) << 3 # saturate
|
| 155 |
+
desc |= (c_fmt & 0x3) << 4 # c_format
|
| 156 |
+
desc |= (a_fmt & 0x7) << 7 # a_format
|
| 157 |
+
desc |= (b_fmt & 0x7) << 10 # b_format
|
| 158 |
+
desc |= (int(a_neg) & 0x1) << 13 # a_negate
|
| 159 |
+
desc |= (int(b_neg) & 0x1) << 14 # b_negate
|
| 160 |
+
desc |= (int(a_major) & 0x1) << 15 # a_major
|
| 161 |
+
desc |= (int(b_major) & 0x1) << 16 # b_major
|
| 162 |
+
desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
|
| 163 |
+
# CUTLASS' tcgen05 lowering sets bit 23 for dense f8f6f4 MMAs; keep this
|
| 164 |
+
# descriptor aligned with generated/reference SM100 FP8 kernels.
|
| 165 |
+
desc |= (int(is_f8f6f4) & 0x1) << 23
|
| 166 |
+
desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
|
| 167 |
+
desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
|
| 168 |
+
# fmt: on
|
| 169 |
+
|
| 170 |
+
return desc & 0xFFFF_FFFF # ensure 32-bit result
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
|
| 174 |
+
return make_instr_desc(
|
| 175 |
+
op.a_dtype,
|
| 176 |
+
op.b_dtype,
|
| 177 |
+
op.acc_dtype,
|
| 178 |
+
op.shape_mnk[0],
|
| 179 |
+
op.shape_mnk[1],
|
| 180 |
+
Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
| 181 |
+
Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
|
| 186 |
+
SWIZZLE_NONE = 0 # (a.k.a. "INTERLEAVE" in older docs)
|
| 187 |
+
SWIZZLE_128B_BASE32B = 1
|
| 188 |
+
SWIZZLE_128B = 2
|
| 189 |
+
SWIZZLE_64B = 4
|
| 190 |
+
SWIZZLE_32B = 6
|
| 191 |
+
# values 3,5,7 are reserved / illegal for UMMA
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
# Helpers – figure out the SWIZZLE_* family from the tensor layout
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
|
| 200 |
+
B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift
|
| 201 |
+
|
| 202 |
+
if M == 4: # Swizzle<*,4,3>
|
| 203 |
+
if S != 3:
|
| 204 |
+
raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
|
| 205 |
+
return {
|
| 206 |
+
0: LayoutType.SWIZZLE_NONE,
|
| 207 |
+
1: LayoutType.SWIZZLE_32B,
|
| 208 |
+
2: LayoutType.SWIZZLE_64B,
|
| 209 |
+
3: LayoutType.SWIZZLE_128B,
|
| 210 |
+
}[B] # KeyError ⇒ invalid B→ raise
|
| 211 |
+
if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
|
| 212 |
+
if (B, S) != (2, 2):
|
| 213 |
+
raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
|
| 214 |
+
return LayoutType.SWIZZLE_128B_BASE32B
|
| 215 |
+
|
| 216 |
+
# Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
|
| 217 |
+
raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
|
| 221 |
+
"""
|
| 222 |
+
Convert a 2-D *shared-memory* Cute layout into the SM100 64-bit
|
| 223 |
+
smem-descriptor, without the smem start address.
|
| 224 |
+
layout must correspond to layout of an uint128 tensor.
|
| 225 |
+
"""
|
| 226 |
+
# ------------------------------------------------------------------ meta
|
| 227 |
+
layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
|
| 228 |
+
|
| 229 |
+
VERSION = 1 # bits 46–47
|
| 230 |
+
LBO_MODE = 0 # bit 52
|
| 231 |
+
BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
|
| 232 |
+
|
| 233 |
+
# ---------------------------------------------------------- strides (units: uint128_t = 16 B)
|
| 234 |
+
swizzle_atom_mn_size = {
|
| 235 |
+
LayoutType.SWIZZLE_NONE: 1,
|
| 236 |
+
LayoutType.SWIZZLE_32B: 2,
|
| 237 |
+
LayoutType.SWIZZLE_64B: 4,
|
| 238 |
+
LayoutType.SWIZZLE_128B: 8,
|
| 239 |
+
LayoutType.SWIZZLE_128B_BASE32B: 8,
|
| 240 |
+
}[layout_type]
|
| 241 |
+
|
| 242 |
+
if major is Major.MN:
|
| 243 |
+
swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
|
| 244 |
+
canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
|
| 245 |
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
| 246 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
|
| 247 |
+
stride_00 = canonical_layout.stride[0][0]
|
| 248 |
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
|
| 249 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
| 250 |
+
stride_10 = canonical_layout.stride[1][0]
|
| 251 |
+
if stride_10 != swizzle_atom_mn_size:
|
| 252 |
+
raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
|
| 253 |
+
stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
|
| 254 |
+
if layout_type is LayoutType.SWIZZLE_NONE:
|
| 255 |
+
stride_byte_offset, leading_byte_offset = stride_01, stride_11
|
| 256 |
+
else:
|
| 257 |
+
stride_byte_offset, leading_byte_offset = stride_11, stride_01
|
| 258 |
+
else:
|
| 259 |
+
if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
|
| 260 |
+
raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
|
| 261 |
+
if not cute.size(layout.shape[0]) % 8 == 0:
|
| 262 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
|
| 263 |
+
canonical_layout = cute.logical_divide(layout, (8, 2))
|
| 264 |
+
if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
|
| 265 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
|
| 266 |
+
stride_00 = canonical_layout.stride[0][0]
|
| 267 |
+
if stride_00 != swizzle_atom_mn_size:
|
| 268 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
| 269 |
+
stride_10 = canonical_layout.stride[1][0]
|
| 270 |
+
if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
|
| 271 |
+
raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
|
| 272 |
+
stride_01 = canonical_layout.stride[0][1]
|
| 273 |
+
stride_byte_offset, leading_byte_offset = stride_01, stride_10
|
| 274 |
+
|
| 275 |
+
# ------------------------------------------------------------------ pack
|
| 276 |
+
desc = 0
|
| 277 |
+
# leading_byte_offset_ [16:30)
|
| 278 |
+
desc |= (leading_byte_offset & 0x3FFF) << 16
|
| 279 |
+
# stride_byte_offset_ [32:46)
|
| 280 |
+
desc |= (stride_byte_offset & 0x3FFF) << 32
|
| 281 |
+
# version_ [46:48)
|
| 282 |
+
desc |= (VERSION & 0x3) << 46
|
| 283 |
+
# base_offset_ [49:52)
|
| 284 |
+
desc |= (BASE_OFFSET & 0x7) << 49
|
| 285 |
+
# lbo_mode_ [52:53)
|
| 286 |
+
desc |= (LBO_MODE & 0x1) << 52
|
| 287 |
+
# layout_type_ [61:64)
|
| 288 |
+
desc |= (int(layout_type) & 0x7) << 61
|
| 289 |
+
|
| 290 |
+
return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
|
| 294 |
+
# 14 bits, remove 4 LSB (bits 0-13 in desc)
|
| 295 |
+
return (start_addr.toint() & 0x3FFFF) >> 4
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:
|
| 299 |
+
sA_swizzle = sA.iterator.type.swizzle_type
|
| 300 |
+
return make_smem_desc_base(
|
| 301 |
+
cute.recast_layout(128, sA.element_type.width, sA.layout[0]),
|
| 302 |
+
sA_swizzle,
|
| 303 |
+
major,
|
| 304 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/named_barrier.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
import enum
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NamedBarrierFwdSm100(enum.IntEnum):
|
| 8 |
+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
|
| 9 |
+
TmemPtr = enum.auto()
|
| 10 |
+
SoftmaxStatsW0 = enum.auto()
|
| 11 |
+
SoftmaxStatsW1 = enum.auto()
|
| 12 |
+
SoftmaxStatsW2 = enum.auto()
|
| 13 |
+
SoftmaxStatsW3 = enum.auto()
|
| 14 |
+
SoftmaxStatsW4 = enum.auto()
|
| 15 |
+
SoftmaxStatsW5 = enum.auto()
|
| 16 |
+
SoftmaxStatsW6 = enum.auto()
|
| 17 |
+
SoftmaxStatsW7 = enum.auto()
|
| 18 |
+
LoadWG = enum.auto()
|
| 19 |
+
StoreEpilogue = enum.auto()
|
| 20 |
+
KvLoad = enum.auto()
|
| 21 |
+
KvDequantK = enum.auto()
|
| 22 |
+
KvDequantV = enum.auto()
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/pack_gqa.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""PackGQA primitives for GQA (grouped-query attention) tile layouts.
|
| 5 |
+
|
| 6 |
+
Contains:
|
| 7 |
+
- ``pack_gqa_layout`` / ``unpack_gqa_layout``: fold/unfold ``qhead_per_kvhead``
|
| 8 |
+
into the seqlen dimension of a tensor layout (zero-copy view).
|
| 9 |
+
- ``PackGQA``: base class with ``compute_ptr`` / ``load_Q`` / ``store_LSE`` /
|
| 10 |
+
``store_O`` helpers for kernels that treat ``(qhead_per_kvhead × seqlen_q)``
|
| 11 |
+
as a single packed row dimension.
|
| 12 |
+
- ``PackGQAComb``: subclass used by the K2 combine kernel; adds ``load_LSE``
|
| 13 |
+
for coalesced GMEM→SMEM async copies when LSE_partial is laid out with H_q
|
| 14 |
+
innermost (stride-1).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import cutlass
|
| 21 |
+
import cutlass.cute as cute
|
| 22 |
+
from cutlass import Float32, Int32, const_expr
|
| 23 |
+
from cutlass.cute import FastDivmodDivisor
|
| 24 |
+
|
| 25 |
+
from ...quack import layout_utils
|
| 26 |
+
|
| 27 |
+
from . import utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):
|
| 31 |
+
"""Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0).
|
| 32 |
+
|
| 33 |
+
The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
|
| 34 |
+
are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
|
| 35 |
+
as-is (e.g. batch).
|
| 36 |
+
|
| 37 |
+
For Q/O tensors (head_idx=2):
|
| 38 |
+
(seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...)
|
| 39 |
+
For LSE tensors (head_idx=1):
|
| 40 |
+
(seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...)
|
| 41 |
+
"""
|
| 42 |
+
head_stride = T.stride[head_idx]
|
| 43 |
+
shape_packed = (
|
| 44 |
+
(qhead_per_kvhead, T.shape[0]),
|
| 45 |
+
*[T.shape[i] for i in range(1, head_idx)],
|
| 46 |
+
nheads_kv,
|
| 47 |
+
*[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
|
| 48 |
+
)
|
| 49 |
+
stride_packed = (
|
| 50 |
+
(head_stride, T.stride[0]),
|
| 51 |
+
*[T.stride[i] for i in range(1, head_idx)],
|
| 52 |
+
head_stride * qhead_per_kvhead,
|
| 53 |
+
*[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
|
| 54 |
+
)
|
| 55 |
+
return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def unpack_gqa_layout(T, qhead_per_kvhead, head_idx):
|
| 59 |
+
"""Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0).
|
| 60 |
+
|
| 61 |
+
The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
|
| 62 |
+
are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
|
| 63 |
+
as-is (e.g. batch).
|
| 64 |
+
|
| 65 |
+
For Q/O tensors (head_idx=2):
|
| 66 |
+
((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...)
|
| 67 |
+
For LSE tensors (head_idx=1):
|
| 68 |
+
((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...)
|
| 69 |
+
"""
|
| 70 |
+
seqlen_stride = T.stride[0][1]
|
| 71 |
+
head_stride = T.stride[0][0]
|
| 72 |
+
shape_unpacked = (
|
| 73 |
+
T.shape[0][1],
|
| 74 |
+
*[T.shape[i] for i in range(1, head_idx)],
|
| 75 |
+
T.shape[head_idx] * qhead_per_kvhead,
|
| 76 |
+
*[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
|
| 77 |
+
)
|
| 78 |
+
stride_unpacked = (
|
| 79 |
+
seqlen_stride,
|
| 80 |
+
*[T.stride[i] for i in range(1, head_idx)],
|
| 81 |
+
head_stride,
|
| 82 |
+
*[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
|
| 83 |
+
)
|
| 84 |
+
return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class PackGQA:
|
| 89 |
+
m_block_size: cutlass.Constexpr[int]
|
| 90 |
+
head_dim_padded: cutlass.Constexpr[int]
|
| 91 |
+
check_hdim_oob: cutlass.Constexpr[bool]
|
| 92 |
+
qhead_per_kvhead: cutlass.Constexpr[bool]
|
| 93 |
+
|
| 94 |
+
@cute.jit
|
| 95 |
+
def compute_ptr(
|
| 96 |
+
self,
|
| 97 |
+
tensor: cute.Tensor,
|
| 98 |
+
cRows: cute.Tensor,
|
| 99 |
+
tidx: cutlass.Int32,
|
| 100 |
+
block: cutlass.Int32,
|
| 101 |
+
threads_per_row: cutlass.Constexpr[int],
|
| 102 |
+
num_threads: cutlass.Constexpr[int],
|
| 103 |
+
):
|
| 104 |
+
num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
|
| 105 |
+
tPrPtr = cute.make_rmem_tensor(num_ptr_per_thread, cutlass.Int64)
|
| 106 |
+
for i in cutlass.range_constexpr(num_ptr_per_thread):
|
| 107 |
+
row = i * num_threads + cRows[tidx % threads_per_row][0]
|
| 108 |
+
idx = block * self.m_block_size + row
|
| 109 |
+
m_idx = idx // self.qhead_per_kvhead
|
| 110 |
+
h_idx = idx - m_idx * self.qhead_per_kvhead
|
| 111 |
+
tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
|
| 112 |
+
return tPrPtr
|
| 113 |
+
|
| 114 |
+
@cute.jit
|
| 115 |
+
def load_Q(
|
| 116 |
+
self,
|
| 117 |
+
mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
| 118 |
+
sQ: cute.Tensor, # (m_block_size, head_dim_padded)
|
| 119 |
+
gmem_tiled_copy: cute.TiledCopy,
|
| 120 |
+
tidx: cutlass.Int32,
|
| 121 |
+
block: cutlass.Int32,
|
| 122 |
+
seqlen: cutlass.Int32,
|
| 123 |
+
):
|
| 124 |
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
| 125 |
+
cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 126 |
+
tQsQ = gmem_thr_copy.partition_D(sQ)
|
| 127 |
+
tQcQ = gmem_thr_copy.partition_S(cQ)
|
| 128 |
+
t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
|
| 129 |
+
tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
|
| 130 |
+
tQcQ_row = tQcQ[0, None, 0]
|
| 131 |
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
| 132 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 133 |
+
num_threads = gmem_tiled_copy.size
|
| 134 |
+
tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
|
| 135 |
+
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
| 136 |
+
q_ptr_i64 = utils.shuffle_sync(
|
| 137 |
+
tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
| 138 |
+
)
|
| 139 |
+
q_gmem_ptr = cute.make_ptr(
|
| 140 |
+
mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
| 141 |
+
)
|
| 142 |
+
if (
|
| 143 |
+
t0QcQ[0, m, 0][0]
|
| 144 |
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
|
| 145 |
+
):
|
| 146 |
+
mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
|
| 147 |
+
elems_per_load = cute.size(tQsQ.shape[0][0])
|
| 148 |
+
mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
|
| 149 |
+
for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
|
| 150 |
+
ki = tQcQ[0, 0, k][1] // elems_per_load
|
| 151 |
+
cute.copy(
|
| 152 |
+
gmem_thr_copy,
|
| 153 |
+
mQ_cur_copy[None, ki],
|
| 154 |
+
tQsQ[None, m, k],
|
| 155 |
+
pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
@cute.jit
|
| 159 |
+
def store_LSE(
|
| 160 |
+
self,
|
| 161 |
+
mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
|
| 162 |
+
tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
|
| 163 |
+
tiled_mma: cute.TiledMma,
|
| 164 |
+
tidx: cutlass.Int32,
|
| 165 |
+
block: cutlass.Int32,
|
| 166 |
+
seqlen: cutlass.Int32,
|
| 167 |
+
):
|
| 168 |
+
thr_mma = tiled_mma.get_slice(tidx)
|
| 169 |
+
caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 170 |
+
taccOcO = thr_mma.partition_C(caccO)
|
| 171 |
+
taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
|
| 172 |
+
assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
|
| 173 |
+
threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
|
| 174 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 175 |
+
assert cute.size(tLSErLSE) <= threads_per_row
|
| 176 |
+
num_threads = tiled_mma.size
|
| 177 |
+
tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
|
| 178 |
+
for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
|
| 179 |
+
lse_ptr_i64 = utils.shuffle_sync(
|
| 180 |
+
tPrLSEPtr[m // threads_per_row],
|
| 181 |
+
m % threads_per_row,
|
| 182 |
+
width=threads_per_row,
|
| 183 |
+
)
|
| 184 |
+
lse_gmem_ptr = cute.make_ptr(
|
| 185 |
+
mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
|
| 186 |
+
)
|
| 187 |
+
row = block * self.m_block_size + taccOcO_row[m][0]
|
| 188 |
+
# Only the thread corresponding to column 0 writes out the lse to gmem
|
| 189 |
+
if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
|
| 190 |
+
mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
|
| 191 |
+
mLSE_copy[0] = tLSErLSE[m]
|
| 192 |
+
|
| 193 |
+
@cute.jit
|
| 194 |
+
def store_O(
|
| 195 |
+
self,
|
| 196 |
+
mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
|
| 197 |
+
tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
|
| 198 |
+
gmem_tiled_copy: cute.TiledCopy,
|
| 199 |
+
tidx: cutlass.Int32,
|
| 200 |
+
block: cutlass.Int32,
|
| 201 |
+
seqlen: cutlass.Int32,
|
| 202 |
+
):
|
| 203 |
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
| 204 |
+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
| 205 |
+
tOcO = gmem_thr_copy.partition_S(cO)
|
| 206 |
+
t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
|
| 207 |
+
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
|
| 208 |
+
tOcO_row = tOcO[0, None, 0]
|
| 209 |
+
threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
|
| 210 |
+
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
| 211 |
+
num_threads = gmem_tiled_copy.size
|
| 212 |
+
tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
|
| 213 |
+
for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
|
| 214 |
+
o_ptr_i64 = utils.shuffle_sync(
|
| 215 |
+
tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
|
| 216 |
+
)
|
| 217 |
+
o_gmem_ptr = cute.make_ptr(
|
| 218 |
+
mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
| 219 |
+
)
|
| 220 |
+
if (
|
| 221 |
+
t0OcO[0, m, 0][0]
|
| 222 |
+
< seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
|
| 223 |
+
):
|
| 224 |
+
mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
|
| 225 |
+
elems_per_load = cute.size(tOrO.shape[0][0])
|
| 226 |
+
mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
|
| 227 |
+
for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
|
| 228 |
+
ki = tOcO[0, 0, k][1] // elems_per_load
|
| 229 |
+
cute.copy(
|
| 230 |
+
gmem_thr_copy,
|
| 231 |
+
tOrO[None, m, k],
|
| 232 |
+
mO_cur_copy[None, ki],
|
| 233 |
+
pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@dataclass
|
| 238 |
+
class PackGQAComb(PackGQA):
|
| 239 |
+
"""PackGQA subclass for the K2 combine kernel.
|
| 240 |
+
|
| 241 |
+
Inherits ``compute_ptr`` / ``load_Q`` / ``store_LSE`` / ``store_O`` from
|
| 242 |
+
``PackGQA``. Adds ``load_LSE`` for coalesced GMEM→SMEM async copies when
|
| 243 |
+
LSE_partial is laid out with H_q innermost.
|
| 244 |
+
|
| 245 |
+
K2 combine treats each query head independently (no GQA grouping in combine
|
| 246 |
+
itself), so ``qhead_per_kvhead`` is set to ``num_heads_q`` by the caller —
|
| 247 |
+
all heads are folded into one "group" per Sq position.
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
@cute.jit
|
| 251 |
+
def load_LSE(
|
| 252 |
+
self,
|
| 253 |
+
mLSE_partial: cute.Tensor,
|
| 254 |
+
# Packed layout after caller-side reshape:
|
| 255 |
+
# shape ((qhead_per_kvhead, seqlen_q), num_splits)
|
| 256 |
+
# stride ((1, qhead_per_kvhead), ...)
|
| 257 |
+
# — H_q is the innermost (stride-1) element of the packed first dim.
|
| 258 |
+
sLSE: cute.Tensor,
|
| 259 |
+
# SMEM destination: ``(topk, m_block_size)`` fp32.
|
| 260 |
+
topk: cutlass.Constexpr[int],
|
| 261 |
+
# Explicit topk so the identity tensor shape is a plain int,
|
| 262 |
+
# avoiding compound-shape traps from sLSE.shape[0] after tile_to_shape.
|
| 263 |
+
gmem_tiled_copy: cute.TiledCopy,
|
| 264 |
+
tidx: Int32,
|
| 265 |
+
block: Int32,
|
| 266 |
+
num_splits: Int32,
|
| 267 |
+
seqlen: Int32,
|
| 268 |
+
num_heads_divmod: FastDivmodDivisor,
|
| 269 |
+
mCounter: Optional[cute.Tensor] = None,
|
| 270 |
+
batch_idx: Optional[Int32] = None,
|
| 271 |
+
qhead_per_kvhead: Int32 = Int32(1),
|
| 272 |
+
# divmod for ``m_pos = idx // qhead_per_kvhead``; passed explicitly so
|
| 273 |
+
# caller controls whether the divisor is constexpr or a runtime value.
|
| 274 |
+
):
|
| 275 |
+
"""Coalesced GMEM→SMEM async load of LSE_partial for one tile.
|
| 276 |
+
|
| 277 |
+
For each (split, row) slot this thread owns in the tile, compute the
|
| 278 |
+
GMEM coordinate ``(h_pos, m_pos)`` via PackGQA divmod and copy one fp32.
|
| 279 |
+
Out-of-bounds rows (``m_pos >= seqlen``) and splits (``si >= num_splits``)
|
| 280 |
+
are filled with ``-inf`` so they flow cleanly through downstream reductions.
|
| 281 |
+
|
| 282 |
+
Coalescing: adjacent thread rows correspond to adjacent ``h_pos`` values
|
| 283 |
+
(head varies fast under ``divmod(idx, qhead_per_kvhead)``), which map to
|
| 284 |
+
adjacent GMEM addresses when H_q is stride-1 — one sector per warp.
|
| 285 |
+
"""
|
| 286 |
+
gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
|
| 287 |
+
cLSE = cute.make_identity_tensor((topk, self.m_block_size))
|
| 288 |
+
tLSEcLSE = gmem_thr_copy.partition_S(cLSE)
|
| 289 |
+
tLSEsLSE = gmem_thr_copy.partition_D(sLSE)
|
| 290 |
+
|
| 291 |
+
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 292 |
+
mi = tLSEcLSE[0, 0, m][1]
|
| 293 |
+
idx = block * self.m_block_size + mi
|
| 294 |
+
m_pos, h_pos = divmod(idx, num_heads_divmod)
|
| 295 |
+
|
| 296 |
+
if m_pos < seqlen:
|
| 297 |
+
row_count = (
|
| 298 |
+
mCounter[batch_idx, m_pos, h_pos // qhead_per_kvhead]
|
| 299 |
+
if const_expr(mCounter is not None)
|
| 300 |
+
else num_splits
|
| 301 |
+
)
|
| 302 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 303 |
+
si = tLSEcLSE[0, s, 0][0]
|
| 304 |
+
if si < num_splits and si < row_count:
|
| 305 |
+
# Build a 1-element GMEM tensor at ((h_pos, m_pos), si),
|
| 306 |
+
# matching PackGQA.store_LSE's ptr pattern so cute.copy
|
| 307 |
+
# receives a proper Tensor, not a scalar.
|
| 308 |
+
src_ptr_i64 = utils.elem_pointer(
|
| 309 |
+
mLSE_partial, ((h_pos, m_pos), si)).toint()
|
| 310 |
+
src_ptr = cute.make_ptr(
|
| 311 |
+
Float32, src_ptr_i64,
|
| 312 |
+
cute.AddressSpace.gmem, assumed_align=4,
|
| 313 |
+
)
|
| 314 |
+
src_t = cute.make_tensor(src_ptr, (1,))
|
| 315 |
+
cute.copy(gmem_thr_copy, src_t, tLSEsLSE[None, s, m])
|
| 316 |
+
else:
|
| 317 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 318 |
+
else:
|
| 319 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 320 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/paged_kv.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import cutlass
|
| 7 |
+
import cutlass.cute as cute
|
| 8 |
+
from cutlass import Int32, const_expr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class PagedKVManager:
|
| 13 |
+
mPageTable: cute.Tensor
|
| 14 |
+
page_size: cutlass.Constexpr[int]
|
| 15 |
+
n_block_size: cutlass.Constexpr[int]
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def create(
|
| 19 |
+
mPageTable: cute.Tensor,
|
| 20 |
+
*,
|
| 21 |
+
page_size: int,
|
| 22 |
+
n_block_size: int,
|
| 23 |
+
):
|
| 24 |
+
if page_size != n_block_size:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
f"page_size ({page_size}) must equal blk_kv ({n_block_size})"
|
| 27 |
+
)
|
| 28 |
+
return PagedKVManager(
|
| 29 |
+
mPageTable,
|
| 30 |
+
page_size=page_size,
|
| 31 |
+
n_block_size=n_block_size,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@cute.jit
|
| 35 |
+
def logical_length(
|
| 36 |
+
self,
|
| 37 |
+
batch_idx: Int32,
|
| 38 |
+
num_kv_blocks: Int32,
|
| 39 |
+
mSeqUsedK=None,
|
| 40 |
+
) -> Int32:
|
| 41 |
+
if const_expr(mSeqUsedK is not None):
|
| 42 |
+
return mSeqUsedK[batch_idx]
|
| 43 |
+
return num_kv_blocks * Int32(self.n_block_size)
|
| 44 |
+
|
| 45 |
+
@cute.jit
|
| 46 |
+
def valid_cols_in_block(
|
| 47 |
+
self,
|
| 48 |
+
batch_idx: Int32,
|
| 49 |
+
kv_block_idx: Int32,
|
| 50 |
+
num_kv_blocks: Int32,
|
| 51 |
+
mSeqUsedK=None,
|
| 52 |
+
) -> Int32:
|
| 53 |
+
seqlen_k = self.logical_length(batch_idx, num_kv_blocks, mSeqUsedK)
|
| 54 |
+
block_start = kv_block_idx * Int32(self.n_block_size)
|
| 55 |
+
remaining = seqlen_k - block_start
|
| 56 |
+
remaining = cutlass.max(remaining, Int32(0))
|
| 57 |
+
return cutlass.min(remaining, Int32(self.n_block_size))
|
| 58 |
+
|
| 59 |
+
@cute.jit
|
| 60 |
+
def physical_block_index(
|
| 61 |
+
self,
|
| 62 |
+
batch_idx: Int32,
|
| 63 |
+
kv_block_idx: Int32,
|
| 64 |
+
) -> Int32:
|
| 65 |
+
return self.mPageTable[batch_idx, kv_block_idx]
|
| 66 |
+
|
| 67 |
+
__all__ = ["PagedKVManager"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/pipeline.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
# import math
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Boolean, Int32, const_expr
|
| 10 |
+
from cutlass.cutlass_dsl import if_generate, dsl_user_op
|
| 11 |
+
from cutlass.pipeline import PipelineState
|
| 12 |
+
from cutlass.pipeline import PipelineUserType
|
| 13 |
+
from cutlass.pipeline import NamedBarrier as NamedBarrierOg
|
| 14 |
+
from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
|
| 15 |
+
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
|
| 16 |
+
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
|
| 17 |
+
from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
|
| 18 |
+
from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
|
| 19 |
+
import cutlass.pipeline as cutlass_pipeline
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
| 23 |
+
"""Compatibility wrapper for FA-style helpers now vendored into src.common."""
|
| 24 |
+
return cutlass_pipeline.make_pipeline_state(type, stages)
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True)
|
| 27 |
+
class NamedBarrier(NamedBarrierOg):
|
| 28 |
+
@staticmethod
|
| 29 |
+
def create(*args, **kwargs):
|
| 30 |
+
obj = NamedBarrierOg.create(*args, **kwargs)
|
| 31 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 32 |
+
object.__setattr__(obj, "__class__", NamedBarrier)
|
| 33 |
+
return obj
|
| 34 |
+
|
| 35 |
+
@dsl_user_op
|
| 36 |
+
def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 37 |
+
"""
|
| 38 |
+
The aligned flavor of arrive is used when all threads in the CTA will execute the
|
| 39 |
+
same instruction. See PTX documentation.
|
| 40 |
+
"""
|
| 41 |
+
cute.arch.barrier_arrive(
|
| 42 |
+
barrier_id=self.barrier_id + index,
|
| 43 |
+
number_of_threads=self.num_threads,
|
| 44 |
+
loc=loc,
|
| 45 |
+
ip=ip,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
@dsl_user_op
|
| 49 |
+
def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
|
| 50 |
+
cute.arch.barrier(
|
| 51 |
+
barrier_id=self.barrier_id + index,
|
| 52 |
+
number_of_threads=self.num_threads,
|
| 53 |
+
loc=loc,
|
| 54 |
+
ip=ip,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass(frozen=True)
|
| 59 |
+
class PipelineAsync(PipelineAsyncOg):
|
| 60 |
+
@staticmethod
|
| 61 |
+
def create(*args, **kwargs):
|
| 62 |
+
obj = PipelineAsyncOg.create(*args, **kwargs)
|
| 63 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 64 |
+
# obj.__class__ = PipelineAsync
|
| 65 |
+
object.__setattr__(obj, "__class__", PipelineAsync)
|
| 66 |
+
return obj
|
| 67 |
+
|
| 68 |
+
@dsl_user_op
|
| 69 |
+
def producer_acquire_w_index_phase(
|
| 70 |
+
self,
|
| 71 |
+
index: Int32,
|
| 72 |
+
phase: Int32,
|
| 73 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 74 |
+
*,
|
| 75 |
+
loc=None,
|
| 76 |
+
ip=None,
|
| 77 |
+
):
|
| 78 |
+
if_generate(
|
| 79 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 80 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 81 |
+
loc=loc,
|
| 82 |
+
ip=ip,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
@dsl_user_op
|
| 86 |
+
def producer_try_acquire_w_index_phase(
|
| 87 |
+
self,
|
| 88 |
+
index: Int32,
|
| 89 |
+
phase: Int32,
|
| 90 |
+
*,
|
| 91 |
+
loc=None,
|
| 92 |
+
ip=None,
|
| 93 |
+
):
|
| 94 |
+
return self.sync_object_empty.try_wait(index, phase, loc=loc, ip=ip)
|
| 95 |
+
|
| 96 |
+
@dsl_user_op
|
| 97 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 98 |
+
self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
|
| 99 |
+
|
| 100 |
+
@dsl_user_op
|
| 101 |
+
def consumer_wait_w_index_phase(
|
| 102 |
+
self,
|
| 103 |
+
index: Int32,
|
| 104 |
+
phase: Int32,
|
| 105 |
+
try_wait_token: Optional[Boolean] = None,
|
| 106 |
+
*,
|
| 107 |
+
loc=None,
|
| 108 |
+
ip=None,
|
| 109 |
+
):
|
| 110 |
+
if_generate(
|
| 111 |
+
try_wait_token is None or try_wait_token == 0,
|
| 112 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 113 |
+
loc=loc,
|
| 114 |
+
ip=ip,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
@dsl_user_op
|
| 118 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 119 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass(frozen=True)
|
| 123 |
+
class PipelineTmaAsync(PipelineTmaAsyncOg):
|
| 124 |
+
"""
|
| 125 |
+
Override producer_acquire to take in extra_tx_count parameter.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def create(*args, **kwargs):
|
| 130 |
+
obj = PipelineTmaAsyncOg.create(*args, **kwargs)
|
| 131 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 132 |
+
object.__setattr__(obj, "__class__", PipelineTmaAsync)
|
| 133 |
+
return obj
|
| 134 |
+
|
| 135 |
+
@dsl_user_op
|
| 136 |
+
def producer_acquire(
|
| 137 |
+
self,
|
| 138 |
+
state: PipelineState,
|
| 139 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 140 |
+
extra_tx_count: int = 0,
|
| 141 |
+
*,
|
| 142 |
+
loc=None,
|
| 143 |
+
ip=None,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 147 |
+
"""
|
| 148 |
+
if_generate(
|
| 149 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 150 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 151 |
+
loc=loc,
|
| 152 |
+
ip=ip,
|
| 153 |
+
)
|
| 154 |
+
if const_expr(extra_tx_count == 0):
|
| 155 |
+
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
|
| 156 |
+
else:
|
| 157 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 158 |
+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@dataclass(frozen=True)
|
| 162 |
+
class PipelineTmaUmma(PipelineTmaUmmaOg):
|
| 163 |
+
"""
|
| 164 |
+
Override producer_acquire to take in extra_tx_count parameter.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def create(*args, **kwargs):
|
| 169 |
+
obj = PipelineTmaUmmaOg.create(*args, **kwargs)
|
| 170 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 171 |
+
# obj.__class__ = PipelineTmaUmma
|
| 172 |
+
object.__setattr__(obj, "__class__", PipelineTmaUmma)
|
| 173 |
+
return obj
|
| 174 |
+
|
| 175 |
+
@dsl_user_op
|
| 176 |
+
def producer_acquire(
|
| 177 |
+
self,
|
| 178 |
+
state: PipelineState,
|
| 179 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 180 |
+
extra_tx_count: int = 0,
|
| 181 |
+
*,
|
| 182 |
+
loc=None,
|
| 183 |
+
ip=None,
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 187 |
+
"""
|
| 188 |
+
if_generate(
|
| 189 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 190 |
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
| 191 |
+
loc=loc,
|
| 192 |
+
ip=ip,
|
| 193 |
+
)
|
| 194 |
+
if const_expr(extra_tx_count == 0):
|
| 195 |
+
if_generate(
|
| 196 |
+
self.is_leader_cta,
|
| 197 |
+
lambda: self.sync_object_full.arrive(
|
| 198 |
+
state.index, self.producer_mask, loc=loc, ip=ip
|
| 199 |
+
),
|
| 200 |
+
loc=loc,
|
| 201 |
+
ip=ip,
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
tx_count = self.sync_object_full.tx_count + extra_tx_count
|
| 205 |
+
if_generate(
|
| 206 |
+
self.is_leader_cta,
|
| 207 |
+
lambda: self.sync_object_full.arrive_and_expect_tx(
|
| 208 |
+
state.index, tx_count, loc=loc, ip=ip
|
| 209 |
+
),
|
| 210 |
+
loc=loc,
|
| 211 |
+
ip=ip,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
@dsl_user_op
|
| 215 |
+
def producer_acquire_w_index_phase(
|
| 216 |
+
self,
|
| 217 |
+
index: Int32,
|
| 218 |
+
phase: Int32,
|
| 219 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 220 |
+
*,
|
| 221 |
+
loc=None,
|
| 222 |
+
ip=None,
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
| 226 |
+
"""
|
| 227 |
+
if_generate(
|
| 228 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 229 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 230 |
+
loc=loc,
|
| 231 |
+
ip=ip,
|
| 232 |
+
)
|
| 233 |
+
if_generate(
|
| 234 |
+
self.is_leader_cta,
|
| 235 |
+
lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
|
| 236 |
+
loc=loc,
|
| 237 |
+
ip=ip,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
@dsl_user_op
|
| 241 |
+
def consumer_wait_w_index_phase(
|
| 242 |
+
self,
|
| 243 |
+
index: Int32,
|
| 244 |
+
phase: Int32,
|
| 245 |
+
try_wait_token: Optional[Boolean] = None,
|
| 246 |
+
*,
|
| 247 |
+
loc=None,
|
| 248 |
+
ip=None,
|
| 249 |
+
):
|
| 250 |
+
if_generate(
|
| 251 |
+
try_wait_token is None or try_wait_token == 0,
|
| 252 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 253 |
+
loc=loc,
|
| 254 |
+
ip=ip,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
@dsl_user_op
|
| 258 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 259 |
+
"""
|
| 260 |
+
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 261 |
+
"""
|
| 262 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@dataclass(frozen=True)
|
| 266 |
+
class PipelineUmmaAsync(PipelineUmmaAsyncOg):
|
| 267 |
+
@staticmethod
|
| 268 |
+
def create(*args, **kwargs):
|
| 269 |
+
obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
|
| 270 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 271 |
+
object.__setattr__(obj, "__class__", PipelineUmmaAsync)
|
| 272 |
+
return obj
|
| 273 |
+
|
| 274 |
+
@dsl_user_op
|
| 275 |
+
def producer_acquire_w_index_phase(
|
| 276 |
+
self,
|
| 277 |
+
index: Int32,
|
| 278 |
+
phase: Int32,
|
| 279 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 280 |
+
*,
|
| 281 |
+
loc=None,
|
| 282 |
+
ip=None,
|
| 283 |
+
):
|
| 284 |
+
if_generate(
|
| 285 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 286 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 287 |
+
loc=loc,
|
| 288 |
+
ip=ip,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
@dsl_user_op
|
| 292 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 293 |
+
"""
|
| 294 |
+
UMMA producer commit buffer full, cta_group needs to be provided.
|
| 295 |
+
"""
|
| 296 |
+
self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
|
| 297 |
+
|
| 298 |
+
@dsl_user_op
|
| 299 |
+
def consumer_wait_w_index_phase(
|
| 300 |
+
self,
|
| 301 |
+
index: Int32,
|
| 302 |
+
phase: Int32,
|
| 303 |
+
try_wait_token: Optional[Boolean] = None,
|
| 304 |
+
*,
|
| 305 |
+
loc=None,
|
| 306 |
+
ip=None,
|
| 307 |
+
):
|
| 308 |
+
if_generate(
|
| 309 |
+
try_wait_token is None or try_wait_token == 0,
|
| 310 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 311 |
+
loc=loc,
|
| 312 |
+
ip=ip,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
@dsl_user_op
|
| 316 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 317 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@dataclass(frozen=True)
|
| 321 |
+
class PipelineAsyncUmma(PipelineAsyncUmmaOg):
|
| 322 |
+
@staticmethod
|
| 323 |
+
def create(*args, **kwargs):
|
| 324 |
+
obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
|
| 325 |
+
# Can't assign to __class__ directly since the dataclass is frozen
|
| 326 |
+
object.__setattr__(obj, "__class__", PipelineAsyncUmma)
|
| 327 |
+
return obj
|
| 328 |
+
|
| 329 |
+
@dsl_user_op
|
| 330 |
+
def producer_acquire_w_index_phase(
|
| 331 |
+
self,
|
| 332 |
+
index: Int32,
|
| 333 |
+
phase: Int32,
|
| 334 |
+
try_acquire_token: Optional[Boolean] = None,
|
| 335 |
+
*,
|
| 336 |
+
loc=None,
|
| 337 |
+
ip=None,
|
| 338 |
+
):
|
| 339 |
+
if_generate(
|
| 340 |
+
try_acquire_token is None or try_acquire_token == 0,
|
| 341 |
+
lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
|
| 342 |
+
loc=loc,
|
| 343 |
+
ip=ip,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
@dsl_user_op
|
| 347 |
+
def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 348 |
+
self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
|
| 349 |
+
|
| 350 |
+
@dsl_user_op
|
| 351 |
+
def consumer_wait_w_index_phase(
|
| 352 |
+
self,
|
| 353 |
+
index: Int32,
|
| 354 |
+
phase: Int32,
|
| 355 |
+
try_wait_token: Optional[Boolean] = None,
|
| 356 |
+
*,
|
| 357 |
+
loc=None,
|
| 358 |
+
ip=None,
|
| 359 |
+
):
|
| 360 |
+
if_generate(
|
| 361 |
+
try_wait_token is None or try_wait_token == 0,
|
| 362 |
+
lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
|
| 363 |
+
loc=loc,
|
| 364 |
+
ip=ip,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
@dsl_user_op
|
| 368 |
+
def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
|
| 369 |
+
"""
|
| 370 |
+
UMMA consumer release buffer empty, cta_group needs to be provided.
|
| 371 |
+
"""
|
| 372 |
+
self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/seqlen_info.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import cutlass
|
| 8 |
+
import cutlass.cute as cute
|
| 9 |
+
from cutlass import Int32, const_expr
|
| 10 |
+
|
| 11 |
+
from ...quack import copy_utils
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
This consolidates all the info related to sequence length. This is so that we can do all
|
| 15 |
+
the gmem reads once at the beginning of each tile, rather than having to repeat these reads
|
| 16 |
+
to compute various things like n_block_min, n_block_max, etc.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class SeqlenInfo:
|
| 22 |
+
offset: Int32
|
| 23 |
+
offset_padded: Int32
|
| 24 |
+
seqlen: Int32
|
| 25 |
+
has_cu_seqlens: cutlass.Constexpr[bool] = False
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def create(
|
| 29 |
+
batch_idx: Int32,
|
| 30 |
+
seqlen_static: Int32,
|
| 31 |
+
cu_seqlens: Optional[cute.Tensor] = None,
|
| 32 |
+
seqused: Optional[cute.Tensor] = None,
|
| 33 |
+
tile: cutlass.Constexpr[int] = 128,
|
| 34 |
+
):
|
| 35 |
+
offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
|
| 36 |
+
offset_padded = (
|
| 37 |
+
0
|
| 38 |
+
if const_expr(cu_seqlens is None)
|
| 39 |
+
# Add divby so that the compiler knows the alignment when moving by offset_padded
|
| 40 |
+
else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile)
|
| 41 |
+
)
|
| 42 |
+
if const_expr(seqused is not None):
|
| 43 |
+
seqlen = seqused[batch_idx]
|
| 44 |
+
elif const_expr(cu_seqlens is not None):
|
| 45 |
+
seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
|
| 46 |
+
else:
|
| 47 |
+
seqlen = seqlen_static
|
| 48 |
+
return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None)
|
| 49 |
+
|
| 50 |
+
def offset_batch(
|
| 51 |
+
self,
|
| 52 |
+
mT: cute.Tensor,
|
| 53 |
+
batch_idx: Int32,
|
| 54 |
+
dim: int,
|
| 55 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 56 |
+
multiple: int = 1,
|
| 57 |
+
) -> cute.Tensor:
|
| 58 |
+
"""Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0."""
|
| 59 |
+
if const_expr(not self.has_cu_seqlens):
|
| 60 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim)
|
| 61 |
+
return mT[idx]
|
| 62 |
+
else:
|
| 63 |
+
off = multiple * (self.offset if const_expr(not padded) else self.offset_padded)
|
| 64 |
+
offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off)
|
| 65 |
+
idx = (offset,) + (None,) * (cute.rank(mT) - 1)
|
| 66 |
+
return cute.domain_offset(idx, mT)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass(frozen=True)
|
| 70 |
+
class SeqlenInfoQK:
|
| 71 |
+
offset_q: Int32
|
| 72 |
+
offset_k: Int32
|
| 73 |
+
padded_offset_q: Int32
|
| 74 |
+
padded_offset_k: Int32
|
| 75 |
+
seqlen_q: Int32
|
| 76 |
+
seqlen_k: Int32
|
| 77 |
+
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
| 78 |
+
has_cu_seqlens_k: cutlass.Constexpr[bool]
|
| 79 |
+
has_seqused_q: cutlass.Constexpr[bool]
|
| 80 |
+
has_seqused_k: cutlass.Constexpr[bool]
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def create(
|
| 84 |
+
batch_idx: Int32,
|
| 85 |
+
seqlen_q_static: Int32,
|
| 86 |
+
seqlen_k_static: Int32,
|
| 87 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
| 88 |
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
| 89 |
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
| 90 |
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
| 91 |
+
tile_m: cutlass.Constexpr[Int32] = 128,
|
| 92 |
+
tile_n: cutlass.Constexpr[Int32] = 128,
|
| 93 |
+
):
|
| 94 |
+
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
| 95 |
+
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
| 96 |
+
padded_offset_q = (
|
| 97 |
+
0
|
| 98 |
+
if const_expr(mCuSeqlensQ is None)
|
| 99 |
+
else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m)
|
| 100 |
+
)
|
| 101 |
+
padded_offset_k = (
|
| 102 |
+
0
|
| 103 |
+
if const_expr(mCuSeqlensK is None)
|
| 104 |
+
else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n)
|
| 105 |
+
)
|
| 106 |
+
if const_expr(mSeqUsedQ is not None):
|
| 107 |
+
seqlen_q = mSeqUsedQ[batch_idx]
|
| 108 |
+
else:
|
| 109 |
+
seqlen_q = (
|
| 110 |
+
seqlen_q_static
|
| 111 |
+
if const_expr(mCuSeqlensQ is None)
|
| 112 |
+
else mCuSeqlensQ[batch_idx + 1] - offset_q
|
| 113 |
+
)
|
| 114 |
+
if const_expr(mSeqUsedK is not None):
|
| 115 |
+
seqlen_k = mSeqUsedK[batch_idx]
|
| 116 |
+
else:
|
| 117 |
+
seqlen_k = (
|
| 118 |
+
seqlen_k_static
|
| 119 |
+
if const_expr(mCuSeqlensK is None)
|
| 120 |
+
else mCuSeqlensK[batch_idx + 1] - offset_k
|
| 121 |
+
)
|
| 122 |
+
return SeqlenInfoQK(
|
| 123 |
+
offset_q,
|
| 124 |
+
offset_k,
|
| 125 |
+
padded_offset_q,
|
| 126 |
+
padded_offset_k,
|
| 127 |
+
seqlen_q,
|
| 128 |
+
seqlen_k,
|
| 129 |
+
has_cu_seqlens_q=mCuSeqlensQ is not None,
|
| 130 |
+
has_cu_seqlens_k=mCuSeqlensK is not None,
|
| 131 |
+
has_seqused_q=mSeqUsedQ is not None,
|
| 132 |
+
has_seqused_k=mSeqUsedK is not None,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def offset_batch_Q(
|
| 136 |
+
self,
|
| 137 |
+
mQ: cute.Tensor,
|
| 138 |
+
batch_idx: Int32,
|
| 139 |
+
dim: int,
|
| 140 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 141 |
+
ragged: cutlass.Constexpr[bool] = False,
|
| 142 |
+
) -> cute.Tensor:
|
| 143 |
+
"""Seqlen must be the first dimension of mQ"""
|
| 144 |
+
if const_expr(not ragged):
|
| 145 |
+
if const_expr(not self.has_cu_seqlens_q):
|
| 146 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
| 147 |
+
return mQ[idx]
|
| 148 |
+
else:
|
| 149 |
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
| 150 |
+
offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q)
|
| 151 |
+
idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1)
|
| 152 |
+
return cute.domain_offset(idx, mQ)
|
| 153 |
+
else:
|
| 154 |
+
if const_expr(not self.has_cu_seqlens_q):
|
| 155 |
+
offset_q = 0
|
| 156 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
| 157 |
+
mQ = mQ[idx]
|
| 158 |
+
else:
|
| 159 |
+
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
| 160 |
+
if const_expr(cute.rank(mQ.shape[0]) == 1):
|
| 161 |
+
return copy_utils.offset_ragged_tensor(
|
| 162 |
+
mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True
|
| 163 |
+
)
|
| 164 |
+
else: # PackGQA
|
| 165 |
+
assert cute.rank(mQ.shape[0]) == 2
|
| 166 |
+
# Unpack before calling offset_ragged_tensor, then pack
|
| 167 |
+
idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1)
|
| 168 |
+
mQ = mQ[idx]
|
| 169 |
+
mQ = copy_utils.offset_ragged_tensor(
|
| 170 |
+
mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True
|
| 171 |
+
)
|
| 172 |
+
return cute.group_modes(mQ, 0, 2)
|
| 173 |
+
|
| 174 |
+
def offset_batch_K(
|
| 175 |
+
self,
|
| 176 |
+
mK: cute.Tensor,
|
| 177 |
+
batch_idx: Int32,
|
| 178 |
+
dim: int,
|
| 179 |
+
padded: cutlass.Constexpr[bool] = False,
|
| 180 |
+
ragged: cutlass.Constexpr[bool] = False,
|
| 181 |
+
multiple: int = 1,
|
| 182 |
+
) -> cute.Tensor:
|
| 183 |
+
"""Seqlen must be the first dimension of mK"""
|
| 184 |
+
if const_expr(not ragged):
|
| 185 |
+
if const_expr(not self.has_cu_seqlens_k):
|
| 186 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
| 187 |
+
return mK[idx]
|
| 188 |
+
else:
|
| 189 |
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
| 190 |
+
offset_k *= multiple
|
| 191 |
+
idx = (offset_k,) + (None,) * (cute.rank(mK) - 1)
|
| 192 |
+
return cute.domain_offset(idx, mK)
|
| 193 |
+
else:
|
| 194 |
+
if const_expr(not self.has_cu_seqlens_k):
|
| 195 |
+
offset_k = 0
|
| 196 |
+
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
| 197 |
+
mK = mK[idx]
|
| 198 |
+
else:
|
| 199 |
+
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
| 200 |
+
offset_k *= multiple
|
| 201 |
+
return copy_utils.offset_ragged_tensor(
|
| 202 |
+
mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True
|
| 203 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/softmax.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Online softmax primitives.
|
| 5 |
+
|
| 6 |
+
Contains:
|
| 7 |
+
- ``Softmax``: SM80/90 base class with online softmax + finalize + rescale_O.
|
| 8 |
+
The ``rescale_O`` path branches on ``arch >= 100`` to emit SM100 packed
|
| 9 |
+
``fmul.f32x2`` (2× CUDA-core throughput) when available.
|
| 10 |
+
- ``SoftmaxSm100``: SM100-specific subclass exposing fused ``update_row_max``,
|
| 11 |
+
``scale_apply_exp2_convert`` etc. used by the UTCMMA warp-specialized kernel.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
import operator
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import cutlass
|
| 20 |
+
import cutlass.cute as cute
|
| 21 |
+
from cutlass import Float32
|
| 22 |
+
|
| 23 |
+
from ...quack import layout_utils
|
| 24 |
+
from ...quack.cute_dsl_utils import ParamsBase
|
| 25 |
+
|
| 26 |
+
from . import utils
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Softmax(ParamsBase):
|
| 31 |
+
scale_log2: Float32
|
| 32 |
+
num_rows: cutlass.Constexpr[int]
|
| 33 |
+
row_max: cute.Tensor
|
| 34 |
+
row_sum: cute.Tensor
|
| 35 |
+
arch: cutlass.Constexpr[int] = 80
|
| 36 |
+
softmax_scale: Float32 | None = None
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def create(
|
| 40 |
+
scale_log2: Float32,
|
| 41 |
+
num_rows: cutlass.Constexpr[int],
|
| 42 |
+
arch: cutlass.Constexpr[int] = 80,
|
| 43 |
+
softmax_scale: Float32 | None = None,
|
| 44 |
+
):
|
| 45 |
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
| 46 |
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
| 47 |
+
return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
|
| 48 |
+
|
| 49 |
+
def reset(self) -> None:
|
| 50 |
+
self.row_max.fill(-Float32.inf)
|
| 51 |
+
self.row_sum.fill(0.0)
|
| 52 |
+
|
| 53 |
+
def _compute_row_max(
|
| 54 |
+
self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
|
| 55 |
+
) -> Float32:
|
| 56 |
+
return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
|
| 57 |
+
|
| 58 |
+
def _compute_row_sum(
|
| 59 |
+
self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
|
| 60 |
+
) -> Float32:
|
| 61 |
+
return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
|
| 62 |
+
|
| 63 |
+
@cute.jit
|
| 64 |
+
def online_softmax(
|
| 65 |
+
self,
|
| 66 |
+
acc_S: cute.Tensor,
|
| 67 |
+
is_first: cutlass.Constexpr[bool] = False,
|
| 68 |
+
check_inf: cutlass.Constexpr[bool] = True,
|
| 69 |
+
) -> cute.Tensor:
|
| 70 |
+
"""Apply online softmax and return the row_scale to rescale O.
|
| 71 |
+
|
| 72 |
+
On SM100+ the inner ``acc_S_row * scale_log2 - row_max_scaled`` is
|
| 73 |
+
rewritten as explicit ``fma_packed_f32x2`` intrinsics — the DSL
|
| 74 |
+
compiler does not fuse TensorSSA ``mul + sub`` into FFMA2 (NCU
|
| 75 |
+
confirms: FFMA2 count is 0 for the TensorSSA path). The packed
|
| 76 |
+
rewrite issues one FFMA.F32X2 per pair, halving the scalar FFMA
|
| 77 |
+
instruction count for the softmax scale/subtract stage.
|
| 78 |
+
"""
|
| 79 |
+
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
|
| 80 |
+
row_scale = cute.make_rmem_tensor_like(self.row_max, Float32)
|
| 81 |
+
|
| 82 |
+
row_max = self.row_max
|
| 83 |
+
row_sum = self.row_sum
|
| 84 |
+
scale_log2 = self.scale_log2
|
| 85 |
+
arch = self.arch
|
| 86 |
+
|
| 87 |
+
for r in cutlass.range(cute.size(row_max), unroll_full=True):
|
| 88 |
+
acc_S_row_slice = acc_S_mn[r, None]
|
| 89 |
+
acc_S_row = acc_S_row_slice.load()
|
| 90 |
+
|
| 91 |
+
row_max_cur = utils.fmax_reduce(
|
| 92 |
+
acc_S_row,
|
| 93 |
+
init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
|
| 94 |
+
arch=arch,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4)
|
| 98 |
+
row_max_prev = row_max[r]
|
| 99 |
+
row_max[r] = row_max_cur
|
| 100 |
+
|
| 101 |
+
if cutlass.const_expr(check_inf):
|
| 102 |
+
row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
|
| 103 |
+
|
| 104 |
+
row_max_cur_scaled = row_max_cur * scale_log2
|
| 105 |
+
minus_row_max_scaled = -row_max_cur_scaled
|
| 106 |
+
n = cute.size(acc_S_row_slice)
|
| 107 |
+
|
| 108 |
+
if cutlass.const_expr(arch >= 100 and n % 2 == 0):
|
| 109 |
+
# SM100 packed f32x2 FMA path: scale + subtract in one pass.
|
| 110 |
+
for i in cutlass.range(0, n, 2, unroll_full=True):
|
| 111 |
+
acc_S_row_slice[i], acc_S_row_slice[i + 1] = cute.arch.fma_packed_f32x2(
|
| 112 |
+
(acc_S_row_slice[i], acc_S_row_slice[i + 1]),
|
| 113 |
+
(scale_log2, scale_log2),
|
| 114 |
+
(minus_row_max_scaled, minus_row_max_scaled),
|
| 115 |
+
)
|
| 116 |
+
for i in cutlass.range(n, unroll_full=True):
|
| 117 |
+
acc_S_row_slice[i] = cute.math.exp2(acc_S_row_slice[i], fastmath=True)
|
| 118 |
+
acc_S_row_exp = acc_S_row_slice.load()
|
| 119 |
+
else:
|
| 120 |
+
acc_S_row_exp = cute.math.exp2(
|
| 121 |
+
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
|
| 122 |
+
)
|
| 123 |
+
acc_S_row_slice.store(acc_S_row_exp)
|
| 124 |
+
|
| 125 |
+
if cutlass.const_expr(is_first):
|
| 126 |
+
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
|
| 127 |
+
row_scale[r] = 1.0
|
| 128 |
+
else:
|
| 129 |
+
row_scale[r] = cute.math.exp2(
|
| 130 |
+
(row_max_prev - row_max_cur) * scale_log2, fastmath=True
|
| 131 |
+
)
|
| 132 |
+
acc_S_row_sum = utils.fadd_reduce(
|
| 133 |
+
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
row_sum[r] = acc_S_row_sum
|
| 137 |
+
|
| 138 |
+
return row_scale
|
| 139 |
+
|
| 140 |
+
@cute.jit
|
| 141 |
+
def finalize(
|
| 142 |
+
self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
|
| 143 |
+
) -> cute.Tensor:
|
| 144 |
+
"""Finalize the online softmax by computing the scale and logsumexp.
|
| 145 |
+
|
| 146 |
+
On SM100+ with an even ``num_rows`` and no sink_val, the loop is
|
| 147 |
+
unrolled in pairs so the key per-row arithmetic ― rcp*final_scale,
|
| 148 |
+
max*scale_log2 + log2(sum), and the final *LN2 ― collapses into one
|
| 149 |
+
``mul_packed_f32x2`` + one ``fma_packed_f32x2`` + one more
|
| 150 |
+
``mul_packed_f32x2`` per row pair. Sink_val path stays scalar (rare).
|
| 151 |
+
"""
|
| 152 |
+
if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
|
| 153 |
+
assert cute.size(sink_val) == cute.size(self.row_sum)
|
| 154 |
+
row_sum = self.row_sum
|
| 155 |
+
row_max = self.row_max
|
| 156 |
+
scale_log2 = self.scale_log2
|
| 157 |
+
|
| 158 |
+
row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
|
| 159 |
+
row_scale = cute.make_rmem_tensor_like(row_max, Float32)
|
| 160 |
+
|
| 161 |
+
LN2 = math.log(2.0)
|
| 162 |
+
num_rows = cute.size(row_sum)
|
| 163 |
+
use_packed = cutlass.const_expr(
|
| 164 |
+
self.arch >= 100 and num_rows % 2 == 0 and sink_val is None
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if use_packed:
|
| 168 |
+
for r in cutlass.range(0, num_rows, 2, unroll_full=True):
|
| 169 |
+
s0 = row_sum[r]
|
| 170 |
+
s1 = row_sum[r + 1]
|
| 171 |
+
m0 = row_max[r]
|
| 172 |
+
m1 = row_max[r + 1]
|
| 173 |
+
bad0 = s0 == 0.0 or s0 != s0
|
| 174 |
+
bad1 = s1 == 0.0 or s1 != s1
|
| 175 |
+
|
| 176 |
+
# row_scale = rcp_approx(safe_sum) * final_scale — rcp is scalar
|
| 177 |
+
# (no packed rcp intrinsic); the trailing multiply packs.
|
| 178 |
+
rcp0 = cute.arch.rcp_approx(1.0 if bad0 else s0)
|
| 179 |
+
rcp1 = cute.arch.rcp_approx(1.0 if bad1 else s1)
|
| 180 |
+
row_scale[r], row_scale[r + 1] = cute.arch.mul_packed_f32x2(
|
| 181 |
+
(rcp0, rcp1), (final_scale, final_scale)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# LSE = (row_max * scale_log2 + log2(row_sum)) * LN2
|
| 185 |
+
# packed FMA for (max*scale_log2 + log2_sum), packed MUL for *LN2.
|
| 186 |
+
log0 = cute.math.log2(s0, fastmath=True)
|
| 187 |
+
log1 = cute.math.log2(s1, fastmath=True)
|
| 188 |
+
lse_pre_0, lse_pre_1 = cute.arch.fma_packed_f32x2(
|
| 189 |
+
(m0, m1), (scale_log2, scale_log2), (log0, log1)
|
| 190 |
+
)
|
| 191 |
+
lse_0, lse_1 = cute.arch.mul_packed_f32x2(
|
| 192 |
+
(lse_pre_0, lse_pre_1), (LN2, LN2)
|
| 193 |
+
)
|
| 194 |
+
row_sum[r] = -Float32.inf if bad0 else lse_0
|
| 195 |
+
row_sum[r + 1] = -Float32.inf if bad1 else lse_1
|
| 196 |
+
else:
|
| 197 |
+
for r in cutlass.range(num_rows, unroll_full=True):
|
| 198 |
+
if cutlass.const_expr(sink_val is not None):
|
| 199 |
+
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
|
| 200 |
+
LOG2_E = math.log2(math.e)
|
| 201 |
+
row_sum[r] += cute.math.exp2(
|
| 202 |
+
sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
| 206 |
+
row_scale[r] = (
|
| 207 |
+
cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
|
| 208 |
+
) * final_scale
|
| 209 |
+
row_sum_cur = row_sum[r]
|
| 210 |
+
row_sum[r] = (
|
| 211 |
+
(row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
|
| 212 |
+
if not acc_O_mn_row_is_zero_or_nan
|
| 213 |
+
else -Float32.inf
|
| 214 |
+
)
|
| 215 |
+
return row_scale
|
| 216 |
+
|
| 217 |
+
@cute.jit
|
| 218 |
+
def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
|
| 219 |
+
"""Scale each row of acc_O by the given scale tensor."""
|
| 220 |
+
acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
|
| 221 |
+
assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
|
| 222 |
+
n = cute.size(acc_O_mn, mode=[1])
|
| 223 |
+
if cutlass.const_expr(self.arch >= 100 and n % 2 == 0):
|
| 224 |
+
# SM100: pack adjacent pairs into fmul.f32x2 (2× CUDA-core throughput).
|
| 225 |
+
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
|
| 226 |
+
scale = row_scale[r]
|
| 227 |
+
for j in cutlass.range(0, n, 2, unroll_full=True):
|
| 228 |
+
acc_O_mn[r, j], acc_O_mn[r, j + 1] = cute.arch.mul_packed_f32x2(
|
| 229 |
+
(acc_O_mn[r, j], acc_O_mn[r, j + 1]), (scale, scale)
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
|
| 233 |
+
acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@dataclass
|
| 237 |
+
class SoftmaxSm100(Softmax):
|
| 238 |
+
"""SM100-specific softmax: single-row, explicit f32x2 pack for FMA/exp2 paths."""
|
| 239 |
+
|
| 240 |
+
rescale_threshold: cutlass.Constexpr[float] = 0.0
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def create(
|
| 244 |
+
scale_log2: Float32,
|
| 245 |
+
rescale_threshold: cutlass.Constexpr[float] = 0.0,
|
| 246 |
+
softmax_scale: Float32 | None = None,
|
| 247 |
+
):
|
| 248 |
+
num_rows = 1
|
| 249 |
+
arch = 100
|
| 250 |
+
row_max = cute.make_rmem_tensor(num_rows, Float32)
|
| 251 |
+
row_sum = cute.make_rmem_tensor(num_rows, Float32)
|
| 252 |
+
return SoftmaxSm100(
|
| 253 |
+
scale_log2,
|
| 254 |
+
num_rows,
|
| 255 |
+
row_max,
|
| 256 |
+
row_sum,
|
| 257 |
+
arch,
|
| 258 |
+
softmax_scale,
|
| 259 |
+
rescale_threshold=rescale_threshold,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
@cute.jit
|
| 263 |
+
def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
|
| 264 |
+
if cutlass.const_expr(is_first):
|
| 265 |
+
row_max_new = self._compute_row_max(acc_S_row)
|
| 266 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 267 |
+
acc_scale = 0.0
|
| 268 |
+
else:
|
| 269 |
+
row_max_old = self.row_max[0]
|
| 270 |
+
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
|
| 271 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 272 |
+
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
|
| 273 |
+
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
| 274 |
+
if cutlass.const_expr(self.rescale_threshold > 0.0):
|
| 275 |
+
if acc_scale_ >= -self.rescale_threshold:
|
| 276 |
+
row_max_new = row_max_old
|
| 277 |
+
row_max_safe = row_max_old
|
| 278 |
+
acc_scale = 1.0
|
| 279 |
+
self.row_max[0] = row_max_new
|
| 280 |
+
return row_max_safe, acc_scale
|
| 281 |
+
|
| 282 |
+
@cute.jit
|
| 283 |
+
def update_row_max_deferred_exp2(
|
| 284 |
+
self,
|
| 285 |
+
acc_S_row: cute.TensorSSA,
|
| 286 |
+
is_first: int,
|
| 287 |
+
) -> Tuple[Float32, Float32]:
|
| 288 |
+
"""update_row_max variant that publishes the log2-delta (un-exp2'd) so
|
| 289 |
+
the consumer can do the exp2 only when an actual rescale fires.
|
| 290 |
+
|
| 291 |
+
Returns ``(row_max_safe, acc_scale_log2_or_zero)`` where:
|
| 292 |
+
- ``row_max_safe`` is the same row-max as ``update_row_max`` (with
|
| 293 |
+
``rescale_threshold`` rollback applied).
|
| 294 |
+
- ``acc_scale_log2_or_zero`` is ``0.0`` for the first iteration or when
|
| 295 |
+
the threshold rollback fired (consumer treats as no rescale), else
|
| 296 |
+
the raw log2-domain value ``(row_max_old - row_max_safe)*scale_log2``
|
| 297 |
+
(consumer computes ``cute.math.exp2`` and rescales).
|
| 298 |
+
|
| 299 |
+
This keeps MUFU.EX2 off the sm_stats publication critical path that
|
| 300 |
+
gates the correction WG's consumer wait.
|
| 301 |
+
"""
|
| 302 |
+
publish = Float32(0.0)
|
| 303 |
+
if cutlass.const_expr(is_first):
|
| 304 |
+
row_max_new = self._compute_row_max(acc_S_row)
|
| 305 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 306 |
+
else:
|
| 307 |
+
row_max_old = self.row_max[0]
|
| 308 |
+
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
|
| 309 |
+
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
| 310 |
+
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
|
| 311 |
+
if cutlass.const_expr(self.rescale_threshold > 0.0):
|
| 312 |
+
if acc_scale_ >= -self.rescale_threshold:
|
| 313 |
+
row_max_new = row_max_old
|
| 314 |
+
row_max_safe = row_max_old
|
| 315 |
+
# publish stays 0.0 (signal: no rescale needed)
|
| 316 |
+
else:
|
| 317 |
+
publish = acc_scale_
|
| 318 |
+
else:
|
| 319 |
+
publish = acc_scale_
|
| 320 |
+
self.row_max[0] = row_max_new
|
| 321 |
+
return row_max_safe, publish
|
| 322 |
+
|
| 323 |
+
@cute.jit
|
| 324 |
+
def update_row_max_only(self, acc_S_row: cute.TensorSSA, is_first: int) -> None:
|
| 325 |
+
if cutlass.const_expr(is_first):
|
| 326 |
+
row_max_new = self._compute_row_max(acc_S_row)
|
| 327 |
+
else:
|
| 328 |
+
row_max_new = self._compute_row_max(acc_S_row, init_val=self.row_max[0])
|
| 329 |
+
self.row_max[0] = row_max_new
|
| 330 |
+
|
| 331 |
+
def update_row_sum(
|
| 332 |
+
self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
|
| 333 |
+
) -> None:
|
| 334 |
+
init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
|
| 335 |
+
self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
|
| 336 |
+
|
| 337 |
+
@cute.jit
|
| 338 |
+
def compute_scaled_exp2_row_sum(
|
| 339 |
+
self,
|
| 340 |
+
acc_S_row: cute.Tensor,
|
| 341 |
+
scale: Float32,
|
| 342 |
+
) -> Float32:
|
| 343 |
+
return utils.fadd_exp2_scaled_reduce(acc_S_row, scale, arch=self.arch)
|
| 344 |
+
|
| 345 |
+
@cute.jit
|
| 346 |
+
def scale_subtract_rowmax(
|
| 347 |
+
self,
|
| 348 |
+
acc_S_row: cute.Tensor,
|
| 349 |
+
row_max: Float32,
|
| 350 |
+
):
|
| 351 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 352 |
+
row_max_scaled = row_max * self.scale_log2
|
| 353 |
+
for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
|
| 354 |
+
acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
|
| 355 |
+
(acc_S_row[i], acc_S_row[i + 1]),
|
| 356 |
+
(self.scale_log2, self.scale_log2),
|
| 357 |
+
(-row_max_scaled, -row_max_scaled),
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
@cute.jit
|
| 361 |
+
def apply_exp2_convert(
|
| 362 |
+
self,
|
| 363 |
+
acc_S_row: cute.Tensor,
|
| 364 |
+
acc_S_row_converted: cute.Tensor,
|
| 365 |
+
ex2_emu_freq: cutlass.Constexpr[int] = 0,
|
| 366 |
+
ex2_emu_res: cutlass.Constexpr[int] = 4,
|
| 367 |
+
ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
|
| 368 |
+
):
|
| 369 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 370 |
+
frg_tile = 32
|
| 371 |
+
assert frg_tile % 2 == 0
|
| 372 |
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
| 373 |
+
assert cute.size(acc_S_row) % frg_tile == 0
|
| 374 |
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
| 375 |
+
acc_S_row_converted_frg = cute.logical_divide(
|
| 376 |
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
| 377 |
+
)
|
| 378 |
+
for j in cutlass.range_constexpr(frg_cnt):
|
| 379 |
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
| 380 |
+
if cutlass.const_expr(ex2_emu_freq == 0):
|
| 381 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 382 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 383 |
+
else:
|
| 384 |
+
if cutlass.const_expr(
|
| 385 |
+
k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
|
| 386 |
+
or j >= frg_cnt - 1
|
| 387 |
+
or j < ex2_emu_start_frg
|
| 388 |
+
):
|
| 389 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 390 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(
|
| 391 |
+
acc_S_row_frg[k + 1, j], fastmath=True
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
|
| 395 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
|
| 396 |
+
)
|
| 397 |
+
acc_S_row_converted_frg[None, j].store(
|
| 398 |
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
@cute.jit
|
| 402 |
+
def scale_apply_exp2_convert(
|
| 403 |
+
self,
|
| 404 |
+
acc_S_row: cute.Tensor,
|
| 405 |
+
row_max: Float32,
|
| 406 |
+
acc_S_row_converted: cute.Tensor,
|
| 407 |
+
):
|
| 408 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 409 |
+
minus_row_max_scaled = -row_max * self.scale_log2
|
| 410 |
+
for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
|
| 411 |
+
acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
|
| 412 |
+
(acc_S_row[i], acc_S_row[i + 1]),
|
| 413 |
+
(self.scale_log2, self.scale_log2),
|
| 414 |
+
(minus_row_max_scaled, minus_row_max_scaled),
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
frg_tile = 32
|
| 418 |
+
assert frg_tile % 2 == 0
|
| 419 |
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
| 420 |
+
assert cute.size(acc_S_row) % frg_tile == 0
|
| 421 |
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
| 422 |
+
acc_S_row_converted_frg = cute.logical_divide(
|
| 423 |
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
| 424 |
+
)
|
| 425 |
+
for j in cutlass.range_constexpr(frg_cnt):
|
| 426 |
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
| 427 |
+
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
| 428 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
| 429 |
+
acc_S_row_converted_frg[None, j].store(
|
| 430 |
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
@cute.jit
|
| 434 |
+
def scale_apply_exp2_convert_sum(
|
| 435 |
+
self,
|
| 436 |
+
acc_S_row: cute.Tensor,
|
| 437 |
+
row_max: Float32,
|
| 438 |
+
acc_S_row_converted: cute.Tensor,
|
| 439 |
+
init_sum: Float32,
|
| 440 |
+
ex2_emu_freq: cutlass.Constexpr[int] = 0,
|
| 441 |
+
ex2_emu_res: cutlass.Constexpr[int] = 4,
|
| 442 |
+
ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
|
| 443 |
+
) -> Float32:
|
| 444 |
+
# When ex2_emu_freq > 0, the (k % ex2_emu_freq) >= ex2_emu_freq - ex2_emu_res
|
| 445 |
+
# pairs in the inner loop use the FFMA2-based polynomial ex2 emulation
|
| 446 |
+
# (ex2_emulation_2) instead of MUFU exp2 — mirrors prefill's
|
| 447 |
+
# apply_exp2_convert. This removes the MUFU "wait" stall that dominates
|
| 448 |
+
# the second-largest stall bucket in decode (~22% of total).
|
| 449 |
+
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
|
| 450 |
+
minus_row_max_scaled = -row_max * self.scale_log2
|
| 451 |
+
acc_sum = (init_sum, Float32(0.0))
|
| 452 |
+
|
| 453 |
+
frg_tile = 32
|
| 454 |
+
assert frg_tile % 2 == 0
|
| 455 |
+
frg_cnt = cute.size(acc_S_row) // frg_tile
|
| 456 |
+
assert cute.size(acc_S_row) % frg_tile == 0
|
| 457 |
+
acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
|
| 458 |
+
acc_S_row_converted_frg = cute.logical_divide(
|
| 459 |
+
acc_S_row_converted, cute.make_layout(frg_tile)
|
| 460 |
+
)
|
| 461 |
+
for j in cutlass.range_constexpr(frg_cnt):
|
| 462 |
+
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
| 463 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = cute.arch.fma_packed_f32x2(
|
| 464 |
+
(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
|
| 465 |
+
(self.scale_log2, self.scale_log2),
|
| 466 |
+
(minus_row_max_scaled, minus_row_max_scaled),
|
| 467 |
+
)
|
| 468 |
+
if cutlass.const_expr(ex2_emu_freq == 0):
|
| 469 |
+
acc_S_row_frg[k, j] = cute.math.exp2(
|
| 470 |
+
acc_S_row_frg[k, j], fastmath=True)
|
| 471 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(
|
| 472 |
+
acc_S_row_frg[k + 1, j], fastmath=True)
|
| 473 |
+
else:
|
| 474 |
+
use_real = cutlass.const_expr(
|
| 475 |
+
k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
|
| 476 |
+
or j >= frg_cnt - 1
|
| 477 |
+
or j < ex2_emu_start_frg
|
| 478 |
+
)
|
| 479 |
+
if cutlass.const_expr(use_real):
|
| 480 |
+
acc_S_row_frg[k, j] = cute.math.exp2(
|
| 481 |
+
acc_S_row_frg[k, j], fastmath=True)
|
| 482 |
+
acc_S_row_frg[k + 1, j] = cute.math.exp2(
|
| 483 |
+
acc_S_row_frg[k + 1, j], fastmath=True)
|
| 484 |
+
else:
|
| 485 |
+
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
|
| 486 |
+
utils.ex2_emulation_2(
|
| 487 |
+
acc_S_row_frg[k, j],
|
| 488 |
+
acc_S_row_frg[k + 1, j],
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
acc_sum = cute.arch.add_packed_f32x2(
|
| 492 |
+
acc_sum,
|
| 493 |
+
(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
|
| 494 |
+
)
|
| 495 |
+
acc_S_row_converted_frg[None, j].store(
|
| 496 |
+
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
| 497 |
+
)
|
| 498 |
+
return acc_sum[0] + acc_sum[1]
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/tile_scheduler.py
ADDED
|
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
from enum import IntEnum, auto
|
| 5 |
+
from typing import Optional, Tuple, Protocol, runtime_checkable
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from typing import override
|
| 10 |
+
except ImportError: # Python < 3.12
|
| 11 |
+
from typing_extensions import override
|
| 12 |
+
|
| 13 |
+
import cutlass
|
| 14 |
+
from cutlass.pipeline import PipelineClcFetchAsync, PipelineState
|
| 15 |
+
from cutlass._mlir import ir
|
| 16 |
+
import cutlass.cute as cute
|
| 17 |
+
from cutlass import Int32, const_expr
|
| 18 |
+
from cutlass.cute import FastDivmodDivisor
|
| 19 |
+
from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams
|
| 20 |
+
|
| 21 |
+
from ...quack.cute_dsl_utils import ParamsBase
|
| 22 |
+
|
| 23 |
+
from ...src.common import utils as utils
|
| 24 |
+
from ...src.common.fast_math import clz
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SchedulingMode(IntEnum):
|
| 28 |
+
NONE = auto()
|
| 29 |
+
STATIC = auto()
|
| 30 |
+
DYNAMIC = auto()
|
| 31 |
+
CLC = auto()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ClcState(ParamsBase):
|
| 36 |
+
"""Owns the runtime state shared by CLC-capable tile schedulers.
|
| 37 |
+
|
| 38 |
+
`SparseAttentionForwardSm100` constructs this state because it owns the CLC
|
| 39 |
+
response buffer, mbarrier storage, and launch geometry needed to initialize
|
| 40 |
+
the hardware scheduler and async pipeline. Individual tile schedulers then
|
| 41 |
+
consume this state and map the returned hardware work tiles into their own
|
| 42 |
+
logical `WorkTileInfo` coordinates.
|
| 43 |
+
|
| 44 |
+
To add CLC support to a scheduler:
|
| 45 |
+
- implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler
|
| 46 |
+
- accept `clc: ClcState | None` in `create(...)` / `__init__`
|
| 47 |
+
- map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
_hw_scheduler: ClcDynamicPersistentTileScheduler
|
| 51 |
+
_pipeline: PipelineClcFetchAsync
|
| 52 |
+
_consumer_state: PipelineState
|
| 53 |
+
_producer_state: PipelineState
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def create(
|
| 57 |
+
*,
|
| 58 |
+
hw_scheduler: ClcDynamicPersistentTileScheduler,
|
| 59 |
+
pipeline: PipelineClcFetchAsync,
|
| 60 |
+
consumer_state: PipelineState,
|
| 61 |
+
producer_state: PipelineState,
|
| 62 |
+
) -> "ClcState":
|
| 63 |
+
return ClcState(hw_scheduler, pipeline, consumer_state, producer_state)
|
| 64 |
+
|
| 65 |
+
def initial_work_tile_info(self):
|
| 66 |
+
return self._hw_scheduler.initial_work_tile_info()
|
| 67 |
+
|
| 68 |
+
def get_current_work(self):
|
| 69 |
+
return self._hw_scheduler.get_current_work()
|
| 70 |
+
|
| 71 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 72 |
+
self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip)
|
| 73 |
+
mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip)
|
| 74 |
+
self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip)
|
| 75 |
+
self._producer_state.advance(loc=loc, ip=ip)
|
| 76 |
+
|
| 77 |
+
def consumer_wait(self, *, loc=None, ip=None):
|
| 78 |
+
self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip)
|
| 79 |
+
|
| 80 |
+
def consumer_release(self, *, loc=None, ip=None):
|
| 81 |
+
self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip)
|
| 82 |
+
self._consumer_state.advance(loc=loc, ip=ip)
|
| 83 |
+
|
| 84 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 85 |
+
self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class WorkTileInfo(cutlass.utils.WorkTileInfo):
|
| 89 |
+
"""Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
|
| 90 |
+
|
| 91 |
+
@override
|
| 92 |
+
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
|
| 93 |
+
assert len(values) == 5
|
| 94 |
+
new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
|
| 95 |
+
new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
|
| 96 |
+
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@runtime_checkable
|
| 100 |
+
class TileSchedulerProtocol(Protocol):
|
| 101 |
+
"""Protocol defining the interface all tile schedulers must implement.
|
| 102 |
+
|
| 103 |
+
Schedulers are responsible for:
|
| 104 |
+
1. Coordinate mapping: linear tile index -> (m_block, head, batch, split)
|
| 105 |
+
2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic)
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def get_current_work(self) -> WorkTileInfo:
|
| 109 |
+
"""Get the current work tile coordinates."""
|
| 110 |
+
...
|
| 111 |
+
|
| 112 |
+
def initial_work_tile_info(self) -> WorkTileInfo:
|
| 113 |
+
"""Get the initial work tile for this CTA."""
|
| 114 |
+
...
|
| 115 |
+
|
| 116 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 117 |
+
"""Consumer-side advance: move to next tile and return it.
|
| 118 |
+
|
| 119 |
+
For static schedulers: grid-stride increment + get_current_work.
|
| 120 |
+
For CLC schedulers: consumer wait + get_current_work + consumer release + state advance.
|
| 121 |
+
"""
|
| 122 |
+
...
|
| 123 |
+
|
| 124 |
+
def prefetch_next_work(self, *, loc=None, ip=None) -> None:
|
| 125 |
+
"""Producer-side prefetch of next work tile (no-op for static schedulers).
|
| 126 |
+
|
| 127 |
+
For CLC schedulers: producer acquire + issue CLC query + producer state advance.
|
| 128 |
+
Only called by the scheduler warp.
|
| 129 |
+
"""
|
| 130 |
+
...
|
| 131 |
+
|
| 132 |
+
def producer_tail(self, *, loc=None, ip=None) -> None:
|
| 133 |
+
"""Producer-side cleanup after the last tile.
|
| 134 |
+
|
| 135 |
+
No-op for static schedulers. For CLC schedulers: pipeline producer_tail.
|
| 136 |
+
"""
|
| 137 |
+
...
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@dataclass
|
| 141 |
+
class TileSchedulerArguments(ParamsBase):
|
| 142 |
+
num_block: Int32
|
| 143 |
+
num_head: Int32
|
| 144 |
+
num_batch: Int32
|
| 145 |
+
num_splits: Int32
|
| 146 |
+
seqlen_k: Int32
|
| 147 |
+
headdim: Int32
|
| 148 |
+
headdim_v: Int32
|
| 149 |
+
total_q: Int32
|
| 150 |
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
| 151 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 152 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
| 153 |
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
| 154 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 155 |
+
element_size: cutlass.Constexpr[int] = 2
|
| 156 |
+
is_persistent: cutlass.Constexpr[bool] = False
|
| 157 |
+
lpt: cutlass.Constexpr[bool] = False
|
| 158 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 159 |
+
head_swizzle: cutlass.Constexpr[bool] = False
|
| 160 |
+
use_cluster_idx: cutlass.Constexpr[bool] = False
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SingleTileScheduler:
|
| 164 |
+
@dataclass
|
| 165 |
+
class Params(ParamsBase):
|
| 166 |
+
num_block: Int32
|
| 167 |
+
num_head: Int32
|
| 168 |
+
num_batch: Int32
|
| 169 |
+
num_splits: Int32
|
| 170 |
+
num_splits_divmod: FastDivmodDivisor
|
| 171 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 172 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 173 |
+
use_cluster_idx: cutlass.Constexpr[bool] = False
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def create(
|
| 177 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 178 |
+
) -> "SingleTileScheduler.Params":
|
| 179 |
+
return SingleTileScheduler.Params(
|
| 180 |
+
args.num_block,
|
| 181 |
+
args.num_head,
|
| 182 |
+
args.num_batch,
|
| 183 |
+
args.num_splits,
|
| 184 |
+
FastDivmodDivisor(args.num_splits),
|
| 185 |
+
args.is_split_kv,
|
| 186 |
+
args.cluster_shape_mn,
|
| 187 |
+
args.use_cluster_idx,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
|
| 191 |
+
self.params = params
|
| 192 |
+
self._blk_coord = blk_coord
|
| 193 |
+
self._is_first_block = True
|
| 194 |
+
self._loc = loc
|
| 195 |
+
self._ip = ip
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def to_underlying_arguments(
|
| 199 |
+
args: TileSchedulerArguments,
|
| 200 |
+
*,
|
| 201 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 202 |
+
loc=None,
|
| 203 |
+
ip=None,
|
| 204 |
+
) -> Params:
|
| 205 |
+
assert scheduling_mode == SchedulingMode.STATIC, (
|
| 206 |
+
f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}"
|
| 207 |
+
)
|
| 208 |
+
return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def create(
|
| 212 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 213 |
+
) -> "SingleTileScheduler":
|
| 214 |
+
if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx):
|
| 215 |
+
blk_coord = cute.arch.block_idx()
|
| 216 |
+
else:
|
| 217 |
+
blk_coord = cute.arch.cluster_idx()
|
| 218 |
+
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
|
| 219 |
+
|
| 220 |
+
# called by host
|
| 221 |
+
@staticmethod
|
| 222 |
+
def get_grid_shape(
|
| 223 |
+
params: Params,
|
| 224 |
+
*,
|
| 225 |
+
loc=None,
|
| 226 |
+
ip=None,
|
| 227 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 228 |
+
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
|
| 229 |
+
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 230 |
+
if const_expr(params.use_cluster_idx):
|
| 231 |
+
# Grid must have num_block * cluster_m physical blocks so that there are num_block clusters
|
| 232 |
+
grid_x = params.num_block * params.cluster_shape_mn[0]
|
| 233 |
+
else:
|
| 234 |
+
grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0])
|
| 235 |
+
return (
|
| 236 |
+
grid_x,
|
| 237 |
+
params.num_head * params.num_splits,
|
| 238 |
+
params.num_batch,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 242 |
+
block_idx, head_idx, batch_idx = self._blk_coord
|
| 243 |
+
if const_expr(self.params.is_split_kv):
|
| 244 |
+
head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
|
| 245 |
+
else:
|
| 246 |
+
split_idx = Int32(0)
|
| 247 |
+
return WorkTileInfo(
|
| 248 |
+
(block_idx, head_idx, batch_idx, split_idx),
|
| 249 |
+
self._is_first_block,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 253 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 254 |
+
|
| 255 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 259 |
+
self._is_first_block = False
|
| 260 |
+
return self.get_current_work()
|
| 261 |
+
|
| 262 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
def __extract_mlir_values__(self):
|
| 266 |
+
values, self._values_pos = [], []
|
| 267 |
+
for obj in [self.params, self._blk_coord]:
|
| 268 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 269 |
+
values += obj_values
|
| 270 |
+
self._values_pos.append(len(obj_values))
|
| 271 |
+
return values
|
| 272 |
+
|
| 273 |
+
def __new_from_mlir_values__(self, values):
|
| 274 |
+
obj_list = []
|
| 275 |
+
for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
|
| 276 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 277 |
+
values = values[n_items:]
|
| 278 |
+
return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class StaticPersistentTileScheduler:
|
| 282 |
+
@dataclass
|
| 283 |
+
class Params(ParamsBase):
|
| 284 |
+
num_block_cluster_divmod: FastDivmodDivisor
|
| 285 |
+
num_head_divmod: FastDivmodDivisor
|
| 286 |
+
total_blocks_cluster: Int32
|
| 287 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 288 |
+
|
| 289 |
+
@staticmethod
|
| 290 |
+
def create(
|
| 291 |
+
args: TileSchedulerArguments, *, loc=None, ip=None
|
| 292 |
+
) -> "StaticPersistentTileScheduler.Params":
|
| 293 |
+
num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn))
|
| 294 |
+
total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch
|
| 295 |
+
return StaticPersistentTileScheduler.Params(
|
| 296 |
+
FastDivmodDivisor(num_block_cluster),
|
| 297 |
+
FastDivmodDivisor(args.num_head),
|
| 298 |
+
total_blocks_cluster,
|
| 299 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
|
| 303 |
+
self.params = params
|
| 304 |
+
self._tile_idx = tile_idx
|
| 305 |
+
self._loc = loc
|
| 306 |
+
self._ip = ip
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def to_underlying_arguments(
|
| 310 |
+
args: TileSchedulerArguments,
|
| 311 |
+
*,
|
| 312 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 313 |
+
loc=None,
|
| 314 |
+
ip=None,
|
| 315 |
+
) -> Params:
|
| 316 |
+
assert scheduling_mode == SchedulingMode.STATIC, (
|
| 317 |
+
f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}"
|
| 318 |
+
)
|
| 319 |
+
return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def create(
|
| 323 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 324 |
+
) -> "StaticPersistentTileScheduler":
|
| 325 |
+
if const_expr(cute.size(params.cluster_shape_m) == 1):
|
| 326 |
+
tile_idx = cute.arch.block_idx()[0]
|
| 327 |
+
else:
|
| 328 |
+
tile_idx = cute.arch.cluster_idx()[0]
|
| 329 |
+
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
|
| 330 |
+
|
| 331 |
+
@staticmethod
|
| 332 |
+
def get_grid_shape(
|
| 333 |
+
params: Params,
|
| 334 |
+
*,
|
| 335 |
+
usable_SM_count=0,
|
| 336 |
+
loc=None,
|
| 337 |
+
ip=None,
|
| 338 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 339 |
+
hardware_info = cutlass.utils.HardwareInfo()
|
| 340 |
+
cluster_shape_m = int(params.cluster_shape_m)
|
| 341 |
+
if usable_SM_count > 0:
|
| 342 |
+
sm_count = usable_SM_count
|
| 343 |
+
else:
|
| 344 |
+
sm_count = hardware_info.get_device_multiprocessor_count()
|
| 345 |
+
max_ctas = (sm_count // cluster_shape_m) * cluster_shape_m
|
| 346 |
+
max_ctas = max(max_ctas, cluster_shape_m)
|
| 347 |
+
grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * cluster_shape_m)
|
| 348 |
+
return (grid_x, Int32(1), Int32(1))
|
| 349 |
+
|
| 350 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 351 |
+
hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
|
| 352 |
+
batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
|
| 353 |
+
is_valid = self._tile_idx < self.params.total_blocks_cluster
|
| 354 |
+
return WorkTileInfo(
|
| 355 |
+
(Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 359 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 360 |
+
|
| 361 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 365 |
+
if const_expr(self.params.cluster_shape_m == 1):
|
| 366 |
+
self._tile_idx += cute.arch.grid_dim()[0]
|
| 367 |
+
else:
|
| 368 |
+
self._tile_idx += cute.arch.cluster_dim()[0]
|
| 369 |
+
return self.get_current_work()
|
| 370 |
+
|
| 371 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 372 |
+
pass
|
| 373 |
+
|
| 374 |
+
def __extract_mlir_values__(self):
|
| 375 |
+
values, self._values_pos = [], []
|
| 376 |
+
for obj in [self.params, self._tile_idx]:
|
| 377 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 378 |
+
values += obj_values
|
| 379 |
+
self._values_pos.append(len(obj_values))
|
| 380 |
+
return values
|
| 381 |
+
|
| 382 |
+
def __new_from_mlir_values__(self, values):
|
| 383 |
+
obj_list = []
|
| 384 |
+
for obj, n_items in zip(
|
| 385 |
+
[self.params, self._tile_idx],
|
| 386 |
+
self._values_pos,
|
| 387 |
+
):
|
| 388 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 389 |
+
values = values[n_items:]
|
| 390 |
+
return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class SingleTileLPTScheduler:
|
| 394 |
+
@dataclass
|
| 395 |
+
class Params(ParamsBase):
|
| 396 |
+
total_blocks: Int32
|
| 397 |
+
num_splits: Int32
|
| 398 |
+
num_block: Int32
|
| 399 |
+
num_head: Int32
|
| 400 |
+
num_batch: Int32
|
| 401 |
+
l2_minor: Int32
|
| 402 |
+
num_head_divmod: FastDivmodDivisor
|
| 403 |
+
l2_minor_divmod: FastDivmodDivisor
|
| 404 |
+
l2_major_divmod: FastDivmodDivisor
|
| 405 |
+
l2_minor_residual_divmod: FastDivmodDivisor
|
| 406 |
+
num_hb_quotient: Int32
|
| 407 |
+
num_splits_divmod: FastDivmodDivisor
|
| 408 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 409 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 410 |
+
scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
|
| 411 |
+
lpt: cutlass.Constexpr[bool] = True
|
| 412 |
+
use_cluster_idx: cutlass.Constexpr[bool] = True
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
@cute.jit
|
| 416 |
+
def create(
|
| 417 |
+
args: TileSchedulerArguments,
|
| 418 |
+
*,
|
| 419 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 420 |
+
loc=None,
|
| 421 |
+
ip=None,
|
| 422 |
+
) -> "SingleTileLPTScheduler.Params":
|
| 423 |
+
assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
|
| 424 |
+
f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
|
| 425 |
+
)
|
| 426 |
+
size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
|
| 427 |
+
size_one_head = size_one_kv_head
|
| 428 |
+
size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
|
| 429 |
+
# Swizzle is the size of each "section". Round swizzle to a power of 2
|
| 430 |
+
# Need to be careful about the case where only one head will fit
|
| 431 |
+
# swizzle is how many heads can fit in L2
|
| 432 |
+
# Seems faster if swizzle is a power of 2
|
| 433 |
+
log2_floor = lambda n: 31 - clz(n)
|
| 434 |
+
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
|
| 435 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 436 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 437 |
+
num_hb_quotient = (args.num_head * args.num_batch) // swizzle
|
| 438 |
+
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
|
| 439 |
+
return SingleTileLPTScheduler.Params(
|
| 440 |
+
total_blocks=args.num_block * args.num_head * args.num_batch,
|
| 441 |
+
num_block=args.num_block,
|
| 442 |
+
num_head=args.num_head,
|
| 443 |
+
num_batch=args.num_batch,
|
| 444 |
+
l2_minor=Int32(swizzle),
|
| 445 |
+
num_head_divmod=FastDivmodDivisor(args.num_head),
|
| 446 |
+
l2_minor_divmod=FastDivmodDivisor(swizzle),
|
| 447 |
+
l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
|
| 448 |
+
l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)),
|
| 449 |
+
num_hb_quotient=Int32(num_hb_quotient),
|
| 450 |
+
num_splits=args.num_splits,
|
| 451 |
+
num_splits_divmod=FastDivmodDivisor(args.num_splits),
|
| 452 |
+
is_split_kv=args.is_split_kv,
|
| 453 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 454 |
+
scheduling_mode=scheduling_mode,
|
| 455 |
+
lpt=args.lpt,
|
| 456 |
+
use_cluster_idx=args.use_cluster_idx,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def __init__(
|
| 460 |
+
self,
|
| 461 |
+
params: Params,
|
| 462 |
+
tile_idx: Int32,
|
| 463 |
+
split_idx: Int32,
|
| 464 |
+
clc: ClcState | None = None,
|
| 465 |
+
*,
|
| 466 |
+
loc=None,
|
| 467 |
+
ip=None,
|
| 468 |
+
):
|
| 469 |
+
self.params = params
|
| 470 |
+
self._tile_idx = tile_idx
|
| 471 |
+
self._split_idx = split_idx
|
| 472 |
+
self.clc = clc
|
| 473 |
+
self._loc = loc
|
| 474 |
+
self._ip = ip
|
| 475 |
+
|
| 476 |
+
@staticmethod
|
| 477 |
+
def to_underlying_arguments(
|
| 478 |
+
args: TileSchedulerArguments,
|
| 479 |
+
*,
|
| 480 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 481 |
+
loc=None,
|
| 482 |
+
ip=None,
|
| 483 |
+
) -> Params:
|
| 484 |
+
return SingleTileLPTScheduler.Params.create(
|
| 485 |
+
args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
@staticmethod
|
| 489 |
+
def _clc_grid_shape(params: Params):
|
| 490 |
+
num_batch_splits = (
|
| 491 |
+
params.num_batch * params.num_splits
|
| 492 |
+
if const_expr(params.is_split_kv)
|
| 493 |
+
else params.num_batch
|
| 494 |
+
)
|
| 495 |
+
return (
|
| 496 |
+
cute.round_up(params.num_block, params.cluster_shape_m),
|
| 497 |
+
params.num_head,
|
| 498 |
+
num_batch_splits,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
@staticmethod
|
| 502 |
+
@cute.jit
|
| 503 |
+
def clc_problem_shape(params: Params):
|
| 504 |
+
return ClcDynamicPersistentTileSchedulerParams(
|
| 505 |
+
problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params),
|
| 506 |
+
cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
@cute.jit
|
| 511 |
+
def create(
|
| 512 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 513 |
+
) -> "SingleTileLPTScheduler":
|
| 514 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 515 |
+
return SingleTileLPTScheduler(
|
| 516 |
+
params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip
|
| 517 |
+
)
|
| 518 |
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 519 |
+
return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 520 |
+
|
| 521 |
+
@staticmethod
|
| 522 |
+
def get_grid_shape(
|
| 523 |
+
params: Params,
|
| 524 |
+
*,
|
| 525 |
+
loc=None,
|
| 526 |
+
ip=None,
|
| 527 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 528 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 529 |
+
return SingleTileLPTScheduler._clc_grid_shape(params)
|
| 530 |
+
return (params.total_blocks, params.num_splits, Int32(1))
|
| 531 |
+
|
| 532 |
+
@cute.jit
|
| 533 |
+
def clc_work_to_coords(self, work) -> WorkTileInfo:
|
| 534 |
+
"""Convert CLC response (block, head, batch_split) to WorkTileInfo.
|
| 535 |
+
|
| 536 |
+
CLC returns raw grid coordinates — no L2 swizzle (hardware decides order).
|
| 537 |
+
We only apply cluster division, optional LPT block reversal, and split_kv unpacking.
|
| 538 |
+
"""
|
| 539 |
+
block_idx = work.tile_idx[0]
|
| 540 |
+
if const_expr(self.params.cluster_shape_m > 1):
|
| 541 |
+
block_idx = block_idx // self.params.cluster_shape_m
|
| 542 |
+
if const_expr(self.params.lpt):
|
| 543 |
+
# Longest-processing-time-first: reverse block order
|
| 544 |
+
if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx):
|
| 545 |
+
num_block = self.params.num_block // self.params.cluster_shape_m
|
| 546 |
+
else:
|
| 547 |
+
num_block = self.params.num_block
|
| 548 |
+
block_idx = num_block - 1 - block_idx
|
| 549 |
+
split_idx = Int32(0)
|
| 550 |
+
if const_expr(self.params.is_split_kv):
|
| 551 |
+
batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod)
|
| 552 |
+
else:
|
| 553 |
+
batch_idx = work.tile_idx[2]
|
| 554 |
+
if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx):
|
| 555 |
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
| 556 |
+
block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0]
|
| 557 |
+
return WorkTileInfo(
|
| 558 |
+
(Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)),
|
| 559 |
+
work.is_valid_tile,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
@cute.jit
|
| 563 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 564 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 565 |
+
work = self.clc.get_current_work()
|
| 566 |
+
self._tile_idx = work.tile_idx[0]
|
| 567 |
+
return self.clc_work_to_coords(work)
|
| 568 |
+
# Static path: L2-swizzled coordinate mapping
|
| 569 |
+
params = self.params
|
| 570 |
+
# Implement LPT scheduling coordinate calculation
|
| 571 |
+
bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
|
| 572 |
+
# If we're in the last section (called residual), we don't want to divide by
|
| 573 |
+
# swizzle. Instead we want to divide by the remainder.
|
| 574 |
+
block, bidhb_residual = 0, 0
|
| 575 |
+
if bidhb < params.num_hb_quotient:
|
| 576 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
|
| 577 |
+
else:
|
| 578 |
+
block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
|
| 579 |
+
bidhb_actual = bidhb * params.l2_minor + bidhb_residual
|
| 580 |
+
batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
|
| 581 |
+
# Longest-processing-time-first
|
| 582 |
+
if const_expr(params.lpt):
|
| 583 |
+
block = params.num_block - 1 - block
|
| 584 |
+
is_valid = self._tile_idx < params.total_blocks
|
| 585 |
+
return WorkTileInfo(
|
| 586 |
+
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
@cute.jit
|
| 590 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 591 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 592 |
+
work = self.clc.initial_work_tile_info()
|
| 593 |
+
self._tile_idx = work.tile_idx[0]
|
| 594 |
+
return self.clc_work_to_coords(work)
|
| 595 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 596 |
+
|
| 597 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 598 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 599 |
+
self.clc.prefetch_next_work(loc=loc, ip=ip)
|
| 600 |
+
|
| 601 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 602 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 603 |
+
self.clc.consumer_wait(loc=loc, ip=ip)
|
| 604 |
+
work = self.get_current_work()
|
| 605 |
+
self.clc.consumer_release(loc=loc, ip=ip)
|
| 606 |
+
return work
|
| 607 |
+
# Single tile scheduler - set to invalid tile_idx to indicate no more work
|
| 608 |
+
self._tile_idx = self.params.total_blocks
|
| 609 |
+
return self.get_current_work()
|
| 610 |
+
|
| 611 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 612 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 613 |
+
self.clc.producer_tail(loc=loc, ip=ip)
|
| 614 |
+
|
| 615 |
+
def __extract_mlir_values__(self):
|
| 616 |
+
values, self._values_pos = [], []
|
| 617 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 618 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 619 |
+
objs += [self.clc]
|
| 620 |
+
for obj in objs:
|
| 621 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 622 |
+
values += obj_values
|
| 623 |
+
self._values_pos.append(len(obj_values))
|
| 624 |
+
return values
|
| 625 |
+
|
| 626 |
+
def __new_from_mlir_values__(self, values):
|
| 627 |
+
obj_list = []
|
| 628 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 629 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 630 |
+
objs += [self.clc]
|
| 631 |
+
for obj, n_items in zip(objs, self._values_pos):
|
| 632 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 633 |
+
values = values[n_items:]
|
| 634 |
+
return self.__class__(*obj_list, loc=self._loc)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class SingleTileVarlenScheduler:
|
| 638 |
+
@dataclass
|
| 639 |
+
class Params(ParamsBase):
|
| 640 |
+
num_head: Int32
|
| 641 |
+
num_batch: Int32
|
| 642 |
+
total_q: Int32
|
| 643 |
+
num_splits: Int32
|
| 644 |
+
max_kvblock_in_l2: Int32
|
| 645 |
+
tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
|
| 646 |
+
mCuSeqlensQ: Optional[cute.Tensor] = None
|
| 647 |
+
mSeqUsedQ: Optional[cute.Tensor] = None
|
| 648 |
+
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
| 649 |
+
lpt: cutlass.Constexpr[bool] = False
|
| 650 |
+
is_split_kv: cutlass.Constexpr[bool] = False
|
| 651 |
+
head_swizzle: cutlass.Constexpr[bool] = False
|
| 652 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 653 |
+
scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
|
| 654 |
+
|
| 655 |
+
@staticmethod
|
| 656 |
+
@cute.jit
|
| 657 |
+
def create(
|
| 658 |
+
args: TileSchedulerArguments,
|
| 659 |
+
*,
|
| 660 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 661 |
+
loc=None,
|
| 662 |
+
ip=None,
|
| 663 |
+
) -> "SingleTileVarlenScheduler.Params":
|
| 664 |
+
assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
|
| 665 |
+
f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
|
| 666 |
+
)
|
| 667 |
+
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
|
| 668 |
+
kv_block_size = (
|
| 669 |
+
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
|
| 670 |
+
)
|
| 671 |
+
if args.head_swizzle:
|
| 672 |
+
kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
|
| 673 |
+
max_kvblock_in_l2 = size_l2 // kv_block_size
|
| 674 |
+
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
|
| 675 |
+
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
|
| 676 |
+
)
|
| 677 |
+
assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
|
| 678 |
+
# TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the
|
| 679 |
+
# flattened-tile decode so cluster unpacking semantics are explicit.
|
| 680 |
+
assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, (
|
| 681 |
+
"Varlen CLC currently requires cluster_shape_mn[0] == 1"
|
| 682 |
+
)
|
| 683 |
+
return SingleTileVarlenScheduler.Params(
|
| 684 |
+
num_head=args.num_head,
|
| 685 |
+
num_batch=args.num_batch,
|
| 686 |
+
total_q=args.total_q,
|
| 687 |
+
num_splits=args.num_splits,
|
| 688 |
+
max_kvblock_in_l2=max_kvblock_in_l2,
|
| 689 |
+
tile_shape_mn=args.tile_shape_mn,
|
| 690 |
+
mCuSeqlensQ=args.mCuSeqlensQ,
|
| 691 |
+
mSeqUsedQ=args.mSeqUsedQ,
|
| 692 |
+
qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
|
| 693 |
+
lpt=args.lpt,
|
| 694 |
+
is_split_kv=args.is_split_kv,
|
| 695 |
+
head_swizzle=args.head_swizzle,
|
| 696 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 697 |
+
scheduling_mode=scheduling_mode,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
def __init__(
|
| 701 |
+
self,
|
| 702 |
+
params: Params,
|
| 703 |
+
tile_idx: Int32,
|
| 704 |
+
split_idx: Int32,
|
| 705 |
+
clc: ClcState | None = None,
|
| 706 |
+
*,
|
| 707 |
+
loc=None,
|
| 708 |
+
ip=None,
|
| 709 |
+
):
|
| 710 |
+
self.params = params
|
| 711 |
+
self._tile_idx = tile_idx
|
| 712 |
+
self._split_idx = split_idx
|
| 713 |
+
self._is_first_block = True
|
| 714 |
+
self.clc = clc
|
| 715 |
+
self._loc = loc
|
| 716 |
+
self._ip = ip
|
| 717 |
+
|
| 718 |
+
@staticmethod
|
| 719 |
+
def to_underlying_arguments(
|
| 720 |
+
args: TileSchedulerArguments,
|
| 721 |
+
*,
|
| 722 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 723 |
+
loc=None,
|
| 724 |
+
ip=None,
|
| 725 |
+
) -> Params:
|
| 726 |
+
return SingleTileVarlenScheduler.Params.create(
|
| 727 |
+
args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
@staticmethod
|
| 731 |
+
@cute.jit
|
| 732 |
+
def clc_problem_shape(params: Params):
|
| 733 |
+
return ClcDynamicPersistentTileSchedulerParams(
|
| 734 |
+
problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params),
|
| 735 |
+
cluster_shape_mnk=(1, 1, 1),
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
@staticmethod
|
| 739 |
+
@cute.jit
|
| 740 |
+
def create(
|
| 741 |
+
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
|
| 742 |
+
) -> "SingleTileVarlenScheduler":
|
| 743 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 744 |
+
block_idx = cute.arch.block_idx()
|
| 745 |
+
split_idx = Int32(0)
|
| 746 |
+
if const_expr(params.is_split_kv):
|
| 747 |
+
split_idx = block_idx[1]
|
| 748 |
+
return SingleTileVarlenScheduler(
|
| 749 |
+
params,
|
| 750 |
+
block_idx[0],
|
| 751 |
+
split_idx,
|
| 752 |
+
clc,
|
| 753 |
+
loc=loc,
|
| 754 |
+
ip=ip,
|
| 755 |
+
)
|
| 756 |
+
tile_idx, split_idx, _ = cute.arch.block_idx()
|
| 757 |
+
return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
|
| 758 |
+
|
| 759 |
+
# called by host
|
| 760 |
+
@staticmethod
|
| 761 |
+
def get_grid_shape(
|
| 762 |
+
params: Params,
|
| 763 |
+
*,
|
| 764 |
+
loc=None,
|
| 765 |
+
ip=None,
|
| 766 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 767 |
+
total_blocks_max = (
|
| 768 |
+
params.total_q
|
| 769 |
+
+ params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
|
| 770 |
+
) // params.tile_shape_mn[0]
|
| 771 |
+
# Round down to nearest multiple of cluster since odd excess is always padding.
|
| 772 |
+
total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
|
| 773 |
+
return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
|
| 774 |
+
|
| 775 |
+
@cute.jit
|
| 776 |
+
def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
|
| 777 |
+
params = self.params
|
| 778 |
+
batch_idx = lane + bidb_start
|
| 779 |
+
if cutlass.const_expr(params.mSeqUsedQ is not None):
|
| 780 |
+
seqlen = Int32(0)
|
| 781 |
+
if batch_idx < params.num_batch:
|
| 782 |
+
seqlen = params.mSeqUsedQ[batch_idx]
|
| 783 |
+
else:
|
| 784 |
+
assert params.mCuSeqlensQ is not None
|
| 785 |
+
cur_cu_seqlen = Int32(0)
|
| 786 |
+
if batch_idx <= params.num_batch:
|
| 787 |
+
cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
|
| 788 |
+
next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
|
| 789 |
+
seqlen = next_cu_seqlen - cur_cu_seqlen
|
| 790 |
+
if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
|
| 791 |
+
seqlen *= params.qhead_per_kvhead_packgqa
|
| 792 |
+
return (
|
| 793 |
+
cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m)
|
| 794 |
+
if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
|
| 795 |
+
else Int32(0)
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
@cute.jit
|
| 799 |
+
def _varlen_coord_map(self) -> WorkTileInfo:
|
| 800 |
+
"""Map self._tile_idx to (block, head, batch) via warp-level prefix sums."""
|
| 801 |
+
params = self.params
|
| 802 |
+
lane_idx = cute.arch.lane_idx()
|
| 803 |
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
|
| 804 |
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
| 805 |
+
# Total number of blocks for the next 31 batches
|
| 806 |
+
m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
|
| 807 |
+
# Same for all lanes
|
| 808 |
+
group_end_tile = m_blocks_in_group * params.num_head
|
| 809 |
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
|
| 810 |
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
|
| 811 |
+
next_tile_idx = self._tile_idx // params.cluster_shape_m
|
| 812 |
+
while group_end_tile <= next_tile_idx:
|
| 813 |
+
batch_idx += cute.arch.WARP_SIZE - 1
|
| 814 |
+
if batch_idx >= params.num_batch:
|
| 815 |
+
batch_idx = Int32(params.num_batch)
|
| 816 |
+
group_end_tile = next_tile_idx + 1
|
| 817 |
+
else:
|
| 818 |
+
num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
|
| 819 |
+
num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
|
| 820 |
+
m_blocks_in_group = cute.arch.shuffle_sync(
|
| 821 |
+
num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
|
| 822 |
+
)
|
| 823 |
+
group_end_tile += m_blocks_in_group * params.num_head
|
| 824 |
+
is_valid = False
|
| 825 |
+
if batch_idx >= params.num_batch:
|
| 826 |
+
block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
|
| 827 |
+
else:
|
| 828 |
+
group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
|
| 829 |
+
# if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
|
| 830 |
+
# The next problem to process is the first one that does not have ending tile position
|
| 831 |
+
# that is greater than or equal to tile index.
|
| 832 |
+
batch_idx_in_group = cute.arch.popc(
|
| 833 |
+
cute.arch.vote_ballot_sync(
|
| 834 |
+
group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
|
| 835 |
+
)
|
| 836 |
+
)
|
| 837 |
+
batch_idx += batch_idx_in_group
|
| 838 |
+
num_m_blocks_prev_lane = (
|
| 839 |
+
0
|
| 840 |
+
if batch_idx_in_group == 0
|
| 841 |
+
else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
|
| 842 |
+
)
|
| 843 |
+
num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
|
| 844 |
+
mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
|
| 845 |
+
if cutlass.const_expr(params.lpt or params.head_swizzle):
|
| 846 |
+
# This is a version of the SingleTileLPTScheduler, complicated by the fact that
|
| 847 |
+
# the seqlen can vary per batch.
|
| 848 |
+
# TODO: is there any case where num_m_blocks is 0?
|
| 849 |
+
# TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
|
| 850 |
+
num_n_blocks = (
|
| 851 |
+
num_m_blocks
|
| 852 |
+
* params.tile_shape_mn[0]
|
| 853 |
+
* params.cluster_shape_m
|
| 854 |
+
// params.qhead_per_kvhead_packgqa
|
| 855 |
+
// params.tile_shape_mn[1]
|
| 856 |
+
)
|
| 857 |
+
# nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
|
| 858 |
+
# Seems faster to have this be a power of 2
|
| 859 |
+
nheads_in_l2 = (
|
| 860 |
+
16
|
| 861 |
+
if num_n_blocks * 16 <= params.max_kvblock_in_l2
|
| 862 |
+
else (
|
| 863 |
+
8
|
| 864 |
+
if num_n_blocks * 8 <= params.max_kvblock_in_l2
|
| 865 |
+
else (
|
| 866 |
+
4
|
| 867 |
+
if num_n_blocks * 4 <= params.max_kvblock_in_l2
|
| 868 |
+
else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
|
| 869 |
+
)
|
| 870 |
+
)
|
| 871 |
+
)
|
| 872 |
+
nheads_in_l2 = min(nheads_in_l2, params.num_head)
|
| 873 |
+
mh_in_l2 = nheads_in_l2 * num_m_blocks
|
| 874 |
+
section_idx = mh_block // mh_in_l2
|
| 875 |
+
l2_mod = mh_block - section_idx * mh_in_l2
|
| 876 |
+
# Deal with tail section
|
| 877 |
+
nheads_in_this_section = (
|
| 878 |
+
nheads_in_l2
|
| 879 |
+
if nheads_in_l2 * (section_idx + 1) <= params.num_head
|
| 880 |
+
else params.num_head - section_idx * nheads_in_l2
|
| 881 |
+
)
|
| 882 |
+
block = l2_mod // nheads_in_this_section
|
| 883 |
+
head_idx_residual = l2_mod - block * nheads_in_this_section
|
| 884 |
+
head_idx = section_idx * nheads_in_l2 + head_idx_residual
|
| 885 |
+
if cutlass.const_expr(params.lpt):
|
| 886 |
+
block = num_m_blocks - 1 - block
|
| 887 |
+
else:
|
| 888 |
+
head_idx = mh_block // num_m_blocks
|
| 889 |
+
block = mh_block - head_idx * num_m_blocks
|
| 890 |
+
is_valid = self._is_first_block and batch_idx < params.num_batch
|
| 891 |
+
if cutlass.const_expr(params.cluster_shape_m > 1):
|
| 892 |
+
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
| 893 |
+
block = block * params.cluster_shape_m + bidx_in_cluster[0]
|
| 894 |
+
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
|
| 895 |
+
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
|
| 896 |
+
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
|
| 897 |
+
|
| 898 |
+
@cute.jit
|
| 899 |
+
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
| 900 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 901 |
+
clc_work = self.clc.get_current_work()
|
| 902 |
+
# Default to grid_dim (one past last valid flat index) so _varlen_coord_map
|
| 903 |
+
# returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when
|
| 904 |
+
# invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural
|
| 905 |
+
# mismatch on self inside the runtime if.
|
| 906 |
+
new_tile_idx = cute.arch.grid_dim()[0]
|
| 907 |
+
new_split_idx = Int32(0)
|
| 908 |
+
if clc_work.is_valid_tile:
|
| 909 |
+
new_tile_idx = clc_work.tile_idx[0]
|
| 910 |
+
if const_expr(self.params.is_split_kv):
|
| 911 |
+
new_split_idx = clc_work.tile_idx[1]
|
| 912 |
+
self._tile_idx = new_tile_idx
|
| 913 |
+
self._split_idx = new_split_idx
|
| 914 |
+
return self._varlen_coord_map()
|
| 915 |
+
|
| 916 |
+
@cute.jit
|
| 917 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 918 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 919 |
+
clc_work = self.clc.initial_work_tile_info()
|
| 920 |
+
# See get_current_work for why grid_dim and local-then-assign.
|
| 921 |
+
new_tile_idx = cute.arch.grid_dim()[0]
|
| 922 |
+
new_split_idx = Int32(0)
|
| 923 |
+
if clc_work.is_valid_tile:
|
| 924 |
+
new_tile_idx = clc_work.tile_idx[0]
|
| 925 |
+
if const_expr(self.params.is_split_kv):
|
| 926 |
+
new_split_idx = clc_work.tile_idx[1]
|
| 927 |
+
self._tile_idx = new_tile_idx
|
| 928 |
+
self._split_idx = new_split_idx
|
| 929 |
+
return self._varlen_coord_map()
|
| 930 |
+
|
| 931 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 932 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 933 |
+
self.clc.prefetch_next_work(loc=loc, ip=ip)
|
| 934 |
+
|
| 935 |
+
def advance_to_next_work(self, *, loc=None, ip=None):
|
| 936 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 937 |
+
self.clc.consumer_wait(loc=loc, ip=ip)
|
| 938 |
+
work = self.get_current_work()
|
| 939 |
+
self.clc.consumer_release(loc=loc, ip=ip)
|
| 940 |
+
return work
|
| 941 |
+
self._is_first_block = False
|
| 942 |
+
return self.get_current_work()
|
| 943 |
+
|
| 944 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 945 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 946 |
+
self.clc.producer_tail(loc=loc, ip=ip)
|
| 947 |
+
|
| 948 |
+
def __extract_mlir_values__(self):
|
| 949 |
+
values, self._values_pos = [], []
|
| 950 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 951 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 952 |
+
objs += [self.clc]
|
| 953 |
+
for obj in objs:
|
| 954 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 955 |
+
values += obj_values
|
| 956 |
+
self._values_pos.append(len(obj_values))
|
| 957 |
+
return values
|
| 958 |
+
|
| 959 |
+
def __new_from_mlir_values__(self, values):
|
| 960 |
+
obj_list = []
|
| 961 |
+
objs = [self.params, self._tile_idx, self._split_idx]
|
| 962 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 963 |
+
objs += [self.clc]
|
| 964 |
+
for obj, n_items in zip(objs, self._values_pos):
|
| 965 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 966 |
+
values = values[n_items:]
|
| 967 |
+
return self.__class__(*obj_list, loc=self._loc)
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/tma_utils.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Raw TMA ops and descriptor builders.
|
| 5 |
+
|
| 6 |
+
`tma_utils.py` is the canonical owner for raw TMA inline-asm helpers and TMA
|
| 7 |
+
descriptor construction. Non-TMA store/layout helpers are re-exported from
|
| 8 |
+
`copy_utils.py` for backward compatibility.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import ctypes
|
| 12 |
+
|
| 13 |
+
from cutlass import Int32, Int64
|
| 14 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 15 |
+
from cutlass._mlir.dialects import llvm
|
| 16 |
+
import cutlass._mlir.dialects.cute as cute_ir
|
| 17 |
+
import cutlass._mlir.dialects.cute_nvgpu as cute_nvgpu_ir
|
| 18 |
+
from cutlass._mlir.dialects import _cute_nvgpu_ops_gen as cute_nvgpu_gen
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Raw TMA Ops
|
| 22 |
+
|
| 23 |
+
TMA_CACHE_EVICT_FIRST = 0x12F0000000000000
|
| 24 |
+
TMA_CACHE_EVICT_LAST = 0x14F0000000000000
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dsl_user_op
|
| 28 |
+
def tma_tile_load(
|
| 29 |
+
smem_ptr,
|
| 30 |
+
smem_byte_offset,
|
| 31 |
+
tma_desc_ptr,
|
| 32 |
+
col_idx,
|
| 33 |
+
row_idx,
|
| 34 |
+
mbar_ptr,
|
| 35 |
+
*,
|
| 36 |
+
loc=None,
|
| 37 |
+
ip=None,
|
| 38 |
+
):
|
| 39 |
+
"""cp.async.bulk.tensor.2d.shared::cta.global.tile with mbar completion."""
|
| 40 |
+
llvm.inline_asm(
|
| 41 |
+
T.i32(),
|
| 42 |
+
[
|
| 43 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 44 |
+
Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
|
| 45 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 46 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 47 |
+
Int32(row_idx).ir_value(loc=loc, ip=ip),
|
| 48 |
+
mbar_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 49 |
+
],
|
| 50 |
+
"{\n"
|
| 51 |
+
".reg .u32 sa, ma;\n"
|
| 52 |
+
"cvt.u32.u64 sa, $1;\n"
|
| 53 |
+
"add.u32 sa, sa, $2;\n"
|
| 54 |
+
"cvt.u32.u64 ma, $6;\n"
|
| 55 |
+
"cp.async.bulk.tensor.2d.shared::cta.global.tile"
|
| 56 |
+
".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5}], [ma];\n"
|
| 57 |
+
"mov.u32 $0, 0;\n"
|
| 58 |
+
"}\n",
|
| 59 |
+
"=r,l,r,l,r,r,l",
|
| 60 |
+
has_side_effects=True,
|
| 61 |
+
is_align_stack=False,
|
| 62 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 63 |
+
loc=loc,
|
| 64 |
+
ip=ip,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dsl_user_op
|
| 69 |
+
def tma_gather4(
|
| 70 |
+
smem_ptr,
|
| 71 |
+
smem_byte_offset,
|
| 72 |
+
tma_desc_ptr,
|
| 73 |
+
col_idx,
|
| 74 |
+
row0,
|
| 75 |
+
row1,
|
| 76 |
+
row2,
|
| 77 |
+
row3,
|
| 78 |
+
mbar_ptr,
|
| 79 |
+
*,
|
| 80 |
+
loc=None,
|
| 81 |
+
ip=None,
|
| 82 |
+
):
|
| 83 |
+
"""cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with mbar."""
|
| 84 |
+
llvm.inline_asm(
|
| 85 |
+
T.i32(),
|
| 86 |
+
[
|
| 87 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 88 |
+
Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
|
| 89 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 90 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 91 |
+
Int32(row0).ir_value(loc=loc, ip=ip),
|
| 92 |
+
Int32(row1).ir_value(loc=loc, ip=ip),
|
| 93 |
+
Int32(row2).ir_value(loc=loc, ip=ip),
|
| 94 |
+
Int32(row3).ir_value(loc=loc, ip=ip),
|
| 95 |
+
mbar_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 96 |
+
],
|
| 97 |
+
"{\n"
|
| 98 |
+
".reg .u32 sa, ma;\n"
|
| 99 |
+
"cvt.u32.u64 sa, $1;\n"
|
| 100 |
+
"add.u32 sa, sa, $2;\n"
|
| 101 |
+
"cvt.u32.u64 ma, $9;\n"
|
| 102 |
+
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4"
|
| 103 |
+
".mbarrier::complete_tx::bytes [sa], [$3, {$4, $5, $6, $7, $8}], [ma];\n"
|
| 104 |
+
"mov.u32 $0, 0;\n"
|
| 105 |
+
"}\n",
|
| 106 |
+
"=r,l,r,l,r,r,r,r,r,l",
|
| 107 |
+
has_side_effects=True,
|
| 108 |
+
is_align_stack=False,
|
| 109 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 110 |
+
loc=loc,
|
| 111 |
+
ip=ip,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dsl_user_op
|
| 116 |
+
def prefetch_tma_desc_raw(tma_desc_ptr, *, loc=None, ip=None):
|
| 117 |
+
"""Prefetch a raw TMA descriptor pointer into the descriptor cache."""
|
| 118 |
+
ptr_i64 = tma_desc_ptr.toint().ir_value(loc=loc, ip=ip)
|
| 119 |
+
ptr_i64_align_ty = cute_ir.ConstrainedIntType.get(128, ptr_i64.type.width)
|
| 120 |
+
ptr_i64_align = cute_ir.assume(ptr_i64_align_ty, ptr_i64, loc=loc, ip=ip)
|
| 121 |
+
ptr_ty = cute_ir.PtrType.get(
|
| 122 |
+
cute_nvgpu_ir.TmaDescriptorTiledType.get(),
|
| 123 |
+
cute_ir.AddressSpace.gmem,
|
| 124 |
+
128,
|
| 125 |
+
)
|
| 126 |
+
desc_ptr = cute_ir.inttoptr(ptr_ty, ptr_i64_align, loc=loc, ip=ip)
|
| 127 |
+
cute_nvgpu_gen.arch_prefetch_tma_desc(desc_ptr.value, loc=loc, ip=ip)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@dsl_user_op
|
| 131 |
+
def tma_tile_prefetch(
|
| 132 |
+
tma_desc_ptr,
|
| 133 |
+
col_idx,
|
| 134 |
+
row_idx,
|
| 135 |
+
cache_hint=TMA_CACHE_EVICT_FIRST,
|
| 136 |
+
*,
|
| 137 |
+
loc=None,
|
| 138 |
+
ip=None,
|
| 139 |
+
):
|
| 140 |
+
"""cp.async.bulk.prefetch.tensor.2d.L2.global.tile with cache hint."""
|
| 141 |
+
llvm.inline_asm(
|
| 142 |
+
None,
|
| 143 |
+
[
|
| 144 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 145 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 146 |
+
Int32(row_idx).ir_value(loc=loc, ip=ip),
|
| 147 |
+
Int64(cache_hint).ir_value(loc=loc, ip=ip),
|
| 148 |
+
],
|
| 149 |
+
"cp.async.bulk.prefetch.tensor.2d.L2.global.tile.L2::cache_hint "
|
| 150 |
+
"[$0, {$1, $2}], $3;\n",
|
| 151 |
+
"l,r,r,l",
|
| 152 |
+
has_side_effects=True,
|
| 153 |
+
is_align_stack=False,
|
| 154 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 155 |
+
loc=loc,
|
| 156 |
+
ip=ip,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@dsl_user_op
|
| 161 |
+
def tma_gather4_prefetch(
|
| 162 |
+
tma_desc_ptr,
|
| 163 |
+
col_idx,
|
| 164 |
+
row0,
|
| 165 |
+
row1,
|
| 166 |
+
row2,
|
| 167 |
+
row3,
|
| 168 |
+
cache_hint=TMA_CACHE_EVICT_LAST,
|
| 169 |
+
*,
|
| 170 |
+
loc=None,
|
| 171 |
+
ip=None,
|
| 172 |
+
):
|
| 173 |
+
"""cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4 with cache hint."""
|
| 174 |
+
llvm.inline_asm(
|
| 175 |
+
None,
|
| 176 |
+
[
|
| 177 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 178 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 179 |
+
Int32(row0).ir_value(loc=loc, ip=ip),
|
| 180 |
+
Int32(row1).ir_value(loc=loc, ip=ip),
|
| 181 |
+
Int32(row2).ir_value(loc=loc, ip=ip),
|
| 182 |
+
Int32(row3).ir_value(loc=loc, ip=ip),
|
| 183 |
+
Int64(cache_hint).ir_value(loc=loc, ip=ip),
|
| 184 |
+
],
|
| 185 |
+
"cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint "
|
| 186 |
+
"[$0, {$1, $2, $3, $4, $5}], $6;\n",
|
| 187 |
+
"l,r,r,r,r,r,l",
|
| 188 |
+
has_side_effects=True,
|
| 189 |
+
is_align_stack=False,
|
| 190 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 191 |
+
loc=loc,
|
| 192 |
+
ip=ip,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@dsl_user_op
|
| 197 |
+
def tma_tile_load_cached(
|
| 198 |
+
smem_ptr,
|
| 199 |
+
smem_byte_offset,
|
| 200 |
+
tma_desc_ptr,
|
| 201 |
+
col_idx,
|
| 202 |
+
row_idx,
|
| 203 |
+
mbar_ptr,
|
| 204 |
+
cache_hint=TMA_CACHE_EVICT_FIRST,
|
| 205 |
+
*,
|
| 206 |
+
loc=None,
|
| 207 |
+
ip=None,
|
| 208 |
+
):
|
| 209 |
+
"""cp.async.bulk.tensor.2d.shared::cta.global.tile with cache hint and mbar."""
|
| 210 |
+
llvm.inline_asm(
|
| 211 |
+
T.i32(),
|
| 212 |
+
[
|
| 213 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 214 |
+
Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
|
| 215 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 216 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 217 |
+
Int32(row_idx).ir_value(loc=loc, ip=ip),
|
| 218 |
+
mbar_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 219 |
+
Int64(cache_hint).ir_value(loc=loc, ip=ip),
|
| 220 |
+
],
|
| 221 |
+
"{\n"
|
| 222 |
+
".reg .u32 sa, ma;\n"
|
| 223 |
+
"cvt.u32.u64 sa, $1;\n"
|
| 224 |
+
"add.u32 sa, sa, $2;\n"
|
| 225 |
+
"cvt.u32.u64 ma, $6;\n"
|
| 226 |
+
"cp.async.bulk.tensor.2d.shared::cta.global.tile"
|
| 227 |
+
".mbarrier::complete_tx::bytes.L2::cache_hint "
|
| 228 |
+
"[sa], [$3, {$4, $5}], [ma], $7;\n"
|
| 229 |
+
"mov.u32 $0, 0;\n"
|
| 230 |
+
"}\n",
|
| 231 |
+
"=r,l,r,l,r,r,l,l",
|
| 232 |
+
has_side_effects=True,
|
| 233 |
+
is_align_stack=False,
|
| 234 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 235 |
+
loc=loc,
|
| 236 |
+
ip=ip,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@dsl_user_op
|
| 241 |
+
def tma_gather4_cached(
|
| 242 |
+
smem_ptr,
|
| 243 |
+
smem_byte_offset,
|
| 244 |
+
tma_desc_ptr,
|
| 245 |
+
col_idx,
|
| 246 |
+
row0,
|
| 247 |
+
row1,
|
| 248 |
+
row2,
|
| 249 |
+
row3,
|
| 250 |
+
mbar_ptr,
|
| 251 |
+
cache_hint=TMA_CACHE_EVICT_LAST,
|
| 252 |
+
*,
|
| 253 |
+
loc=None,
|
| 254 |
+
ip=None,
|
| 255 |
+
):
|
| 256 |
+
"""cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4 with cache hint."""
|
| 257 |
+
llvm.inline_asm(
|
| 258 |
+
None,
|
| 259 |
+
[
|
| 260 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 261 |
+
Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
|
| 262 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 263 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 264 |
+
Int32(row0).ir_value(loc=loc, ip=ip),
|
| 265 |
+
Int32(row1).ir_value(loc=loc, ip=ip),
|
| 266 |
+
Int32(row2).ir_value(loc=loc, ip=ip),
|
| 267 |
+
Int32(row3).ir_value(loc=loc, ip=ip),
|
| 268 |
+
mbar_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 269 |
+
Int64(cache_hint).ir_value(loc=loc, ip=ip),
|
| 270 |
+
],
|
| 271 |
+
"{\n"
|
| 272 |
+
".reg .u32 sa, ma;\n"
|
| 273 |
+
"cvt.u32.u64 sa, $0;\n"
|
| 274 |
+
"add.u32 sa, sa, $1;\n"
|
| 275 |
+
"cvt.u32.u64 ma, $8;\n"
|
| 276 |
+
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4"
|
| 277 |
+
".mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint "
|
| 278 |
+
"[sa], [$2, {$3, $4, $5, $6, $7}], [ma], $9;\n"
|
| 279 |
+
"}\n",
|
| 280 |
+
"l,r,l,r,r,r,r,r,l,l",
|
| 281 |
+
has_side_effects=True,
|
| 282 |
+
is_align_stack=False,
|
| 283 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 284 |
+
loc=loc,
|
| 285 |
+
ip=ip,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@dsl_user_op
|
| 290 |
+
def tma_tile_store(
|
| 291 |
+
tma_desc_ptr,
|
| 292 |
+
col_idx,
|
| 293 |
+
row_idx,
|
| 294 |
+
smem_ptr,
|
| 295 |
+
smem_byte_offset,
|
| 296 |
+
*,
|
| 297 |
+
loc=None,
|
| 298 |
+
ip=None,
|
| 299 |
+
):
|
| 300 |
+
"""cp.async.bulk.tensor.2d.global.shared::cta.bulk_group store."""
|
| 301 |
+
llvm.inline_asm(
|
| 302 |
+
T.i32(),
|
| 303 |
+
[
|
| 304 |
+
tma_desc_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 305 |
+
Int32(col_idx).ir_value(loc=loc, ip=ip),
|
| 306 |
+
Int32(row_idx).ir_value(loc=loc, ip=ip),
|
| 307 |
+
smem_ptr.toint().ir_value(loc=loc, ip=ip),
|
| 308 |
+
Int32(smem_byte_offset).ir_value(loc=loc, ip=ip),
|
| 309 |
+
],
|
| 310 |
+
"{\n"
|
| 311 |
+
".reg .u32 sa;\n"
|
| 312 |
+
"cvt.u32.u64 sa, $4;\n"
|
| 313 |
+
"add.u32 sa, sa, $5;\n"
|
| 314 |
+
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group"
|
| 315 |
+
" [$1, {$2, $3}], [sa];\n"
|
| 316 |
+
"mov.u32 $0, 0;\n"
|
| 317 |
+
"}\n",
|
| 318 |
+
"=r,l,r,r,l,r",
|
| 319 |
+
has_side_effects=True,
|
| 320 |
+
is_align_stack=False,
|
| 321 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 322 |
+
loc=loc,
|
| 323 |
+
ip=ip,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# Descriptor Builders
|
| 328 |
+
|
| 329 |
+
_TMA_DESC_BYTES = 128
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _encode_tma_desc_2d_bytes(tensor_2d, *, box_x, box_y, context: str) -> bytes:
|
| 333 |
+
import torch
|
| 334 |
+
import cuda.bindings.driver as cuda
|
| 335 |
+
|
| 336 |
+
if tensor_2d.ndim != 2:
|
| 337 |
+
raise ValueError(f"{context} tensor must be rank-2, got {tuple(tensor_2d.shape)}")
|
| 338 |
+
rows, cols = tensor_2d.shape
|
| 339 |
+
if tensor_2d.stride(-1) != 1:
|
| 340 |
+
raise ValueError(f"{context} tensor must be contiguous in the last dimension")
|
| 341 |
+
dtype_map = {
|
| 342 |
+
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
| 343 |
+
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
| 344 |
+
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
| 345 |
+
}
|
| 346 |
+
if tensor_2d.dtype not in dtype_map:
|
| 347 |
+
raise TypeError(f"Unsupported dtype for {context} TMA descriptor: {tensor_2d.dtype}")
|
| 348 |
+
|
| 349 |
+
sizes = [cuda.cuuint64_t(cols), cuda.cuuint64_t(rows)]
|
| 350 |
+
strides = [cuda.cuuint64_t(tensor_2d.stride(0) * tensor_2d.element_size())]
|
| 351 |
+
box = [cuda.cuuint32_t(box_x), cuda.cuuint32_t(box_y)]
|
| 352 |
+
elem_stride = [cuda.cuuint32_t(1), cuda.cuuint32_t(1)]
|
| 353 |
+
err, tm = cuda.cuTensorMapEncodeTiled(
|
| 354 |
+
dtype_map[tensor_2d.dtype],
|
| 355 |
+
2,
|
| 356 |
+
tensor_2d.data_ptr(),
|
| 357 |
+
sizes,
|
| 358 |
+
strides,
|
| 359 |
+
box,
|
| 360 |
+
elem_stride,
|
| 361 |
+
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
| 362 |
+
cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
| 363 |
+
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
| 364 |
+
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
| 365 |
+
)
|
| 366 |
+
assert err == cuda.CUresult.CUDA_SUCCESS, f"TMA encode failed: {err}"
|
| 367 |
+
buf = (ctypes.c_uint8 * _TMA_DESC_BYTES).from_address(tm.getPtr())
|
| 368 |
+
return bytes(buf)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _desc_bytes_to_device_tensor(desc_bytes: bytes | bytearray, *, device):
|
| 372 |
+
import torch
|
| 373 |
+
|
| 374 |
+
desc_bytes = bytes(desc_bytes)
|
| 375 |
+
device = torch.device(device)
|
| 376 |
+
if device.type != "cuda":
|
| 377 |
+
raise ValueError(f"TMA descriptors require a CUDA device, got {device}")
|
| 378 |
+
|
| 379 |
+
host_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, pin_memory=True)
|
| 380 |
+
host_desc.copy_(torch.frombuffer(bytearray(desc_bytes), dtype=torch.uint8))
|
| 381 |
+
device_desc = torch.empty((len(desc_bytes),), dtype=torch.uint8, device=device)
|
| 382 |
+
stream = torch.cuda.current_stream(device)
|
| 383 |
+
with torch.cuda.stream(stream):
|
| 384 |
+
device_desc.copy_(host_desc, non_blocking=True)
|
| 385 |
+
device_desc.record_stream(stream)
|
| 386 |
+
# Keep the staging buffer alive for the async copy without caching descriptors.
|
| 387 |
+
device_desc._tma_host_desc = host_desc
|
| 388 |
+
return device_desc
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def create_flat_gather4_tma_desc(tensor_2d, box_x=64):
|
| 392 |
+
"""Create a gather4 CUtensorMap descriptor for a flat 2D row-major tensor."""
|
| 393 |
+
if tensor_2d.ndim != 2:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
f"tensor_2d must be rank-2 [rows, dim], got {tuple(tensor_2d.shape)}"
|
| 396 |
+
)
|
| 397 |
+
desc = _encode_tma_desc_2d_bytes(
|
| 398 |
+
tensor_2d,
|
| 399 |
+
box_x=box_x,
|
| 400 |
+
box_y=1,
|
| 401 |
+
context="gather4",
|
| 402 |
+
)
|
| 403 |
+
return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def create_q_gather4_tma_desc(q_flat, box_x=64):
|
| 407 |
+
return create_flat_gather4_tma_desc(q_flat, box_x=box_x)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def create_strided_2d_tma_desc(tensor_2d, *, box_x, box_y):
|
| 411 |
+
"""Create a CUtensorMap descriptor for a rank-2 tensor with arbitrary row stride."""
|
| 412 |
+
desc = _encode_tma_desc_2d_bytes(
|
| 413 |
+
tensor_2d,
|
| 414 |
+
box_x=box_x,
|
| 415 |
+
box_y=box_y,
|
| 416 |
+
context="strided 2D",
|
| 417 |
+
)
|
| 418 |
+
return _desc_bytes_to_device_tensor(desc, device=tensor_2d.device)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def create_flat_kv_tma_descs(kv_flat, *, box_x=64, box_y=128):
|
| 422 |
+
"""Create per-KV-head token-major TMA descriptors for flat [total_k, H, D] storage."""
|
| 423 |
+
import torch
|
| 424 |
+
|
| 425 |
+
if kv_flat.ndim != 3:
|
| 426 |
+
raise ValueError(
|
| 427 |
+
f"kv_flat must be rank-3 [total_k, H, D], got {tuple(kv_flat.shape)}"
|
| 428 |
+
)
|
| 429 |
+
total_k, head_kv, dim = kv_flat.shape
|
| 430 |
+
row_stride = head_kv * dim
|
| 431 |
+
desc_table = bytearray()
|
| 432 |
+
for h in range(head_kv):
|
| 433 |
+
head_view = torch.as_strided(
|
| 434 |
+
kv_flat,
|
| 435 |
+
size=(total_k, dim),
|
| 436 |
+
stride=(row_stride, 1),
|
| 437 |
+
storage_offset=h * dim,
|
| 438 |
+
)
|
| 439 |
+
desc_table.extend(
|
| 440 |
+
_encode_tma_desc_2d_bytes(
|
| 441 |
+
head_view,
|
| 442 |
+
box_x=box_x,
|
| 443 |
+
box_y=box_y,
|
| 444 |
+
context="flat KV",
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
return _desc_bytes_to_device_tensor(desc_table, device=kv_flat.device).reshape(
|
| 448 |
+
head_kv, _TMA_DESC_BYTES
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# Compatibility Re-exports
|
| 453 |
+
|
| 454 |
+
from .copy_utils import (
|
| 455 |
+
atomic_add_broadcast_i32,
|
| 456 |
+
atomic_add_i32,
|
| 457 |
+
convert_layout_acc_mn,
|
| 458 |
+
convert_layout_from_tmem16x256b_to_acc_sm90,
|
| 459 |
+
make_16x256b_tensor_mn_view,
|
| 460 |
+
real_col_to_stg128_fake_col,
|
| 461 |
+
real_col_to_stg128_fp8_fake_col,
|
| 462 |
+
real_col_to_stg128_half_fake_col,
|
| 463 |
+
stg128_fake_col_to_real_col,
|
| 464 |
+
stg128_fp8_fake_col_to_real_col,
|
| 465 |
+
stg128_half_fake_col_to_real_col,
|
| 466 |
+
stg_128,
|
| 467 |
+
stg_128_cs,
|
| 468 |
+
stg_128_bf16,
|
| 469 |
+
stg_128_bf16_cs,
|
| 470 |
+
stg_128_f16,
|
| 471 |
+
stg_128_f16_cs,
|
| 472 |
+
stg_128_fp8_e4m3_cs,
|
| 473 |
+
stg_32_fp8_e4m3,
|
| 474 |
+
stg_64_bf16,
|
| 475 |
+
stg_64_f16,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
__all__ = [
|
| 480 |
+
"TMA_CACHE_EVICT_FIRST",
|
| 481 |
+
"TMA_CACHE_EVICT_LAST",
|
| 482 |
+
"atomic_add_broadcast_i32",
|
| 483 |
+
"atomic_add_i32",
|
| 484 |
+
"convert_layout_acc_mn",
|
| 485 |
+
"convert_layout_from_tmem16x256b_to_acc_sm90",
|
| 486 |
+
"create_flat_gather4_tma_desc",
|
| 487 |
+
"create_flat_kv_tma_descs",
|
| 488 |
+
"create_q_gather4_tma_desc",
|
| 489 |
+
"create_strided_2d_tma_desc",
|
| 490 |
+
"make_16x256b_tensor_mn_view",
|
| 491 |
+
"prefetch_tma_desc_raw",
|
| 492 |
+
"real_col_to_stg128_fake_col",
|
| 493 |
+
"real_col_to_stg128_fp8_fake_col",
|
| 494 |
+
"real_col_to_stg128_half_fake_col",
|
| 495 |
+
"stg128_fake_col_to_real_col",
|
| 496 |
+
"stg128_fp8_fake_col_to_real_col",
|
| 497 |
+
"stg128_half_fake_col_to_real_col",
|
| 498 |
+
"stg_128",
|
| 499 |
+
"stg_128_cs",
|
| 500 |
+
"stg_128_bf16",
|
| 501 |
+
"stg_128_bf16_cs",
|
| 502 |
+
"stg_128_f16",
|
| 503 |
+
"stg_128_f16_cs",
|
| 504 |
+
"stg_128_fp8_e4m3_cs",
|
| 505 |
+
"stg_32_fp8_e4m3",
|
| 506 |
+
"stg_64_bf16",
|
| 507 |
+
"stg_64_f16",
|
| 508 |
+
"tma_gather4",
|
| 509 |
+
"tma_gather4_cached",
|
| 510 |
+
"tma_gather4_prefetch",
|
| 511 |
+
"tma_tile_load",
|
| 512 |
+
"tma_tile_load_cached",
|
| 513 |
+
"tma_tile_prefetch",
|
| 514 |
+
"tma_tile_store",
|
| 515 |
+
]
|
build/torch211-cxx11-cu128-x86_64-linux/src/common/utils.py
ADDED
|
@@ -0,0 +1,1088 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import hashlib
|
| 6 |
+
import inspect
|
| 7 |
+
from typing import Type, Callable, Optional, Tuple, overload
|
| 8 |
+
|
| 9 |
+
import cutlass
|
| 10 |
+
import cutlass.cute as cute
|
| 11 |
+
|
| 12 |
+
from cutlass import Float32, const_expr
|
| 13 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 14 |
+
from cutlass._mlir.dialects import nvvm, llvm
|
| 15 |
+
from cutlass.cute.runtime import from_dlpack
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from ...quack import activation
|
| 19 |
+
_MIXER_ATTRS = ("__vec_size__",)
|
| 20 |
+
|
| 21 |
+
# Obtained from sollya:
|
| 22 |
+
# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
|
| 23 |
+
POLY_EX2 = {
|
| 24 |
+
0: (1.0),
|
| 25 |
+
1: (
|
| 26 |
+
1.0,
|
| 27 |
+
0.922497093677520751953125,
|
| 28 |
+
),
|
| 29 |
+
2: (
|
| 30 |
+
1.0,
|
| 31 |
+
0.6657850742340087890625,
|
| 32 |
+
0.330107033252716064453125,
|
| 33 |
+
),
|
| 34 |
+
3: (
|
| 35 |
+
1.0,
|
| 36 |
+
0.695146143436431884765625,
|
| 37 |
+
0.227564394474029541015625,
|
| 38 |
+
0.077119089663028717041015625,
|
| 39 |
+
),
|
| 40 |
+
4: (
|
| 41 |
+
1.0,
|
| 42 |
+
0.693042695522308349609375,
|
| 43 |
+
0.2412912547588348388671875,
|
| 44 |
+
5.2225358784198760986328125e-2,
|
| 45 |
+
1.3434938155114650726318359375e-2,
|
| 46 |
+
),
|
| 47 |
+
5: (
|
| 48 |
+
1.0,
|
| 49 |
+
0.693151414394378662109375,
|
| 50 |
+
0.24016360938549041748046875,
|
| 51 |
+
5.5802188813686370849609375e-2,
|
| 52 |
+
9.01452265679836273193359375e-3,
|
| 53 |
+
1.86810153536498546600341796875e-3,
|
| 54 |
+
),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _compute_base_hash(func: Callable) -> str:
|
| 59 |
+
"""Compute hash from source code or bytecode and closure values."""
|
| 60 |
+
try:
|
| 61 |
+
data = inspect.getsource(func).encode()
|
| 62 |
+
except (OSError, TypeError):
|
| 63 |
+
if hasattr(func, "__code__") and func.__code__ is not None:
|
| 64 |
+
data = func.__code__.co_code
|
| 65 |
+
else:
|
| 66 |
+
data = repr(func).encode()
|
| 67 |
+
|
| 68 |
+
hasher = hashlib.sha256(data)
|
| 69 |
+
|
| 70 |
+
if hasattr(func, "__closure__") and func.__closure__ is not None:
|
| 71 |
+
for cell in func.__closure__:
|
| 72 |
+
hasher.update(repr(cell.cell_contents).encode())
|
| 73 |
+
|
| 74 |
+
return hasher.hexdigest()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def hash_callable(
|
| 78 |
+
func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
|
| 79 |
+
) -> str:
|
| 80 |
+
"""Hash a callable based on the source code or bytecode and closure values.
|
| 81 |
+
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
|
| 82 |
+
attribute, that value is returned immediately as the base hash, then
|
| 83 |
+
metadata dunders are mixed in to produce the final dict-key hash.
|
| 84 |
+
set_cute_hash: whether or not to set func.__cute_hash__
|
| 85 |
+
"""
|
| 86 |
+
# Resolve base hash
|
| 87 |
+
if hasattr(func, "__cute_hash__"):
|
| 88 |
+
base_hash = func.__cute_hash__
|
| 89 |
+
else:
|
| 90 |
+
# Unwrap decorated functions (e.g., cute.jit wrappers).
|
| 91 |
+
base_func = getattr(func, "__wrapped__", func)
|
| 92 |
+
|
| 93 |
+
if hasattr(base_func, "__cute_hash__"):
|
| 94 |
+
base_hash = base_func.__cute_hash__
|
| 95 |
+
else:
|
| 96 |
+
base_hash = _compute_base_hash(base_func)
|
| 97 |
+
|
| 98 |
+
if set_cute_hash:
|
| 99 |
+
base_func.__cute_hash__ = base_hash
|
| 100 |
+
|
| 101 |
+
# Mix in mutable metadata dunders
|
| 102 |
+
mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
|
| 103 |
+
|
| 104 |
+
if all(v is None for v in mixer_values):
|
| 105 |
+
return base_hash
|
| 106 |
+
|
| 107 |
+
hasher = hashlib.sha256(base_hash.encode())
|
| 108 |
+
|
| 109 |
+
for attr, val in zip(_MIXER_ATTRS, mixer_values):
|
| 110 |
+
hasher.update(f"{attr}={val!r}".encode())
|
| 111 |
+
|
| 112 |
+
return hasher.hexdigest()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
LOG2_E = math.log2(math.e)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def compute_softmax_scale_log2(softmax_scale):
|
| 119 |
+
"""Compute softmax_scale_log2 from softmax_scale.
|
| 120 |
+
|
| 121 |
+
Returns (softmax_scale_log2, None).
|
| 122 |
+
"""
|
| 123 |
+
return softmax_scale * LOG2_E, None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
| 127 |
+
return (
|
| 128 |
+
from_dlpack(x, assumed_align=alignment)
|
| 129 |
+
.mark_layout_dynamic(leading_dim=leading_dim)
|
| 130 |
+
.mark_compact_shape_dynamic(
|
| 131 |
+
mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def make_tiled_copy_A(
|
| 137 |
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
| 138 |
+
) -> cute.TiledCopy:
|
| 139 |
+
if const_expr(swapAB):
|
| 140 |
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
| 141 |
+
else:
|
| 142 |
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def make_tiled_copy_B(
|
| 146 |
+
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
| 147 |
+
) -> cute.TiledCopy:
|
| 148 |
+
if const_expr(swapAB):
|
| 149 |
+
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
| 150 |
+
else:
|
| 151 |
+
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def mma_make_fragment_A(
|
| 155 |
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
| 156 |
+
) -> cute.Tensor:
|
| 157 |
+
if const_expr(swapAB):
|
| 158 |
+
return mma_make_fragment_B(smem, thr_mma)
|
| 159 |
+
else:
|
| 160 |
+
return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def mma_make_fragment_B(
|
| 164 |
+
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
|
| 165 |
+
) -> cute.Tensor:
|
| 166 |
+
if const_expr(swapAB):
|
| 167 |
+
return mma_make_fragment_A(smem, thr_mma)
|
| 168 |
+
else:
|
| 169 |
+
return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_smem_store_atom(
|
| 173 |
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
| 174 |
+
) -> cute.CopyAtom:
|
| 175 |
+
if const_expr(arch < 90 or element_type.width != 16):
|
| 176 |
+
return cute.make_copy_atom(
|
| 177 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 178 |
+
element_type,
|
| 179 |
+
num_bits_per_copy=2 * element_type.width,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
return cute.make_copy_atom(
|
| 183 |
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
| 184 |
+
element_type,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@cute.jit
|
| 189 |
+
def warp_reduce(
|
| 190 |
+
val: cute.TensorSSA | cute.Numeric,
|
| 191 |
+
op: Callable,
|
| 192 |
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
| 193 |
+
) -> cute.TensorSSA | cute.Numeric:
|
| 194 |
+
if const_expr(isinstance(val, cute.TensorSSA)):
|
| 195 |
+
res = cute.make_rmem_tensor(val.shape, val.dtype)
|
| 196 |
+
res.store(val)
|
| 197 |
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
| 198 |
+
res[i] = warp_reduce(res[i], op, width)
|
| 199 |
+
return res.load()
|
| 200 |
+
else:
|
| 201 |
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
| 202 |
+
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
| 203 |
+
return val
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@dsl_user_op
|
| 207 |
+
def fmax(
|
| 208 |
+
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
| 209 |
+
) -> Float32:
|
| 210 |
+
from cutlass import CUDA_VERSION
|
| 211 |
+
|
| 212 |
+
# * NVVM call based on nvvm version
|
| 213 |
+
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
|
| 214 |
+
# Old API: requires explicit result type as first positional argument
|
| 215 |
+
return Float32(
|
| 216 |
+
nvvm.fmax(
|
| 217 |
+
T.f32(),
|
| 218 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 219 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 220 |
+
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
|
| 221 |
+
loc=loc,
|
| 222 |
+
ip=ip,
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
# New API: infers result type automatically
|
| 227 |
+
return Float32(
|
| 228 |
+
nvvm.fmax(
|
| 229 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 230 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 231 |
+
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
|
| 232 |
+
loc=loc,
|
| 233 |
+
ip=ip,
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@cute.jit
|
| 239 |
+
def fmax_reduce(
|
| 240 |
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
| 241 |
+
) -> Float32:
|
| 242 |
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
| 243 |
+
res = cute.make_rmem_tensor(x.shape, Float32)
|
| 244 |
+
res.store(x)
|
| 245 |
+
local_max = [res[0], res[1], res[2], res[3]]
|
| 246 |
+
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
|
| 247 |
+
local_max[0] = fmax(local_max[0], res[i + 0])
|
| 248 |
+
local_max[1] = fmax(local_max[1], res[i + 1])
|
| 249 |
+
local_max[2] = fmax(local_max[2], res[i + 2])
|
| 250 |
+
local_max[3] = fmax(local_max[3], res[i + 3])
|
| 251 |
+
local_max[0] = fmax(local_max[0], local_max[1])
|
| 252 |
+
local_max[2] = fmax(local_max[2], local_max[3])
|
| 253 |
+
local_max[0] = fmax(local_max[0], local_max[2])
|
| 254 |
+
return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
|
| 255 |
+
else:
|
| 256 |
+
res = cute.make_rmem_tensor(x.shape, Float32)
|
| 257 |
+
res.store(x)
|
| 258 |
+
local_max_0 = (
|
| 259 |
+
fmax(init_val, res[0], res[1])
|
| 260 |
+
if const_expr(init_val is not None)
|
| 261 |
+
else fmax(res[0], res[1])
|
| 262 |
+
)
|
| 263 |
+
local_max = [
|
| 264 |
+
local_max_0,
|
| 265 |
+
fmax(res[2], res[3]),
|
| 266 |
+
fmax(res[4], res[5]),
|
| 267 |
+
fmax(res[6], res[7]),
|
| 268 |
+
]
|
| 269 |
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
| 270 |
+
local_max[0] = fmax(local_max[0], res[i], res[i + 1])
|
| 271 |
+
local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
|
| 272 |
+
local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
|
| 273 |
+
local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
|
| 274 |
+
local_max[0] = fmax(local_max[0], local_max[1])
|
| 275 |
+
return fmax(local_max[0], local_max[2], local_max[3])
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@cute.jit
|
| 279 |
+
def fadd_reduce(
|
| 280 |
+
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
|
| 281 |
+
) -> Float32:
|
| 282 |
+
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
|
| 283 |
+
if const_expr(init_val is None):
|
| 284 |
+
init_val = Float32.zero
|
| 285 |
+
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
|
| 286 |
+
else:
|
| 287 |
+
res = cute.make_rmem_tensor(x.shape, Float32)
|
| 288 |
+
res.store(x)
|
| 289 |
+
local_sum_0 = (
|
| 290 |
+
cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
|
| 291 |
+
if const_expr(init_val is not None)
|
| 292 |
+
else (res[0], res[1])
|
| 293 |
+
)
|
| 294 |
+
local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
|
| 295 |
+
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
|
| 296 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
|
| 297 |
+
local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
|
| 298 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
|
| 299 |
+
local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
|
| 300 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
|
| 301 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
|
| 302 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
|
| 303 |
+
return local_sum[0][0] + local_sum[0][1]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@cute.jit
|
| 307 |
+
def fadd_exp2_scaled_reduce(
|
| 308 |
+
x: cute.Tensor, scale: Float32, arch: cutlass.Constexpr[int] = 80
|
| 309 |
+
) -> Float32:
|
| 310 |
+
assert cute.size(x.shape) % 2 == 0, "x must have an even number of elements"
|
| 311 |
+
if const_expr(arch < 100):
|
| 312 |
+
return fadd_reduce(cute.math.exp2(x.load() * scale, fastmath=True), arch=arch)
|
| 313 |
+
elif const_expr(cute.size(x.shape) % 8 == 0):
|
| 314 |
+
local_sum = [
|
| 315 |
+
(Float32(0.0), Float32(0.0)),
|
| 316 |
+
(Float32(0.0), Float32(0.0)),
|
| 317 |
+
(Float32(0.0), Float32(0.0)),
|
| 318 |
+
(Float32(0.0), Float32(0.0)),
|
| 319 |
+
]
|
| 320 |
+
for i in cutlass.range_constexpr(0, cute.size(x.shape), 8):
|
| 321 |
+
acc0, acc1 = cute.arch.mul_packed_f32x2(
|
| 322 |
+
(x[i + 0], x[i + 1]), (scale, scale)
|
| 323 |
+
)
|
| 324 |
+
acc2, acc3 = cute.arch.mul_packed_f32x2(
|
| 325 |
+
(x[i + 2], x[i + 3]), (scale, scale)
|
| 326 |
+
)
|
| 327 |
+
acc4, acc5 = cute.arch.mul_packed_f32x2(
|
| 328 |
+
(x[i + 4], x[i + 5]), (scale, scale)
|
| 329 |
+
)
|
| 330 |
+
acc6, acc7 = cute.arch.mul_packed_f32x2(
|
| 331 |
+
(x[i + 6], x[i + 7]), (scale, scale)
|
| 332 |
+
)
|
| 333 |
+
acc0 = cute.math.exp2(acc0, fastmath=True)
|
| 334 |
+
acc1 = cute.math.exp2(acc1, fastmath=True)
|
| 335 |
+
acc2 = cute.math.exp2(acc2, fastmath=True)
|
| 336 |
+
acc3 = cute.math.exp2(acc3, fastmath=True)
|
| 337 |
+
acc4 = cute.math.exp2(acc4, fastmath=True)
|
| 338 |
+
acc5 = cute.math.exp2(acc5, fastmath=True)
|
| 339 |
+
acc6 = cute.math.exp2(acc6, fastmath=True)
|
| 340 |
+
acc7 = cute.math.exp2(acc7, fastmath=True)
|
| 341 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (acc0, acc1))
|
| 342 |
+
local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (acc2, acc3))
|
| 343 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (acc4, acc5))
|
| 344 |
+
local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (acc6, acc7))
|
| 345 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
|
| 346 |
+
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
|
| 347 |
+
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
|
| 348 |
+
return local_sum[0][0] + local_sum[0][1]
|
| 349 |
+
else:
|
| 350 |
+
row_sum = Float32(0.0)
|
| 351 |
+
for i in cutlass.range_constexpr(0, cute.size(x.shape), 2):
|
| 352 |
+
acc0, acc1 = cute.arch.mul_packed_f32x2(
|
| 353 |
+
(x[i], x[i + 1]), (scale, scale)
|
| 354 |
+
)
|
| 355 |
+
acc0 = cute.math.exp2(acc0, fastmath=True)
|
| 356 |
+
acc1 = cute.math.exp2(acc1, fastmath=True)
|
| 357 |
+
row_sum += acc0 + acc1
|
| 358 |
+
return row_sum
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@dsl_user_op
|
| 362 |
+
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
|
| 363 |
+
nvvm.atomicrmw(
|
| 364 |
+
res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@dsl_user_op
|
| 369 |
+
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
| 370 |
+
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@cute.jit
|
| 374 |
+
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
| 375 |
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
| 376 |
+
tApA = cute.make_rmem_tensor(
|
| 377 |
+
cute.make_layout(
|
| 378 |
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
| 379 |
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
| 380 |
+
),
|
| 381 |
+
cutlass.Boolean,
|
| 382 |
+
)
|
| 383 |
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
| 384 |
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
| 385 |
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
| 386 |
+
return tApA
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
|
| 390 |
+
warp_group_idx = cute.arch.thread_idx()[0] // 128
|
| 391 |
+
if const_expr(sync):
|
| 392 |
+
warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
|
| 393 |
+
return warp_group_idx
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@cute.jit
|
| 397 |
+
def shuffle_sync(
|
| 398 |
+
value: cute.Numeric,
|
| 399 |
+
offset: cute.typing.Int,
|
| 400 |
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
| 401 |
+
) -> cute.Numeric:
|
| 402 |
+
assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
|
| 403 |
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
| 404 |
+
mask = cute.arch.WARP_SIZE - width
|
| 405 |
+
clamp = cute.arch.WARP_SIZE - 1
|
| 406 |
+
mask_and_clamp = mask << 8 | clamp
|
| 407 |
+
# important: need stride 1 and not 0 for recast_tensor to work
|
| 408 |
+
val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
|
| 409 |
+
val[0] = value
|
| 410 |
+
val_i32 = cute.recast_tensor(val, cutlass.Int32)
|
| 411 |
+
for i in cutlass.range_constexpr(cute.size(val_i32)):
|
| 412 |
+
val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
|
| 413 |
+
return val[0]
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@dsl_user_op
|
| 417 |
+
def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
| 418 |
+
"""
|
| 419 |
+
Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).
|
| 420 |
+
|
| 421 |
+
Named ``shl_u32`` (not ``shl_b32``) because python type annotations
|
| 422 |
+
distinguish signed/unsigned.
|
| 423 |
+
|
| 424 |
+
PTX semantics (9.7.8.8): "Shift amounts greater than the register width N
|
| 425 |
+
are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0.
|
| 426 |
+
|
| 427 |
+
This differs from C/C++ and LLVM IR, where shifting by >= the type width is
|
| 428 |
+
undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain
|
| 429 |
+
Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer
|
| 430 |
+
may treat the result as poison and eliminate dependent code. Inline PTX
|
| 431 |
+
bypasses the LLVM IR shift entirely -- the instruction is emitted verbatim
|
| 432 |
+
into PTX where clamping makes it safe for all shift amounts.
|
| 433 |
+
"""
|
| 434 |
+
return cutlass.Uint32(
|
| 435 |
+
llvm.inline_asm(
|
| 436 |
+
T.i32(),
|
| 437 |
+
[
|
| 438 |
+
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 439 |
+
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 440 |
+
],
|
| 441 |
+
"shl.b32 $0, $1, $2;",
|
| 442 |
+
"=r,r,r",
|
| 443 |
+
has_side_effects=False,
|
| 444 |
+
is_align_stack=False,
|
| 445 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@dsl_user_op
|
| 451 |
+
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
|
| 452 |
+
"""
|
| 453 |
+
Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).
|
| 454 |
+
|
| 455 |
+
See ``shl_u32`` docstring for why inline PTX is used instead of plain
|
| 456 |
+
CuTeDSL shift operators (LLVM shift-by-type-width UB).
|
| 457 |
+
"""
|
| 458 |
+
return cutlass.Uint32(
|
| 459 |
+
llvm.inline_asm(
|
| 460 |
+
T.i32(),
|
| 461 |
+
[
|
| 462 |
+
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
| 463 |
+
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
| 464 |
+
],
|
| 465 |
+
"shr.u32 $0, $1, $2;",
|
| 466 |
+
"=r,r,r",
|
| 467 |
+
has_side_effects=False,
|
| 468 |
+
is_align_stack=False,
|
| 469 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@cute.jit
|
| 475 |
+
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
| 476 |
+
if const_expr(lane is None):
|
| 477 |
+
lane = cute.arch.lane_idx()
|
| 478 |
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
| 479 |
+
offset = 1 << i
|
| 480 |
+
# Very important that we set mask_and_clamp to 0
|
| 481 |
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
| 482 |
+
if lane >= offset:
|
| 483 |
+
val += partial_sum
|
| 484 |
+
return val
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@dsl_user_op
|
| 488 |
+
def cvt_f16x2_f32(
|
| 489 |
+
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
|
| 490 |
+
) -> cutlass.Int32:
|
| 491 |
+
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
|
| 492 |
+
return cutlass.Int32(
|
| 493 |
+
llvm.inline_asm(
|
| 494 |
+
T.i32(),
|
| 495 |
+
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
|
| 496 |
+
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
|
| 497 |
+
"=r,f,f",
|
| 498 |
+
has_side_effects=False,
|
| 499 |
+
is_align_stack=False,
|
| 500 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 501 |
+
)
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
@dsl_user_op
|
| 506 |
+
def cvt_fp8x4_e4m3_f32(
|
| 507 |
+
a: float | Float32,
|
| 508 |
+
b: float | Float32,
|
| 509 |
+
c: float | Float32,
|
| 510 |
+
d: float | Float32,
|
| 511 |
+
*,
|
| 512 |
+
loc=None,
|
| 513 |
+
ip=None,
|
| 514 |
+
) -> cutlass.Int32:
|
| 515 |
+
return cutlass.Int32(
|
| 516 |
+
llvm.inline_asm(
|
| 517 |
+
T.i32(),
|
| 518 |
+
[
|
| 519 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 520 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 521 |
+
Float32(c).ir_value(loc=loc, ip=ip),
|
| 522 |
+
Float32(d).ir_value(loc=loc, ip=ip),
|
| 523 |
+
],
|
| 524 |
+
"{\n"
|
| 525 |
+
".reg .b16 h0, h1;\n"
|
| 526 |
+
"cvt.rn.satfinite.e4m3x2.f32 h0, $2, $1;\n"
|
| 527 |
+
"cvt.rn.satfinite.e4m3x2.f32 h1, $4, $3;\n"
|
| 528 |
+
"mov.b32 $0, {h0, h1};\n"
|
| 529 |
+
"}\n",
|
| 530 |
+
"=r,f,f,f,f",
|
| 531 |
+
has_side_effects=False,
|
| 532 |
+
is_align_stack=False,
|
| 533 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 534 |
+
)
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
@dsl_user_op
|
| 539 |
+
def cvt_fp8x4_e4m3_bf16x4(
|
| 540 |
+
src: cutlass.Int32,
|
| 541 |
+
*,
|
| 542 |
+
loc=None,
|
| 543 |
+
ip=None,
|
| 544 |
+
) -> Tuple[cutlass.Int32, cutlass.Int32]:
|
| 545 |
+
"""Convert packed e4m3x4 bits into two packed bf16x2 registers."""
|
| 546 |
+
out0 = cutlass.Int32(
|
| 547 |
+
llvm.inline_asm(
|
| 548 |
+
T.i32(),
|
| 549 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 550 |
+
"{\n\t"
|
| 551 |
+
".reg .b32 q, mant, out, bias, zero;\n\t"
|
| 552 |
+
"prmt.b32 q, $1, $1, 0x1302;\n\t"
|
| 553 |
+
"and.b32 out, q, 0x80008000;\n\t"
|
| 554 |
+
"and.b32 mant, q, 0x7f007f00;\n\t"
|
| 555 |
+
"shr.u32 mant, mant, 4;\n\t"
|
| 556 |
+
"or.b32 out, out, mant;\n\t"
|
| 557 |
+
"mov.b32 bias, 0x7b807b80;\n\t"
|
| 558 |
+
"mov.b32 zero, 0;\n\t"
|
| 559 |
+
"fma.rn.bf16x2 $0, out, bias, zero;\n\t"
|
| 560 |
+
"}\n",
|
| 561 |
+
"=r,r",
|
| 562 |
+
has_side_effects=False,
|
| 563 |
+
is_align_stack=False,
|
| 564 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
out1 = cutlass.Int32(
|
| 568 |
+
llvm.inline_asm(
|
| 569 |
+
T.i32(),
|
| 570 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 571 |
+
"{\n\t"
|
| 572 |
+
".reg .b32 q, qs, mant, out, bias, zero;\n\t"
|
| 573 |
+
"prmt.b32 q, $1, $1, 0x1302;\n\t"
|
| 574 |
+
"shl.b32 qs, q, 8;\n\t"
|
| 575 |
+
"and.b32 out, qs, 0x80008000;\n\t"
|
| 576 |
+
"and.b32 mant, qs, 0x7f007f00;\n\t"
|
| 577 |
+
"shr.u32 mant, mant, 4;\n\t"
|
| 578 |
+
"or.b32 out, out, mant;\n\t"
|
| 579 |
+
"mov.b32 bias, 0x7b807b80;\n\t"
|
| 580 |
+
"mov.b32 zero, 0;\n\t"
|
| 581 |
+
"fma.rn.bf16x2 $0, out, bias, zero;\n\t"
|
| 582 |
+
"}\n",
|
| 583 |
+
"=r,r",
|
| 584 |
+
has_side_effects=False,
|
| 585 |
+
is_align_stack=False,
|
| 586 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 587 |
+
)
|
| 588 |
+
)
|
| 589 |
+
return out0, out1
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@dsl_user_op
|
| 593 |
+
def cvt_fp4x2_e2m1_f16x2(
|
| 594 |
+
src: cutlass.Int32,
|
| 595 |
+
*,
|
| 596 |
+
loc=None,
|
| 597 |
+
ip=None,
|
| 598 |
+
) -> cutlass.Int32:
|
| 599 |
+
"""Convert one packed E2M1 byte into one packed f16x2 register."""
|
| 600 |
+
|
| 601 |
+
return cutlass.Int32(
|
| 602 |
+
llvm.inline_asm(
|
| 603 |
+
T.i32(),
|
| 604 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 605 |
+
"{\n\t"
|
| 606 |
+
".reg .b8 byte0;\n\t"
|
| 607 |
+
"mov.b32 {byte0, _, _, _}, $1;\n\t"
|
| 608 |
+
"cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t"
|
| 609 |
+
"}\n",
|
| 610 |
+
"=r,r",
|
| 611 |
+
has_side_effects=False,
|
| 612 |
+
is_align_stack=False,
|
| 613 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 614 |
+
)
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@dsl_user_op
|
| 619 |
+
def cvt_fp4x8_e2m1_f16x8(
|
| 620 |
+
src: cutlass.Int32,
|
| 621 |
+
*,
|
| 622 |
+
loc=None,
|
| 623 |
+
ip=None,
|
| 624 |
+
) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]:
|
| 625 |
+
"""Convert four packed E2M1 bytes into four packed f16x2 registers."""
|
| 626 |
+
|
| 627 |
+
out = llvm.inline_asm(
|
| 628 |
+
llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
|
| 629 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 630 |
+
"{\n\t"
|
| 631 |
+
".reg .b8 byte0, byte1, byte2, byte3;\n\t"
|
| 632 |
+
"mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t"
|
| 633 |
+
"cvt.rn.f16x2.e2m1x2 $0, byte0;\n\t"
|
| 634 |
+
"cvt.rn.f16x2.e2m1x2 $1, byte1;\n\t"
|
| 635 |
+
"cvt.rn.f16x2.e2m1x2 $2, byte2;\n\t"
|
| 636 |
+
"cvt.rn.f16x2.e2m1x2 $3, byte3;\n\t"
|
| 637 |
+
"}\n",
|
| 638 |
+
"=r,=r,=r,=r,r",
|
| 639 |
+
has_side_effects=False,
|
| 640 |
+
is_align_stack=False,
|
| 641 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 642 |
+
)
|
| 643 |
+
out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
|
| 644 |
+
out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
|
| 645 |
+
out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip))
|
| 646 |
+
out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip))
|
| 647 |
+
return out0, out1, out2, out3
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
@dsl_user_op
|
| 651 |
+
def cvt_fp4x8_e2m1_bf16x8(
|
| 652 |
+
src: cutlass.Int32,
|
| 653 |
+
*,
|
| 654 |
+
loc=None,
|
| 655 |
+
ip=None,
|
| 656 |
+
) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32]:
|
| 657 |
+
"""Convert four packed E2M1 bytes into four packed bf16x2 registers."""
|
| 658 |
+
|
| 659 |
+
from cutlass import CUDA_VERSION
|
| 660 |
+
|
| 661 |
+
if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2):
|
| 662 |
+
out = llvm.inline_asm(
|
| 663 |
+
llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
|
| 664 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 665 |
+
"{\n\t"
|
| 666 |
+
".reg .b8 byte0, byte1, byte2, byte3;\n\t"
|
| 667 |
+
"mov.b32 {byte0, byte1, byte2, byte3}, $4;\n\t"
|
| 668 |
+
"cvt.rn.bf16x2.e2m1x2 $0, byte0;\n\t"
|
| 669 |
+
"cvt.rn.bf16x2.e2m1x2 $1, byte1;\n\t"
|
| 670 |
+
"cvt.rn.bf16x2.e2m1x2 $2, byte2;\n\t"
|
| 671 |
+
"cvt.rn.bf16x2.e2m1x2 $3, byte3;\n\t"
|
| 672 |
+
"}\n",
|
| 673 |
+
"=r,=r,=r,=r,r",
|
| 674 |
+
has_side_effects=False,
|
| 675 |
+
is_align_stack=False,
|
| 676 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 677 |
+
)
|
| 678 |
+
out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
|
| 679 |
+
out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
|
| 680 |
+
out2 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [2], loc=loc, ip=ip))
|
| 681 |
+
out3 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [3], loc=loc, ip=ip))
|
| 682 |
+
return out0, out1, out2, out3
|
| 683 |
+
|
| 684 |
+
f16_pair0, f16_pair1, f16_pair2, f16_pair3 = cvt_fp4x8_e2m1_f16x8(
|
| 685 |
+
src, loc=loc, ip=ip
|
| 686 |
+
)
|
| 687 |
+
return (
|
| 688 |
+
cvt_f16x2_to_bf16x2(f16_pair0, loc=loc, ip=ip),
|
| 689 |
+
cvt_f16x2_to_bf16x2(f16_pair1, loc=loc, ip=ip),
|
| 690 |
+
cvt_f16x2_to_bf16x2(f16_pair2, loc=loc, ip=ip),
|
| 691 |
+
cvt_f16x2_to_bf16x2(f16_pair3, loc=loc, ip=ip),
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@dsl_user_op
|
| 696 |
+
def cvt_fp4x8_e2m1_scaled_e4m3x8(
|
| 697 |
+
src: cutlass.Int32,
|
| 698 |
+
scale_e4m3: cutlass.Int32,
|
| 699 |
+
*,
|
| 700 |
+
loc=None,
|
| 701 |
+
ip=None,
|
| 702 |
+
) -> Tuple[cutlass.Int32, cutlass.Int32]:
|
| 703 |
+
"""Scale eight packed E2M1 values by one E4M3 byte and convert to E4M3."""
|
| 704 |
+
|
| 705 |
+
from cutlass import CUDA_VERSION
|
| 706 |
+
|
| 707 |
+
if CUDA_VERSION.major > 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor >= 2):
|
| 708 |
+
out = llvm.inline_asm(
|
| 709 |
+
llvm.StructType.get_literal([T.i32(), T.i32()]),
|
| 710 |
+
[
|
| 711 |
+
cutlass.Int32(src).ir_value(loc=loc, ip=ip),
|
| 712 |
+
cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip),
|
| 713 |
+
],
|
| 714 |
+
"{\n\t"
|
| 715 |
+
".reg .b32 tmp, ra;\n\t"
|
| 716 |
+
".reg .b8 byte0, byte1, byte2, byte3;\n\t"
|
| 717 |
+
"prmt.b32 tmp, $3, 0, 0;\n\t"
|
| 718 |
+
"mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t"
|
| 719 |
+
"mov.b32 ra, {byte0, byte1, _, _};\n\t"
|
| 720 |
+
"mul.e4m3x4.e2m1x4.e4m3x4.satfinite $0, ra, tmp;\n\t"
|
| 721 |
+
"mov.b32 ra, {_, _, byte2, byte3};\n\t"
|
| 722 |
+
"mul.e4m3x4.e2m1x4.e4m3x4.satfinite $1, ra, tmp;\n\t"
|
| 723 |
+
"}\n",
|
| 724 |
+
"=r,=r,r,r",
|
| 725 |
+
has_side_effects=False,
|
| 726 |
+
is_align_stack=False,
|
| 727 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 728 |
+
)
|
| 729 |
+
out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
|
| 730 |
+
out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
|
| 731 |
+
return out0, out1
|
| 732 |
+
|
| 733 |
+
out = llvm.inline_asm(
|
| 734 |
+
llvm.StructType.get_literal([T.i32(), T.i32()]),
|
| 735 |
+
[
|
| 736 |
+
cutlass.Int32(src).ir_value(loc=loc, ip=ip),
|
| 737 |
+
cutlass.Int32(scale_e4m3).ir_value(loc=loc, ip=ip),
|
| 738 |
+
],
|
| 739 |
+
"{\n\t"
|
| 740 |
+
".reg .b32 sf_bytes, sf_f16x2;\n\t"
|
| 741 |
+
".reg .b16 sf_pair, e0, e1, e2, e3;\n\t"
|
| 742 |
+
".reg .b8 byte0, byte1, byte2, byte3;\n\t"
|
| 743 |
+
".reg .b32 h0, h1, h2, h3;\n\t"
|
| 744 |
+
"prmt.b32 sf_bytes, $3, 0, 0;\n\t"
|
| 745 |
+
"mov.b32 {sf_pair, _}, sf_bytes;\n\t"
|
| 746 |
+
"cvt.rn.f16x2.e4m3x2 sf_f16x2, sf_pair;\n\t"
|
| 747 |
+
"mov.b32 {byte0, byte1, byte2, byte3}, $2;\n\t"
|
| 748 |
+
"cvt.rn.f16x2.e2m1x2 h0, byte0;\n\t"
|
| 749 |
+
"cvt.rn.f16x2.e2m1x2 h1, byte1;\n\t"
|
| 750 |
+
"cvt.rn.f16x2.e2m1x2 h2, byte2;\n\t"
|
| 751 |
+
"cvt.rn.f16x2.e2m1x2 h3, byte3;\n\t"
|
| 752 |
+
"mul.rn.f16x2 h0, h0, sf_f16x2;\n\t"
|
| 753 |
+
"mul.rn.f16x2 h1, h1, sf_f16x2;\n\t"
|
| 754 |
+
"mul.rn.f16x2 h2, h2, sf_f16x2;\n\t"
|
| 755 |
+
"mul.rn.f16x2 h3, h3, sf_f16x2;\n\t"
|
| 756 |
+
"cvt.rn.satfinite.e4m3x2.f16x2 e0, h0;\n\t"
|
| 757 |
+
"cvt.rn.satfinite.e4m3x2.f16x2 e1, h1;\n\t"
|
| 758 |
+
"cvt.rn.satfinite.e4m3x2.f16x2 e2, h2;\n\t"
|
| 759 |
+
"cvt.rn.satfinite.e4m3x2.f16x2 e3, h3;\n\t"
|
| 760 |
+
"mov.b32 $0, {e0, e1};\n\t"
|
| 761 |
+
"mov.b32 $1, {e2, e3};\n\t"
|
| 762 |
+
"}\n",
|
| 763 |
+
"=r,=r,r,r",
|
| 764 |
+
has_side_effects=False,
|
| 765 |
+
is_align_stack=False,
|
| 766 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 767 |
+
)
|
| 768 |
+
out0 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [0], loc=loc, ip=ip))
|
| 769 |
+
out1 = cutlass.Int32(llvm.extractvalue(T.i32(), out, [1], loc=loc, ip=ip))
|
| 770 |
+
return out0, out1
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
@dsl_user_op
|
| 774 |
+
def cvt_f16x2_to_bf16x2(
|
| 775 |
+
src: cutlass.Int32,
|
| 776 |
+
*,
|
| 777 |
+
loc=None,
|
| 778 |
+
ip=None,
|
| 779 |
+
) -> cutlass.Int32:
|
| 780 |
+
"""Convert a packed f16x2 register into a packed bf16x2 register."""
|
| 781 |
+
|
| 782 |
+
return cutlass.Int32(
|
| 783 |
+
llvm.inline_asm(
|
| 784 |
+
T.i32(),
|
| 785 |
+
[cutlass.Int32(src).ir_value(loc=loc, ip=ip)],
|
| 786 |
+
"{\n\t"
|
| 787 |
+
".reg .b16 h0, h1;\n\t"
|
| 788 |
+
".reg .f32 f0, f1;\n\t"
|
| 789 |
+
"mov.b32 {h0, h1}, $1;\n\t"
|
| 790 |
+
"cvt.f32.f16 f0, h0;\n\t"
|
| 791 |
+
"cvt.f32.f16 f1, h1;\n\t"
|
| 792 |
+
"cvt.rn.bf16x2.f32 $0, f1, f0;\n\t"
|
| 793 |
+
"}\n",
|
| 794 |
+
"=r,r",
|
| 795 |
+
has_side_effects=False,
|
| 796 |
+
is_align_stack=False,
|
| 797 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 798 |
+
)
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
@dsl_user_op
|
| 803 |
+
def mul_bf16x2(
|
| 804 |
+
a: cutlass.Int32,
|
| 805 |
+
b: cutlass.Int32,
|
| 806 |
+
*,
|
| 807 |
+
loc=None,
|
| 808 |
+
ip=None,
|
| 809 |
+
) -> cutlass.Int32:
|
| 810 |
+
"""Multiply two packed bf16x2 registers."""
|
| 811 |
+
|
| 812 |
+
return cutlass.Int32(
|
| 813 |
+
llvm.inline_asm(
|
| 814 |
+
T.i32(),
|
| 815 |
+
[
|
| 816 |
+
cutlass.Int32(a).ir_value(loc=loc, ip=ip),
|
| 817 |
+
cutlass.Int32(b).ir_value(loc=loc, ip=ip),
|
| 818 |
+
],
|
| 819 |
+
"mul.rn.bf16x2 $0, $1, $2;",
|
| 820 |
+
"=r,r,r",
|
| 821 |
+
has_side_effects=False,
|
| 822 |
+
is_align_stack=False,
|
| 823 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 824 |
+
)
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
@cute.jit
|
| 829 |
+
def cvt_fp8_e4m3_to_bf16x2_replicated(src: cutlass.Int32) -> cutlass.Int32:
|
| 830 |
+
"""Decode one E4M3 byte and replicate it into a packed bf16x2 register."""
|
| 831 |
+
|
| 832 |
+
src_u8 = src & cutlass.Int32(0xFF)
|
| 833 |
+
packed = src_u8 * cutlass.Int32(0x01010101)
|
| 834 |
+
out0, _ = cvt_fp8x4_e4m3_bf16x4(packed)
|
| 835 |
+
return out0
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
@overload
|
| 839 |
+
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
@overload
|
| 843 |
+
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
@cute.jit
|
| 847 |
+
def cvt_f16(src: cute.Tensor, dst_or_dtype):
|
| 848 |
+
"""Convert Float32 tensor to Float16/BFloat16.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
src: Source tensor with Float32 element type
|
| 852 |
+
dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
None if dst is a tensor, or a new tensor if dtype is provided
|
| 856 |
+
"""
|
| 857 |
+
if const_expr(isinstance(dst_or_dtype, type)):
|
| 858 |
+
# dtype variant: create new tensor and call the tensor variant
|
| 859 |
+
dtype = dst_or_dtype
|
| 860 |
+
dst = cute.make_rmem_tensor(src.shape, dtype)
|
| 861 |
+
cvt_f16(src, dst)
|
| 862 |
+
return dst
|
| 863 |
+
else:
|
| 864 |
+
# tensor variant: write to dst
|
| 865 |
+
dst = dst_or_dtype
|
| 866 |
+
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
|
| 867 |
+
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
|
| 868 |
+
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
|
| 869 |
+
"dst must be BFloat16 or Float16"
|
| 870 |
+
)
|
| 871 |
+
assert src.element_type is Float32, "src must be Float32"
|
| 872 |
+
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
|
| 873 |
+
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
|
| 874 |
+
for i in cutlass.range_constexpr(cute.size(dst_i32)):
|
| 875 |
+
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
@cute.jit
|
| 879 |
+
def cvt_f32(src: cute.Tensor, dst: cute.Tensor) -> None:
|
| 880 |
+
"""Convert a Float32 rmem tensor to dst's element type.
|
| 881 |
+
|
| 882 |
+
fp8 path uses the reference fp8 quantize pattern: fragment-by-fragment
|
| 883 |
+
``.store(.load().to(fp8))`` over groups of ``frg_tile=4``. This lets the
|
| 884 |
+
DSL emit ``cvt.rn.satfinite.e4m3x2.f32`` pairs and pack the resulting fp8
|
| 885 |
+
bytes within a 32-bit register cell in the order DSL chooses, which is
|
| 886 |
+
expected to match the K-adjacency that SM100 fp8 UMMA fragment_A reads.
|
| 887 |
+
"""
|
| 888 |
+
if const_expr(dst.element_type in [cutlass.BFloat16, cutlass.Float16]):
|
| 889 |
+
cvt_f16(src, dst)
|
| 890 |
+
elif const_expr(dst.element_type is cutlass.Float8E4M3FN):
|
| 891 |
+
assert src.element_type is Float32, "src must be Float32"
|
| 892 |
+
assert cute.size(src.shape) == cute.size(dst.shape), "dst and src must have the same size"
|
| 893 |
+
assert cute.size(src.shape) % 4 == 0, "src must have a multiple of 4 elements"
|
| 894 |
+
frg_tile = 4
|
| 895 |
+
src_frg = cute.logical_divide(src, cute.make_layout(frg_tile))
|
| 896 |
+
dst_frg = cute.logical_divide(dst, cute.make_layout(frg_tile))
|
| 897 |
+
for i in cutlass.range_constexpr(cute.size(src_frg, mode=[1])):
|
| 898 |
+
dst_frg[None, i].store(src_frg[None, i].load().to(dst.element_type))
|
| 899 |
+
else:
|
| 900 |
+
assert src.element_type is Float32, "src must be Float32"
|
| 901 |
+
dst_view = cute.make_tensor(dst.iterator, src.layout)
|
| 902 |
+
dst_view.store(src.load().to(dst.element_type))
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
@dsl_user_op
|
| 906 |
+
@cute.jit
|
| 907 |
+
def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
|
| 908 |
+
deg = len(poly) - 1
|
| 909 |
+
out = poly[deg]
|
| 910 |
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
| 911 |
+
out = out * x + poly[i]
|
| 912 |
+
return out
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
@dsl_user_op
|
| 916 |
+
@cute.jit
|
| 917 |
+
def evaluate_polynomial_2(
|
| 918 |
+
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
|
| 919 |
+
) -> Tuple[Float32, Float32]:
|
| 920 |
+
deg = len(poly) - 1
|
| 921 |
+
out = (poly[deg], poly[deg])
|
| 922 |
+
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
| 923 |
+
out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
|
| 924 |
+
return out
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
@dsl_user_op
|
| 928 |
+
def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
|
| 929 |
+
# There's probably a way to call llvm or nvvm to do this instead of ptx
|
| 930 |
+
return cutlass.Float32(
|
| 931 |
+
llvm.inline_asm(
|
| 932 |
+
T.f32(),
|
| 933 |
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
|
| 934 |
+
"add.rm.ftz.f32 $0, $1, $2;",
|
| 935 |
+
"=f,f,f",
|
| 936 |
+
has_side_effects=False,
|
| 937 |
+
is_align_stack=False,
|
| 938 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 939 |
+
)
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
@dsl_user_op
|
| 944 |
+
def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
|
| 945 |
+
return cutlass.Float32(
|
| 946 |
+
llvm.inline_asm(
|
| 947 |
+
T.f32(),
|
| 948 |
+
[
|
| 949 |
+
Float32(x_rounded).ir_value(loc=loc, ip=ip),
|
| 950 |
+
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
|
| 951 |
+
],
|
| 952 |
+
"{\n\t"
|
| 953 |
+
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
|
| 954 |
+
"mov.b32 x_rounded_i, $1;\n\t"
|
| 955 |
+
"mov.b32 frac_ex_i, $2;\n\t"
|
| 956 |
+
"shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
|
| 957 |
+
# add.u32 generates IMAD instruction and add.s32 generates LEA instruction
|
| 958 |
+
# IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
|
| 959 |
+
"add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
|
| 960 |
+
"mov.b32 $0, out_i;\n\t"
|
| 961 |
+
"}\n",
|
| 962 |
+
"=f,f,f",
|
| 963 |
+
has_side_effects=False,
|
| 964 |
+
is_align_stack=False,
|
| 965 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 966 |
+
)
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
@dsl_user_op
|
| 971 |
+
def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
|
| 972 |
+
assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
|
| 973 |
+
# We assume x <= 127.0
|
| 974 |
+
fp32_round_int = float(2**23 + 2**22)
|
| 975 |
+
x_clamped = cute.arch.fmax(x, -127.0)
|
| 976 |
+
# We want to round down here, so that the fractional part is in [0, 1)
|
| 977 |
+
x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
|
| 978 |
+
# The integer floor of x is now in the last 8 bits of x_rounded
|
| 979 |
+
# We assume the next 2 ops round to nearest even. The rounding mode is important.
|
| 980 |
+
x_rounded_back = x_rounded - fp32_round_int
|
| 981 |
+
x_frac = x_clamped - x_rounded_back
|
| 982 |
+
x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
|
| 983 |
+
return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
@dsl_user_op
|
| 987 |
+
def ex2_emulation_2(
|
| 988 |
+
x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
|
| 989 |
+
) -> Tuple[Float32, Float32]:
|
| 990 |
+
# We assume x <= 127.0 and y <= 127.0
|
| 991 |
+
fp32_round_int = float(2**23 + 2**22)
|
| 992 |
+
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
|
| 993 |
+
# We want to round down here, so that the fractional part is in [0, 1)
|
| 994 |
+
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
|
| 995 |
+
# The integer floor of x & y are now in the last 8 bits of xy_rounded
|
| 996 |
+
# We want the next 2 ops to round to nearest even. The rounding mode is important.
|
| 997 |
+
xy_rounded_back = activation.sub_packed_f32x2(
|
| 998 |
+
xy_rounded, (fp32_round_int, fp32_round_int)
|
| 999 |
+
)
|
| 1000 |
+
xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
|
| 1001 |
+
xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
|
| 1002 |
+
x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
|
| 1003 |
+
y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
|
| 1004 |
+
return x_out, y_out
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
@dsl_user_op
|
| 1008 |
+
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
| 1009 |
+
out_f32x2 = llvm.inline_asm(
|
| 1010 |
+
llvm.StructType.get_literal([T.f32(), T.f32()]),
|
| 1011 |
+
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
|
| 1012 |
+
"{\n\t"
|
| 1013 |
+
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
|
| 1014 |
+
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
|
| 1015 |
+
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
|
| 1016 |
+
"max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
|
| 1017 |
+
"max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
|
| 1018 |
+
"mov.b64 l1, {f1, f2};\n\t"
|
| 1019 |
+
"mov.f32 f3, 0f4B400000;\n\t"
|
| 1020 |
+
"mov.b64 l2, {f3, f3};\n\t"
|
| 1021 |
+
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
|
| 1022 |
+
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
|
| 1023 |
+
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
|
| 1024 |
+
"mov.f32 f7, 0f3D9DF09D;\n\t"
|
| 1025 |
+
"mov.b64 l6, {f7, f7};\n\t"
|
| 1026 |
+
"mov.f32 f6, 0f3E6906A4;\n\t"
|
| 1027 |
+
"mov.b64 l5, {f6, f6};\n\t"
|
| 1028 |
+
"mov.f32 f5, 0f3F31F519;\n\t"
|
| 1029 |
+
"mov.b64 l4, {f5, f5};\n\t"
|
| 1030 |
+
"mov.f32 f4, 0f3F800000;\n\t"
|
| 1031 |
+
"mov.b64 l3, {f4, f4};\n\t"
|
| 1032 |
+
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
|
| 1033 |
+
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
|
| 1034 |
+
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
|
| 1035 |
+
"mov.b64 {r1, r2}, l7;\n\t"
|
| 1036 |
+
"mov.b64 {r3, r4}, l10;\n\t"
|
| 1037 |
+
"shl.b32 r5, r1, 23;\n\t"
|
| 1038 |
+
"add.s32 r7, r5, r3;\n\t"
|
| 1039 |
+
"shl.b32 r6, r2, 23;\n\t"
|
| 1040 |
+
"add.s32 r8, r6, r4;\n\t"
|
| 1041 |
+
"mov.b32 $0, r7;\n\t"
|
| 1042 |
+
"mov.b32 $1, r8;\n\t"
|
| 1043 |
+
"}\n",
|
| 1044 |
+
"=r,=r,f,f",
|
| 1045 |
+
has_side_effects=False,
|
| 1046 |
+
is_align_stack=False,
|
| 1047 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 1048 |
+
)
|
| 1049 |
+
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
|
| 1050 |
+
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
|
| 1051 |
+
return out0, out1
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
@dsl_user_op
|
| 1055 |
+
def domain_offset_aligned(
|
| 1056 |
+
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
|
| 1057 |
+
) -> cute.Tensor:
|
| 1058 |
+
assert isinstance(tensor.iterator, cute.Pointer)
|
| 1059 |
+
# We assume that applying the offset does not change the pointer alignment
|
| 1060 |
+
new_ptr = cute.make_ptr(
|
| 1061 |
+
tensor.element_type,
|
| 1062 |
+
elem_pointer(tensor, coord).toint(),
|
| 1063 |
+
tensor.memspace,
|
| 1064 |
+
assumed_align=tensor.iterator.alignment,
|
| 1065 |
+
)
|
| 1066 |
+
return cute.make_tensor(new_ptr, tensor.layout)
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
@cute.jit
|
| 1070 |
+
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
|
| 1071 |
+
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
|
| 1072 |
+
vec = cute.make_rmem_tensor(1, dtype)
|
| 1073 |
+
vec[0] = a
|
| 1074 |
+
return vec.load()
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
def ssa_to_scalar(val):
|
| 1078 |
+
"""Could inline but nice for reflecting the above api"""
|
| 1079 |
+
return val[0]
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
# ------------------------------------------------------------------
|
| 1083 |
+
# Host-side Python helpers (not @cute.jit — called from PyTorch host code)
|
| 1084 |
+
# ------------------------------------------------------------------
|
| 1085 |
+
|
| 1086 |
+
def default_softmax_scale(dim: int) -> float:
|
| 1087 |
+
"""Default softmax scale: 1 / sqrt(dim)."""
|
| 1088 |
+
return 1.0 / math.sqrt(dim)
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""SM100 sparse attention kernels."""
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/build_k2q_csr/__init__.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""q2k -> k2q CSR builder backed by the precompiled Torch ops.
|
| 5 |
+
|
| 6 |
+
The CUDA implementation lives in ``csrc/build_k2q_csr.cu`` and is built
|
| 7 |
+
ahead of time by kernel-builder; it is reached through the ``_ops``
|
| 8 |
+
namespace instead of being JIT-compiled at import time.
|
| 9 |
+
|
| 10 |
+
The kernel pipeline is tuned and verified for SM100; other
|
| 11 |
+
architectures are not supported.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from ...._ops import ops
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_build_k2q_csr(
|
| 22 |
+
q2k: torch.Tensor,
|
| 23 |
+
cu_seqlens_q: torch.Tensor,
|
| 24 |
+
cu_seqlens_k: torch.Tensor,
|
| 25 |
+
row_ptr: torch.Tensor,
|
| 26 |
+
q_idx: torch.Tensor,
|
| 27 |
+
topk: int,
|
| 28 |
+
blk_kv: int,
|
| 29 |
+
total_rows: int,
|
| 30 |
+
max_kv_blocks: int,
|
| 31 |
+
) -> None:
|
| 32 |
+
"""In-place fill of ``row_ptr`` and ``q_idx``.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
q2k: int32 [H, total_q, topK] contiguous (CUDA).
|
| 36 |
+
cu_seqlens_q: int32 [B+1] contiguous (CUDA).
|
| 37 |
+
cu_seqlens_k: int32 [B+1] contiguous (CUDA).
|
| 38 |
+
row_ptr: int32 [H, total_rows + 1] CUDA, written in place.
|
| 39 |
+
q_idx: int32 [H, total_q * topK] CUDA, written in place
|
| 40 |
+
(trailing slots set to -1).
|
| 41 |
+
topk: must be in {4, 8, 16, 32}.
|
| 42 |
+
blk_kv: must equal 128.
|
| 43 |
+
total_rows: sum over batches of ceil(seqlen_k / blk_kv).
|
| 44 |
+
max_kv_blocks: max over batches of ceil(seqlen_k / blk_kv); upper bound
|
| 45 |
+
used to size the row_map workspace and clamp valid kv ids.
|
| 46 |
+
"""
|
| 47 |
+
ops.run_build_k2q_csr(
|
| 48 |
+
q2k,
|
| 49 |
+
cu_seqlens_q,
|
| 50 |
+
cu_seqlens_k,
|
| 51 |
+
row_ptr,
|
| 52 |
+
q_idx,
|
| 53 |
+
int(topk),
|
| 54 |
+
int(blk_kv),
|
| 55 |
+
int(total_rows),
|
| 56 |
+
int(max_kv_blocks),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_build_k2q_csr_with_schedule(
|
| 61 |
+
q2k: torch.Tensor,
|
| 62 |
+
cu_seqlens_q: torch.Tensor,
|
| 63 |
+
cu_seqlens_k: torch.Tensor,
|
| 64 |
+
row_ptr: torch.Tensor,
|
| 65 |
+
q_idx: torch.Tensor,
|
| 66 |
+
scheduler_metadata: torch.Tensor,
|
| 67 |
+
work_count: torch.Tensor,
|
| 68 |
+
qsplit_idx: torch.Tensor,
|
| 69 |
+
split_counts: torch.Tensor,
|
| 70 |
+
topk: int,
|
| 71 |
+
blk_kv: int,
|
| 72 |
+
total_rows: int,
|
| 73 |
+
max_kv_blocks: int,
|
| 74 |
+
target_q_per_cta: int,
|
| 75 |
+
work_capacity: int,
|
| 76 |
+
max_seqlen_q: int,
|
| 77 |
+
) -> None:
|
| 78 |
+
"""In-place fill of CSR plus fused sparse attention schedule metadata."""
|
| 79 |
+
ops.run_build_k2q_csr_with_schedule(
|
| 80 |
+
q2k,
|
| 81 |
+
cu_seqlens_q,
|
| 82 |
+
cu_seqlens_k,
|
| 83 |
+
row_ptr,
|
| 84 |
+
q_idx,
|
| 85 |
+
scheduler_metadata,
|
| 86 |
+
work_count,
|
| 87 |
+
qsplit_idx,
|
| 88 |
+
split_counts,
|
| 89 |
+
int(topk),
|
| 90 |
+
int(blk_kv),
|
| 91 |
+
int(total_rows),
|
| 92 |
+
int(max_kv_blocks),
|
| 93 |
+
int(target_q_per_cta),
|
| 94 |
+
int(work_capacity),
|
| 95 |
+
int(max_seqlen_q),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def is_supported(topk: int, blk_kv: int) -> bool:
|
| 100 |
+
return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/decode_schedule.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Split-KV schedule for paged fp8 decode attention.
|
| 5 |
+
|
| 6 |
+
The public PageKV representation remains this repo's rectangular page table:
|
| 7 |
+
``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only
|
| 8 |
+
describes how query tiles and KV chunks are split into work items.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class DecodeAttentionSchedule:
|
| 21 |
+
split_kv: bool
|
| 22 |
+
cta_tile_q: int
|
| 23 |
+
num_q_tiles: int
|
| 24 |
+
kv_chunk_size_pages: int
|
| 25 |
+
kv_chunk_size_tokens: int
|
| 26 |
+
work_count: int
|
| 27 |
+
padded_work_count: int
|
| 28 |
+
partial_rows: int
|
| 29 |
+
max_split_count: int
|
| 30 |
+
max_grid_size: int
|
| 31 |
+
active_blocks_per_sm: int
|
| 32 |
+
num_sms: int
|
| 33 |
+
base_cta: int
|
| 34 |
+
request_indices: torch.Tensor
|
| 35 |
+
qo_tile_indices: torch.Tensor
|
| 36 |
+
kv_tile_indices: torch.Tensor
|
| 37 |
+
merge_indptr: torch.Tensor
|
| 38 |
+
o_indptr: torch.Tensor
|
| 39 |
+
block_valid_mask: torch.Tensor
|
| 40 |
+
kv_pages: torch.Tensor
|
| 41 |
+
split_counts: torch.Tensor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None:
|
| 45 |
+
if tensor.dtype != torch.int32:
|
| 46 |
+
raise TypeError(f"{name} must be torch.int32")
|
| 47 |
+
if tensor.ndim != 1:
|
| 48 |
+
raise ValueError(f"{name} must be rank-1")
|
| 49 |
+
if not tensor.is_cuda:
|
| 50 |
+
raise ValueError(f"{name} must be a CUDA tensor")
|
| 51 |
+
if not tensor.is_contiguous():
|
| 52 |
+
raise ValueError(f"{name} must be contiguous")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def prepare_decode_schedule(
|
| 56 |
+
*,
|
| 57 |
+
seqused_k: torch.Tensor,
|
| 58 |
+
page_size: int,
|
| 59 |
+
seqlen_q: int,
|
| 60 |
+
num_qo_heads: int,
|
| 61 |
+
num_kv_heads: int,
|
| 62 |
+
head_dim: int,
|
| 63 |
+
max_seqlen_k: int,
|
| 64 |
+
enable_cuda_graph: bool = False,
|
| 65 |
+
max_grid_size: Optional[int] = None,
|
| 66 |
+
fixed_split_size: Optional[int] = None,
|
| 67 |
+
disable_split_kv: bool = False,
|
| 68 |
+
) -> DecodeAttentionSchedule:
|
| 69 |
+
"""Build paged decode split-KV schedule on the GPU.
|
| 70 |
+
|
| 71 |
+
A single CUDA kernel reads ``seqused_k`` on device and writes all
|
| 72 |
+
schedule index arrays. Only a small summary tensor is D2H-synced so
|
| 73 |
+
the wrapper can size O_partial / pick the kernel grid / choose the
|
| 74 |
+
split-vs-non-split compile path.
|
| 75 |
+
|
| 76 |
+
``max_seqlen_k`` is the host-side worst-case bound used to pad the
|
| 77 |
+
work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``.
|
| 78 |
+
"""
|
| 79 |
+
_require_i32_cuda_1d(seqused_k, name="seqused_k")
|
| 80 |
+
# Hard cap: current single-CTA schedule kernel stores per-batch state
|
| 81 |
+
# in shared memory. Larger batches require a multi-CTA cooperative
|
| 82 |
+
# scheduler (unimplemented). Fail fast at the Python boundary so the
|
| 83 |
+
# error doesn't surface from inside the CUDA extension.
|
| 84 |
+
if int(seqused_k.shape[0]) > 1024:
|
| 85 |
+
raise NotImplementedError(
|
| 86 |
+
"decode schedule currently supports batch <= 1024 "
|
| 87 |
+
f"(got batch={int(seqused_k.shape[0])}). Larger batches need "
|
| 88 |
+
"the multi-CTA scheduler — not yet implemented."
|
| 89 |
+
)
|
| 90 |
+
# Two API-boundary checks tied to the kernel's packed-GQA layout
|
| 91 |
+
# (q_tokens_per_group = m_block_size / qhead_per_kv = 128/16 = 8):
|
| 92 |
+
#
|
| 93 |
+
# (1) seqused_k[b] >= seqlen_q. The kernel computes the causal mask as
|
| 94 |
+
# col_limit = row_idx + seqlen_k - seqlen_q + 1. For row 0 (first
|
| 95 |
+
# q-token in the packed group) this is col_limit = seqlen_k - seqlen_q
|
| 96 |
+
# + 1, which goes <= 0 whenever seqlen_k < seqlen_q. That all-masked
|
| 97 |
+
# row then enters a mask-codegen path with PTX-undefined shift counts
|
| 98 |
+
# and the kernel hangs. The condition is also semantically invalid
|
| 99 |
+
# in batched-decode: you can't emit seqlen_q new tokens with fewer
|
| 100 |
+
# than seqlen_q total context tokens (seqlen_k includes them).
|
| 101 |
+
#
|
| 102 |
+
# (2) seqused_k[b] % page_size ∈ {0, 8, 16, ..., 120}. Same hang fires
|
| 103 |
+
# when the LAST partial page has < q_tokens_per_group=8 valid
|
| 104 |
+
# columns, because then the *last MMA tile* hits the same all-masked
|
| 105 |
+
# row case for the trailing q-tokens.
|
| 106 |
+
#
|
| 107 |
+
# Both are tracked as a separate kernel-level TODO (un-pack the
|
| 108 |
+
# all-masked row → skip mask call, or saturate causal_col_limit at >= 1
|
| 109 |
+
# in mask.py). Until then, fail fast at the Python boundary with a
|
| 110 |
+
# clear message rather than letting the kernel timeout.
|
| 111 |
+
seqlen_q_i = int(seqlen_q)
|
| 112 |
+
bad_q = seqused_k < seqlen_q_i
|
| 113 |
+
if bool(bad_q.any().item()):
|
| 114 |
+
bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item())
|
| 115 |
+
bad_val = int(seqused_k[bad_idx].item())
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) "
|
| 118 |
+
f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. "
|
| 119 |
+
f"This is also a batched-decode invariant: seqlen_k must include "
|
| 120 |
+
f"the seqlen_q new tokens being emitted."
|
| 121 |
+
)
|
| 122 |
+
rem = seqused_k % int(page_size)
|
| 123 |
+
bad_rem = (rem > 0) & (rem < seqlen_q_i)
|
| 124 |
+
if bool(bad_rem.any().item()):
|
| 125 |
+
bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item())
|
| 126 |
+
bad_val = int(seqused_k[bad_idx].item())
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"decode kernel requires seqused_k[b] % page_size ∈ "
|
| 129 |
+
f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {(page_size//seqlen_q_i)*seqlen_q_i}}}. "
|
| 130 |
+
f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has "
|
| 131 |
+
f"{bad_val % int(page_size)} valid columns (< seqlen_q={seqlen_q_i}). "
|
| 132 |
+
f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to "
|
| 133 |
+
f"a multiple of {page_size}."
|
| 134 |
+
)
|
| 135 |
+
if int(page_size) <= 0:
|
| 136 |
+
raise ValueError("page_size must be positive")
|
| 137 |
+
if int(seqlen_q) <= 0:
|
| 138 |
+
raise ValueError("seqlen_q must be positive")
|
| 139 |
+
if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0:
|
| 140 |
+
raise ValueError("head counts must be positive")
|
| 141 |
+
if int(num_qo_heads) % int(num_kv_heads) != 0:
|
| 142 |
+
raise ValueError("num_qo_heads must be divisible by num_kv_heads")
|
| 143 |
+
if int(num_qo_heads) // int(num_kv_heads) != 16:
|
| 144 |
+
raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16")
|
| 145 |
+
if int(head_dim) != 128:
|
| 146 |
+
raise NotImplementedError("decode schedule currently supports only head_dim=128")
|
| 147 |
+
if int(max_seqlen_k) <= 0:
|
| 148 |
+
raise ValueError("max_seqlen_k must be positive")
|
| 149 |
+
|
| 150 |
+
from ...src.sm100.fwd_decode.build_decode_schedule import build_decode_schedule
|
| 151 |
+
|
| 152 |
+
raw = build_decode_schedule(
|
| 153 |
+
seqused_k,
|
| 154 |
+
page_size=int(page_size),
|
| 155 |
+
seqlen_q=int(seqlen_q),
|
| 156 |
+
num_qo_heads=int(num_qo_heads),
|
| 157 |
+
num_kv_heads=int(num_kv_heads),
|
| 158 |
+
head_dim=int(head_dim),
|
| 159 |
+
max_seqlen_k=int(max_seqlen_k),
|
| 160 |
+
enable_cuda_graph=bool(enable_cuda_graph),
|
| 161 |
+
max_grid_size=0 if max_grid_size is None else int(max_grid_size),
|
| 162 |
+
fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size),
|
| 163 |
+
disable_split_kv=bool(disable_split_kv),
|
| 164 |
+
)
|
| 165 |
+
return DecodeAttentionSchedule(
|
| 166 |
+
split_kv=bool(raw["split_kv"]),
|
| 167 |
+
cta_tile_q=int(raw["cta_tile_q"]),
|
| 168 |
+
num_q_tiles=int(raw["num_q_tiles"]),
|
| 169 |
+
kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]),
|
| 170 |
+
kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]),
|
| 171 |
+
work_count=int(raw["work_count"]),
|
| 172 |
+
padded_work_count=int(raw["padded_work_count"]),
|
| 173 |
+
partial_rows=int(raw["partial_rows"]),
|
| 174 |
+
max_split_count=int(raw["max_split_count"]),
|
| 175 |
+
max_grid_size=int(raw["max_grid_size"]),
|
| 176 |
+
active_blocks_per_sm=int(raw["active_blocks_per_sm"]),
|
| 177 |
+
num_sms=int(raw["num_sms"]),
|
| 178 |
+
base_cta=int(raw["base_cta"]),
|
| 179 |
+
request_indices=raw["request_indices"],
|
| 180 |
+
qo_tile_indices=raw["qo_tile_indices"],
|
| 181 |
+
kv_tile_indices=raw["kv_tile_indices"],
|
| 182 |
+
merge_indptr=raw["merge_indptr"],
|
| 183 |
+
o_indptr=raw["o_indptr"],
|
| 184 |
+
block_valid_mask=raw["block_valid_mask"],
|
| 185 |
+
kv_pages=raw["kv_pages"],
|
| 186 |
+
split_counts=raw["split_counts"],
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
__all__ = [
|
| 191 |
+
"DecodeAttentionSchedule",
|
| 192 |
+
"prepare_decode_schedule",
|
| 193 |
+
]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fp4_indexer.py
ADDED
|
@@ -0,0 +1,1956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""SM100 FP4 sparse-attention indexer kernels."""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Literal
|
| 10 |
+
|
| 11 |
+
import cuda.bindings.driver as cuda
|
| 12 |
+
import cutlass
|
| 13 |
+
import cutlass.cute as cute
|
| 14 |
+
import cutlass.pipeline as pipeline
|
| 15 |
+
import cutlass.utils as utils
|
| 16 |
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
| 17 |
+
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
| 18 |
+
import torch
|
| 19 |
+
from cutlass import Float32, Int32, const_expr
|
| 20 |
+
from cutlass.cute.nvgpu import cpasync, tcgen05
|
| 21 |
+
|
| 22 |
+
from ...src.common import pipeline as common_pipeline
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
FP4_FORMAT = Literal["mxfp4", "nvfp4"]
|
| 26 |
+
_FP4_PACKED_D_BYTES = 64
|
| 27 |
+
_HEAD_DIM = 128
|
| 28 |
+
_BLOCK_K = 128
|
| 29 |
+
_PAGE_SIZE = 128
|
| 30 |
+
_MMA_TILER_MN = (128, 128)
|
| 31 |
+
_MMA_INST_SHAPE_K = 64
|
| 32 |
+
_NON_CAUSAL_K_TILES_PER_CTA = 16
|
| 33 |
+
_CAUSAL_K_TILES_PER_CTA = 16
|
| 34 |
+
_DECODE_PACK_Q_LEN = 8
|
| 35 |
+
_DECODE_QHEAD_PER_KV = 16
|
| 36 |
+
_DECODE_K_TILES_PER_CTA = 16
|
| 37 |
+
_AB_DTYPE = cutlass.Float4E2M1FN
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass(frozen=True)
|
| 41 |
+
class Fp4FormatSpec:
|
| 42 |
+
name: FP4_FORMAT
|
| 43 |
+
sf_vec_size: int
|
| 44 |
+
scale_groups: int
|
| 45 |
+
torch_scale_dtype: torch.dtype
|
| 46 |
+
cutlass_scale_dtype: type
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_FORMAT_SPECS: dict[str, Fp4FormatSpec] = {
|
| 50 |
+
"mxfp4": Fp4FormatSpec(
|
| 51 |
+
name="mxfp4",
|
| 52 |
+
sf_vec_size=32,
|
| 53 |
+
scale_groups=4,
|
| 54 |
+
torch_scale_dtype=torch.float8_e8m0fnu,
|
| 55 |
+
cutlass_scale_dtype=cutlass.Float8E8M0FNU,
|
| 56 |
+
),
|
| 57 |
+
"nvfp4": Fp4FormatSpec(
|
| 58 |
+
name="nvfp4",
|
| 59 |
+
sf_vec_size=16,
|
| 60 |
+
scale_groups=8,
|
| 61 |
+
torch_scale_dtype=torch.float8_e4m3fn,
|
| 62 |
+
cutlass_scale_dtype=cutlass.Float8E4M3FN,
|
| 63 |
+
),
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_fp4_format(fmt: str) -> Fp4FormatSpec:
|
| 68 |
+
key = str(fmt).lower()
|
| 69 |
+
try:
|
| 70 |
+
return _FORMAT_SPECS[key]
|
| 71 |
+
except KeyError as exc:
|
| 72 |
+
raise ValueError(f"format must be one of {sorted(_FORMAT_SPECS)}, got {fmt!r}") from exc
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def ceil_div(x: int, y: int) -> int:
|
| 76 |
+
return (int(x) + int(y) - 1) // int(y)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def k_tiles_per_cta_for(causal: bool) -> int:
|
| 80 |
+
return _CAUSAL_K_TILES_PER_CTA if bool(causal) else _NON_CAUSAL_K_TILES_PER_CTA
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Fp4IndexerScaleReorderSm100:
|
| 84 |
+
"""Reorder public FP4 indexer scales to the 1CTA blockscaled MMA layout."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, *, fmt: str):
|
| 87 |
+
spec = normalize_fp4_format(fmt)
|
| 88 |
+
self.fmt = spec.name
|
| 89 |
+
self.sf_dtype = spec.cutlass_scale_dtype
|
| 90 |
+
self.scale_groups = spec.scale_groups
|
| 91 |
+
self.threads_per_cta = 256
|
| 92 |
+
|
| 93 |
+
@cute.jit
|
| 94 |
+
def __call__(
|
| 95 |
+
self,
|
| 96 |
+
q_scale_ptr: cute.Pointer,
|
| 97 |
+
k_scale_ptr: cute.Pointer,
|
| 98 |
+
q_scale_mma_ptr: cute.Pointer,
|
| 99 |
+
k_scale_mma_ptr: cute.Pointer,
|
| 100 |
+
problem_size: tuple,
|
| 101 |
+
stream: cuda.CUstream,
|
| 102 |
+
):
|
| 103 |
+
total_q, heads_q, page_count, heads_k = problem_size
|
| 104 |
+
rest_q_m = cute.ceil_div(total_q, 128)
|
| 105 |
+
rest_g = cute.ceil_div(self.scale_groups, 4)
|
| 106 |
+
k_l = page_count * heads_k
|
| 107 |
+
|
| 108 |
+
q_scale = cute.make_tensor(
|
| 109 |
+
q_scale_ptr,
|
| 110 |
+
cute.make_layout(
|
| 111 |
+
(total_q, heads_q, self.scale_groups),
|
| 112 |
+
stride=(heads_q * self.scale_groups, self.scale_groups, 1),
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
k_scale = cute.make_tensor(
|
| 116 |
+
k_scale_ptr,
|
| 117 |
+
cute.make_layout(
|
| 118 |
+
(page_count, heads_k, _PAGE_SIZE, self.scale_groups),
|
| 119 |
+
stride=(
|
| 120 |
+
heads_k * _PAGE_SIZE * self.scale_groups,
|
| 121 |
+
_PAGE_SIZE * self.scale_groups,
|
| 122 |
+
self.scale_groups,
|
| 123 |
+
1,
|
| 124 |
+
),
|
| 125 |
+
),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
q_mma_layout = cute.make_ordered_layout(
|
| 129 |
+
(32, 4, rest_q_m, 4, rest_g, heads_q),
|
| 130 |
+
order=(2, 1, 4, 0, 3, 5),
|
| 131 |
+
)
|
| 132 |
+
k_mma_layout = cute.make_ordered_layout(
|
| 133 |
+
(32, 4, 1, 4, rest_g, k_l),
|
| 134 |
+
order=(2, 1, 4, 0, 3, 5),
|
| 135 |
+
)
|
| 136 |
+
q_scale_mma = cute.make_tensor(q_scale_mma_ptr, q_mma_layout)
|
| 137 |
+
k_scale_mma = cute.make_tensor(k_scale_mma_ptr, k_mma_layout)
|
| 138 |
+
q_scale_mma = cute.group_modes(q_scale_mma, 0, 3)
|
| 139 |
+
q_scale_mma = cute.group_modes(q_scale_mma, 1, 3)
|
| 140 |
+
k_scale_mma = cute.group_modes(k_scale_mma, 0, 3)
|
| 141 |
+
k_scale_mma = cute.group_modes(k_scale_mma, 1, 3)
|
| 142 |
+
|
| 143 |
+
q_scale_count = total_q * heads_q * Int32(self.scale_groups)
|
| 144 |
+
k_scale_count = page_count * heads_k * Int32(_PAGE_SIZE * self.scale_groups)
|
| 145 |
+
total_scale_count = q_scale_count + k_scale_count
|
| 146 |
+
grid_ctas = cute.ceil_div(total_scale_count, self.threads_per_cta)
|
| 147 |
+
self.kernel(
|
| 148 |
+
q_scale,
|
| 149 |
+
k_scale,
|
| 150 |
+
q_scale_mma,
|
| 151 |
+
k_scale_mma,
|
| 152 |
+
heads_q,
|
| 153 |
+
heads_k,
|
| 154 |
+
q_scale_count,
|
| 155 |
+
total_scale_count,
|
| 156 |
+
).launch(
|
| 157 |
+
grid=(grid_ctas, 1, 1),
|
| 158 |
+
block=[self.threads_per_cta, 1, 1],
|
| 159 |
+
cluster=(1, 1, 1),
|
| 160 |
+
stream=stream,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
@cute.kernel
|
| 164 |
+
def kernel(
|
| 165 |
+
self,
|
| 166 |
+
q_scale: cute.Tensor,
|
| 167 |
+
k_scale: cute.Tensor,
|
| 168 |
+
q_scale_mma: cute.Tensor,
|
| 169 |
+
k_scale_mma: cute.Tensor,
|
| 170 |
+
heads_q: Int32,
|
| 171 |
+
heads_k: Int32,
|
| 172 |
+
q_scale_count: Int32,
|
| 173 |
+
total_scale_count: Int32,
|
| 174 |
+
):
|
| 175 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 176 |
+
block_idx, _, _ = cute.arch.block_idx()
|
| 177 |
+
grid_dim, _, _ = cute.arch.grid_dim()
|
| 178 |
+
linear = block_idx * Int32(self.threads_per_cta) + tidx
|
| 179 |
+
stride = grid_dim * Int32(self.threads_per_cta)
|
| 180 |
+
|
| 181 |
+
while linear < total_scale_count:
|
| 182 |
+
if linear < q_scale_count:
|
| 183 |
+
group = linear % Int32(self.scale_groups)
|
| 184 |
+
tmp = linear // Int32(self.scale_groups)
|
| 185 |
+
head = tmp % heads_q
|
| 186 |
+
row = tmp // heads_q
|
| 187 |
+
q_scale_mma[row, group, head] = q_scale[row, head, group]
|
| 188 |
+
else:
|
| 189 |
+
k_linear = linear - q_scale_count
|
| 190 |
+
group = k_linear % Int32(self.scale_groups)
|
| 191 |
+
tmp = k_linear // Int32(self.scale_groups)
|
| 192 |
+
row = tmp % Int32(_PAGE_SIZE)
|
| 193 |
+
tmp = tmp // Int32(_PAGE_SIZE)
|
| 194 |
+
head = tmp % heads_k
|
| 195 |
+
page = tmp // heads_k
|
| 196 |
+
scale_l = page * heads_k + head
|
| 197 |
+
k_scale_mma[row, group, scale_l] = k_scale[page, head, row, group]
|
| 198 |
+
linear += stride
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Fp4IndexerStagedMmaSm100:
|
| 202 |
+
"""Single-kernel FP4 indexer for preordered MMA scale storage."""
|
| 203 |
+
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
*,
|
| 207 |
+
fmt: str,
|
| 208 |
+
causal: bool,
|
| 209 |
+
preordered_q_scale_tma: bool = False,
|
| 210 |
+
compact_schedule: bool = False,
|
| 211 |
+
use_tmem_load_red: bool = False,
|
| 212 |
+
):
|
| 213 |
+
spec = normalize_fp4_format(fmt)
|
| 214 |
+
self.fmt = spec.name
|
| 215 |
+
self.is_causal = bool(causal)
|
| 216 |
+
self.preordered_q_scale_tma = bool(preordered_q_scale_tma)
|
| 217 |
+
self.compact_schedule = bool(compact_schedule)
|
| 218 |
+
self.use_tmem_load_red = bool(use_tmem_load_red)
|
| 219 |
+
self.sf_vec_size = spec.sf_vec_size
|
| 220 |
+
self.sf_dtype = spec.cutlass_scale_dtype
|
| 221 |
+
self.scale_groups = spec.scale_groups
|
| 222 |
+
self.use_nvfp4 = spec.name == "nvfp4"
|
| 223 |
+
self.epi_threads_per_cta = 128
|
| 224 |
+
self.epi_warps_per_group = 4
|
| 225 |
+
self.num_epi_warpgroups = 2
|
| 226 |
+
self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups
|
| 227 |
+
self.load_warp_id = self.mma_warp_id + 1
|
| 228 |
+
self.threads_per_cta = 384
|
| 229 |
+
self.num_tmem_alloc_cols = 512
|
| 230 |
+
self.num_q_stage = 1
|
| 231 |
+
self.num_acc_stage = 3
|
| 232 |
+
self.num_ab_stage = 3
|
| 233 |
+
self.k_tiles_per_cta = k_tiles_per_cta_for(self.is_causal)
|
| 234 |
+
|
| 235 |
+
@cute.jit
|
| 236 |
+
def __call__(
|
| 237 |
+
self,
|
| 238 |
+
q_ptr: cute.Pointer,
|
| 239 |
+
k_ptr: cute.Pointer,
|
| 240 |
+
q_scale_ptr: cute.Pointer,
|
| 241 |
+
k_scale_ptr: cute.Pointer,
|
| 242 |
+
scores_ptr: cute.Pointer,
|
| 243 |
+
kv_indices_ptr: cute.Pointer,
|
| 244 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 245 |
+
cu_seqlens_k_ptr: cute.Pointer,
|
| 246 |
+
cu_page_offsets_ptr: cute.Pointer,
|
| 247 |
+
qo_offset_ptr: cute.Pointer,
|
| 248 |
+
problem_size: tuple,
|
| 249 |
+
stream: cuda.CUstream,
|
| 250 |
+
):
|
| 251 |
+
(
|
| 252 |
+
m,
|
| 253 |
+
_,
|
| 254 |
+
k,
|
| 255 |
+
_,
|
| 256 |
+
lk,
|
| 257 |
+
heads_q,
|
| 258 |
+
heads_k,
|
| 259 |
+
batch,
|
| 260 |
+
max_k_tiles,
|
| 261 |
+
total_q,
|
| 262 |
+
has_qo_offset,
|
| 263 |
+
compact_task_count,
|
| 264 |
+
) = problem_size
|
| 265 |
+
page_count = lk // heads_k
|
| 266 |
+
self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2)
|
| 267 |
+
self.cta_tile_shape_mnk = self.mma_tiler
|
| 268 |
+
|
| 269 |
+
q_tma_tensor = cute.make_tensor(
|
| 270 |
+
cute.recast_ptr(q_ptr, dtype=_AB_DTYPE),
|
| 271 |
+
cute.make_layout(
|
| 272 |
+
(total_q, _HEAD_DIM, heads_q),
|
| 273 |
+
stride=(heads_q * _HEAD_DIM, 1, _HEAD_DIM),
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
k_tma_tensor = cute.make_tensor(
|
| 277 |
+
cute.recast_ptr(k_ptr, dtype=_AB_DTYPE),
|
| 278 |
+
cute.make_layout(
|
| 279 |
+
(_PAGE_SIZE, _HEAD_DIM, heads_k, page_count),
|
| 280 |
+
stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM),
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
q_scale_tensor = cute.make_tensor(
|
| 284 |
+
q_scale_ptr,
|
| 285 |
+
blockscaled_utils.tile_atom_to_shape_SF(
|
| 286 |
+
(total_q, _HEAD_DIM, heads_q),
|
| 287 |
+
self.sf_vec_size,
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
k_scale_tensor = cute.make_tensor(
|
| 291 |
+
k_scale_ptr,
|
| 292 |
+
blockscaled_utils.tile_atom_to_shape_SF(
|
| 293 |
+
(_PAGE_SIZE, _HEAD_DIM, page_count * heads_k),
|
| 294 |
+
self.sf_vec_size,
|
| 295 |
+
),
|
| 296 |
+
)
|
| 297 |
+
scores_tensor = cute.make_tensor(
|
| 298 |
+
scores_ptr,
|
| 299 |
+
cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)),
|
| 300 |
+
)
|
| 301 |
+
kv_indices_tensor = cute.make_tensor(
|
| 302 |
+
kv_indices_ptr,
|
| 303 |
+
cute.make_layout((page_count,), stride=(1,)),
|
| 304 |
+
)
|
| 305 |
+
cu_layout = cute.make_layout((batch + 1,), stride=(1,))
|
| 306 |
+
cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout)
|
| 307 |
+
cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout)
|
| 308 |
+
cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout)
|
| 309 |
+
qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,)))
|
| 310 |
+
|
| 311 |
+
if const_expr(self.use_nvfp4):
|
| 312 |
+
mma_op = tcgen05.MmaMXF4NVF4Op(
|
| 313 |
+
self.sf_dtype,
|
| 314 |
+
(*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
|
| 315 |
+
tcgen05.CtaGroup.ONE,
|
| 316 |
+
tcgen05.OperandSource.SMEM,
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
mma_op = tcgen05.MmaMXF4Op(
|
| 320 |
+
(*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
|
| 321 |
+
tcgen05.CtaGroup.ONE,
|
| 322 |
+
tcgen05.OperandSource.SMEM,
|
| 323 |
+
)
|
| 324 |
+
tiled_mma = cute.make_tiled_mma(mma_op)
|
| 325 |
+
q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage)
|
| 326 |
+
k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage)
|
| 327 |
+
q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa(
|
| 328 |
+
tiled_mma,
|
| 329 |
+
self.mma_tiler,
|
| 330 |
+
self.sf_vec_size,
|
| 331 |
+
self.num_q_stage,
|
| 332 |
+
)
|
| 333 |
+
k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb(
|
| 334 |
+
tiled_mma,
|
| 335 |
+
self.mma_tiler,
|
| 336 |
+
self.sf_vec_size,
|
| 337 |
+
self.num_ab_stage,
|
| 338 |
+
)
|
| 339 |
+
cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
| 340 |
+
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
| 341 |
+
q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0))
|
| 342 |
+
k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0))
|
| 343 |
+
tma_q = cute.nvgpu.make_tiled_tma_atom_A(
|
| 344 |
+
tma_load_op,
|
| 345 |
+
q_tma_tensor,
|
| 346 |
+
q_smem_layout_stage,
|
| 347 |
+
self.mma_tiler,
|
| 348 |
+
tiled_mma,
|
| 349 |
+
cluster_layout_vmnk.shape,
|
| 350 |
+
)
|
| 351 |
+
tma_k = cute.nvgpu.make_tiled_tma_atom_B(
|
| 352 |
+
tma_load_op,
|
| 353 |
+
k_tma_tensor,
|
| 354 |
+
k_smem_layout_stage,
|
| 355 |
+
self.mma_tiler,
|
| 356 |
+
tiled_mma,
|
| 357 |
+
cluster_layout_vmnk.shape,
|
| 358 |
+
)
|
| 359 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 360 |
+
tma_qs = cute.nvgpu.make_tiled_tma_atom_A(
|
| 361 |
+
tma_load_op,
|
| 362 |
+
q_scale_tensor,
|
| 363 |
+
q_scale_smem_layout,
|
| 364 |
+
self.mma_tiler,
|
| 365 |
+
tiled_mma,
|
| 366 |
+
cluster_layout_vmnk.shape,
|
| 367 |
+
internal_type=cutlass.Int16,
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
tma_qs = tma_q
|
| 371 |
+
tma_ks = cute.nvgpu.make_tiled_tma_atom_B(
|
| 372 |
+
tma_load_op,
|
| 373 |
+
k_scale_tensor,
|
| 374 |
+
k_scale_smem_layout,
|
| 375 |
+
self.mma_tiler,
|
| 376 |
+
tiled_mma,
|
| 377 |
+
cluster_layout_vmnk.shape,
|
| 378 |
+
internal_type=cutlass.Int16,
|
| 379 |
+
)
|
| 380 |
+
grid_q_tiles = cute.ceil_div(m, self.cta_tile_shape_mnk[0])
|
| 381 |
+
grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta)
|
| 382 |
+
if const_expr(self.compact_schedule):
|
| 383 |
+
grid_x = compact_task_count
|
| 384 |
+
else:
|
| 385 |
+
grid_x = grid_q_tiles * grid_k_groups
|
| 386 |
+
self.kernel(
|
| 387 |
+
tiled_mma,
|
| 388 |
+
tma_q,
|
| 389 |
+
tma_qs,
|
| 390 |
+
tma_k,
|
| 391 |
+
tma_ks,
|
| 392 |
+
q_scale_tensor,
|
| 393 |
+
k_scale_tensor,
|
| 394 |
+
scores_tensor,
|
| 395 |
+
kv_indices_tensor,
|
| 396 |
+
cu_q_tensor,
|
| 397 |
+
cu_k_tensor,
|
| 398 |
+
cu_page_offsets_tensor,
|
| 399 |
+
qo_offset_tensor,
|
| 400 |
+
q_smem_layout,
|
| 401 |
+
k_smem_layout,
|
| 402 |
+
q_scale_smem_layout,
|
| 403 |
+
k_scale_smem_layout,
|
| 404 |
+
heads_q,
|
| 405 |
+
heads_k,
|
| 406 |
+
has_qo_offset,
|
| 407 |
+
max_k_tiles,
|
| 408 |
+
grid_k_groups,
|
| 409 |
+
).launch(
|
| 410 |
+
grid=(grid_x, batch * heads_q, 1),
|
| 411 |
+
block=[self.threads_per_cta, 1, 1],
|
| 412 |
+
cluster=(1, 1, 1),
|
| 413 |
+
stream=stream,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
@cute.jit
|
| 417 |
+
def _group_has_visible(
|
| 418 |
+
self,
|
| 419 |
+
q_tile_start: Int32,
|
| 420 |
+
q_tile_last: Int32,
|
| 421 |
+
q_len: Int32,
|
| 422 |
+
group_first_ktile: Int32,
|
| 423 |
+
batch_k_tiles: Int32,
|
| 424 |
+
causal_offset: Int32,
|
| 425 |
+
):
|
| 426 |
+
visible = q_tile_start < q_len and group_first_ktile < batch_k_tiles
|
| 427 |
+
if const_expr(self.is_causal):
|
| 428 |
+
visible = visible and group_first_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
|
| 429 |
+
return visible
|
| 430 |
+
|
| 431 |
+
@cute.jit
|
| 432 |
+
def _tile_has_visible(
|
| 433 |
+
self,
|
| 434 |
+
q_tile_start: Int32,
|
| 435 |
+
q_tile_last: Int32,
|
| 436 |
+
q_len: Int32,
|
| 437 |
+
ktile: Int32,
|
| 438 |
+
batch_k_tiles: Int32,
|
| 439 |
+
causal_offset: Int32,
|
| 440 |
+
):
|
| 441 |
+
visible = q_tile_start < q_len and ktile < batch_k_tiles
|
| 442 |
+
if const_expr(self.is_causal):
|
| 443 |
+
visible = visible and ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
|
| 444 |
+
return visible
|
| 445 |
+
|
| 446 |
+
@cute.jit
|
| 447 |
+
def _tile_mask_free(self, q_tile_start: Int32, ktile: Int32, causal_offset: Int32):
|
| 448 |
+
if const_expr(self.is_causal):
|
| 449 |
+
return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= q_tile_start + causal_offset
|
| 450 |
+
return True
|
| 451 |
+
|
| 452 |
+
@cute.jit
|
| 453 |
+
def _full_tile_coord_visible(
|
| 454 |
+
self,
|
| 455 |
+
coord_m: Int32,
|
| 456 |
+
target_m: Int32,
|
| 457 |
+
q_local: Int32,
|
| 458 |
+
k_local: Int32,
|
| 459 |
+
causal_offset: Int32,
|
| 460 |
+
):
|
| 461 |
+
visible = coord_m == target_m
|
| 462 |
+
if const_expr(self.is_causal):
|
| 463 |
+
visible = visible and k_local <= q_local + causal_offset
|
| 464 |
+
return visible
|
| 465 |
+
|
| 466 |
+
@cute.jit
|
| 467 |
+
def _partial_tile_coord_visible(
|
| 468 |
+
self,
|
| 469 |
+
coord_m: Int32,
|
| 470 |
+
target_m: Int32,
|
| 471 |
+
q_local: Int32,
|
| 472 |
+
k_local: Int32,
|
| 473 |
+
q_len: Int32,
|
| 474 |
+
k_len: Int32,
|
| 475 |
+
causal_offset: Int32,
|
| 476 |
+
):
|
| 477 |
+
visible = coord_m == target_m and q_local < q_len and k_local < k_len
|
| 478 |
+
if const_expr(self.is_causal):
|
| 479 |
+
visible = visible and k_local <= q_local + causal_offset
|
| 480 |
+
return visible
|
| 481 |
+
|
| 482 |
+
@cute.kernel
|
| 483 |
+
def kernel(
|
| 484 |
+
self,
|
| 485 |
+
tiled_mma: cute.TiledMma,
|
| 486 |
+
tma_q: cpasync.TmaInfo,
|
| 487 |
+
tma_qs: cpasync.TmaInfo,
|
| 488 |
+
tma_k: cpasync.TmaInfo,
|
| 489 |
+
tma_ks: cpasync.TmaInfo,
|
| 490 |
+
mQS: cute.Tensor,
|
| 491 |
+
mKS: cute.Tensor,
|
| 492 |
+
mScores: cute.Tensor,
|
| 493 |
+
mKvIndices: cute.Tensor,
|
| 494 |
+
mCuQ: cute.Tensor,
|
| 495 |
+
mCuK: cute.Tensor,
|
| 496 |
+
mCuPages: cute.Tensor,
|
| 497 |
+
mQoOffset: cute.Tensor,
|
| 498 |
+
q_smem_layout: cute.ComposedLayout,
|
| 499 |
+
k_smem_layout: cute.ComposedLayout,
|
| 500 |
+
q_scale_smem_layout: cute.Layout,
|
| 501 |
+
k_scale_smem_layout: cute.Layout,
|
| 502 |
+
heads_q: Int32,
|
| 503 |
+
heads_k: Int32,
|
| 504 |
+
has_qo_offset: Int32,
|
| 505 |
+
max_k_tiles: Int32,
|
| 506 |
+
k_group_count: Int32,
|
| 507 |
+
):
|
| 508 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 509 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 510 |
+
lane_idx = cute.arch.lane_idx()
|
| 511 |
+
epi_tidx = tidx % Int32(self.epi_threads_per_cta)
|
| 512 |
+
epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group)
|
| 513 |
+
task_idx, q_l, _ = cute.arch.block_idx()
|
| 514 |
+
batch_idx = q_l // heads_q
|
| 515 |
+
hq = q_l - batch_idx * heads_q
|
| 516 |
+
hk = hq // (heads_q // heads_k)
|
| 517 |
+
q_begin = mCuQ[batch_idx]
|
| 518 |
+
q_end = mCuQ[batch_idx + 1]
|
| 519 |
+
k_begin = mCuK[batch_idx]
|
| 520 |
+
k_end = mCuK[batch_idx + 1]
|
| 521 |
+
q_len = q_end - q_begin
|
| 522 |
+
k_len = k_end - k_begin
|
| 523 |
+
page_begin = mCuPages[batch_idx]
|
| 524 |
+
batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE)
|
| 525 |
+
causal_offset = Int32(0)
|
| 526 |
+
if const_expr(self.is_causal):
|
| 527 |
+
causal_offset = k_len - q_len
|
| 528 |
+
if has_qo_offset != 0:
|
| 529 |
+
causal_offset = mQoOffset[batch_idx]
|
| 530 |
+
task_valid = True
|
| 531 |
+
q_tile_idx = Int32(0)
|
| 532 |
+
ktile_group = Int32(0)
|
| 533 |
+
if const_expr(self.compact_schedule):
|
| 534 |
+
remaining = task_idx
|
| 535 |
+
q_tile_count = (q_len + Int32(self.cta_tile_shape_mnk[0] - 1)) // Int32(self.cta_tile_shape_mnk[0])
|
| 536 |
+
batch_k_group_count = (batch_k_tiles + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta)
|
| 537 |
+
q_scan = Int32(0)
|
| 538 |
+
task_valid = False
|
| 539 |
+
while q_scan < q_tile_count and not task_valid:
|
| 540 |
+
q_scan_start = q_scan * Int32(self.cta_tile_shape_mnk[0])
|
| 541 |
+
q_scan_last = q_scan_start + Int32(self.cta_tile_shape_mnk[0] - 1)
|
| 542 |
+
if q_scan_last >= q_len:
|
| 543 |
+
q_scan_last = q_len - Int32(1)
|
| 544 |
+
visible_limit = q_scan_last + causal_offset
|
| 545 |
+
visible_group_count = Int32(0)
|
| 546 |
+
if visible_limit >= Int32(0):
|
| 547 |
+
visible_group_count = visible_limit // Int32(self.k_tiles_per_cta * _BLOCK_K) + Int32(1)
|
| 548 |
+
if visible_group_count > batch_k_group_count:
|
| 549 |
+
visible_group_count = batch_k_group_count
|
| 550 |
+
task_valid = remaining < visible_group_count
|
| 551 |
+
if not task_valid:
|
| 552 |
+
remaining -= visible_group_count
|
| 553 |
+
q_scan += Int32(1)
|
| 554 |
+
if task_valid:
|
| 555 |
+
q_tile_idx = q_scan
|
| 556 |
+
ktile_group = remaining
|
| 557 |
+
else:
|
| 558 |
+
q_len = Int32(0)
|
| 559 |
+
k_len = Int32(0)
|
| 560 |
+
else:
|
| 561 |
+
q_tile_idx = task_idx // k_group_count
|
| 562 |
+
ktile_group = task_idx - q_tile_idx * k_group_count
|
| 563 |
+
q_tile_start = q_tile_idx * Int32(self.cta_tile_shape_mnk[0])
|
| 564 |
+
q_tile_last = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1)
|
| 565 |
+
if q_tile_last >= q_len:
|
| 566 |
+
q_tile_last = q_len - Int32(1)
|
| 567 |
+
q_tile_full = q_tile_start + Int32(self.cta_tile_shape_mnk[0] - 1) < q_len
|
| 568 |
+
q_tile_global_start = q_begin + q_tile_start
|
| 569 |
+
q_scale_tma_safe = q_tile_global_start == (q_tile_global_start // Int32(128)) * Int32(128)
|
| 570 |
+
group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta)
|
| 571 |
+
group_has_visible = self._group_has_visible(
|
| 572 |
+
q_tile_start,
|
| 573 |
+
q_tile_last,
|
| 574 |
+
q_len,
|
| 575 |
+
group_first_ktile,
|
| 576 |
+
batch_k_tiles,
|
| 577 |
+
causal_offset,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
@cute.struct
|
| 581 |
+
class SharedStorage:
|
| 582 |
+
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
| 583 |
+
q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
| 584 |
+
qs_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
| 585 |
+
k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
| 586 |
+
tmem_holding_buf: cutlass.Int32
|
| 587 |
+
|
| 588 |
+
smem = utils.SmemAllocator()
|
| 589 |
+
storage = smem.allocate(SharedStorage)
|
| 590 |
+
sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner)
|
| 591 |
+
sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner)
|
| 592 |
+
sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128)
|
| 593 |
+
sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128)
|
| 594 |
+
mQ_tma = tma_q.tma_tensor
|
| 595 |
+
mQS_tma = tma_qs.tma_tensor
|
| 596 |
+
mK_tma = tma_k.tma_tensor
|
| 597 |
+
mKS_tma = tma_ks.tma_tensor
|
| 598 |
+
thr_mma = tiled_mma.get_slice(0)
|
| 599 |
+
tCsQ = thr_mma.partition_A(sQ_public)
|
| 600 |
+
tCsK = thr_mma.partition_B(sK_public)
|
| 601 |
+
mQ_tma_cur = cute.domain_offset((q_begin, 0, 0), mQ_tma)
|
| 602 |
+
gQ_tma = cute.local_tile(
|
| 603 |
+
mQ_tma_cur,
|
| 604 |
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
| 605 |
+
(None, None, None),
|
| 606 |
+
)
|
| 607 |
+
tCgQ_tma = thr_mma.partition_A(gQ_tma)
|
| 608 |
+
tQsQ_tma, tQgQ_tma = cpasync.tma_partition(
|
| 609 |
+
tma_q.atom,
|
| 610 |
+
0,
|
| 611 |
+
cute.make_layout(1),
|
| 612 |
+
cute.group_modes(sQ_public, 0, 3),
|
| 613 |
+
cute.group_modes(tCgQ_tma, 0, 3),
|
| 614 |
+
)
|
| 615 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 616 |
+
mQS_tma_cur = cute.domain_offset((q_begin, 0, 0), mQS_tma)
|
| 617 |
+
gQS_tma = cute.local_tile(
|
| 618 |
+
mQS_tma_cur,
|
| 619 |
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
| 620 |
+
(None, None, None),
|
| 621 |
+
)
|
| 622 |
+
tCgQS_tma = thr_mma.partition_A(gQS_tma)
|
| 623 |
+
tQsQS_tma, tQgQS_tma = cpasync.tma_partition(
|
| 624 |
+
tma_qs.atom,
|
| 625 |
+
0,
|
| 626 |
+
cute.make_layout(1),
|
| 627 |
+
cute.group_modes(sQS_public, 0, 3),
|
| 628 |
+
cute.group_modes(tCgQS_tma, 0, 3),
|
| 629 |
+
)
|
| 630 |
+
tQsQS_tma = cute.filter_zeros(tQsQS_tma)
|
| 631 |
+
tQgQS_tma = cute.filter_zeros(tQgQS_tma)
|
| 632 |
+
gK_tma = cute.local_tile(
|
| 633 |
+
mK_tma,
|
| 634 |
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
| 635 |
+
(None, None, None, None),
|
| 636 |
+
)
|
| 637 |
+
tCgK_tma = thr_mma.partition_B(gK_tma)
|
| 638 |
+
tKsK_tma, tKgK_tma = cpasync.tma_partition(
|
| 639 |
+
tma_k.atom,
|
| 640 |
+
0,
|
| 641 |
+
cute.make_layout(1),
|
| 642 |
+
cute.group_modes(sK_public, 0, 3),
|
| 643 |
+
cute.group_modes(tCgK_tma, 0, 3),
|
| 644 |
+
)
|
| 645 |
+
gKS_tma = cute.local_tile(
|
| 646 |
+
mKS_tma,
|
| 647 |
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
| 648 |
+
(None, None, None),
|
| 649 |
+
)
|
| 650 |
+
tCgKS_tma = thr_mma.partition_B(gKS_tma)
|
| 651 |
+
tKsKS_tma, tKgKS_tma = cpasync.tma_partition(
|
| 652 |
+
tma_ks.atom,
|
| 653 |
+
0,
|
| 654 |
+
cute.make_layout(1),
|
| 655 |
+
cute.group_modes(sKS_public, 0, 3),
|
| 656 |
+
cute.group_modes(tCgKS_tma, 0, 3),
|
| 657 |
+
)
|
| 658 |
+
tKsKS_tma = cute.filter_zeros(tKsKS_tma)
|
| 659 |
+
tKgKS_tma = cute.filter_zeros(tKgKS_tma)
|
| 660 |
+
sQS = sQS_public
|
| 661 |
+
sKS = sKS_public
|
| 662 |
+
|
| 663 |
+
tCrQ = tiled_mma.make_fragment_A(sQ_public)
|
| 664 |
+
tCrK = tiled_mma.make_fragment_B(sK_public)
|
| 665 |
+
tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2]))
|
| 666 |
+
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
| 667 |
+
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
| 668 |
+
|
| 669 |
+
tmem = utils.TmemAllocator(
|
| 670 |
+
storage.tmem_holding_buf.ptr,
|
| 671 |
+
barrier_for_retrieve=pipeline.NamedBarrier(
|
| 672 |
+
barrier_id=1,
|
| 673 |
+
num_threads=32 * (self.mma_warp_id + 1),
|
| 674 |
+
),
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
acc_pipeline = common_pipeline.PipelineUmmaAsync.create(
|
| 678 |
+
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
| 679 |
+
num_stages=self.num_acc_stage,
|
| 680 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 681 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta),
|
| 682 |
+
defer_sync=True,
|
| 683 |
+
)
|
| 684 |
+
acc_producer, _ = acc_pipeline.make_participants()
|
| 685 |
+
q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout)
|
| 686 |
+
k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout)
|
| 687 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 688 |
+
qs_tma_copy_bytes = cute.size_in_bytes(
|
| 689 |
+
self.sf_dtype,
|
| 690 |
+
cute.select(tma_qs.smem_layout, mode=[0, 1, 2]),
|
| 691 |
+
)
|
| 692 |
+
ks_tma_copy_bytes = cute.size_in_bytes(
|
| 693 |
+
self.sf_dtype,
|
| 694 |
+
cute.select(tma_ks.smem_layout, mode=[0, 1, 2]),
|
| 695 |
+
)
|
| 696 |
+
k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes
|
| 697 |
+
q_producer, q_consumer = pipeline.PipelineTmaAsync.create(
|
| 698 |
+
barrier_storage=storage.q_mbar_ptr.data_ptr(),
|
| 699 |
+
num_stages=1,
|
| 700 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 701 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 702 |
+
tx_count=q_tma_copy_bytes,
|
| 703 |
+
defer_sync=True,
|
| 704 |
+
).make_participants()
|
| 705 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 706 |
+
qs_producer, qs_consumer = pipeline.PipelineTmaAsync.create(
|
| 707 |
+
barrier_storage=storage.qs_mbar_ptr.data_ptr(),
|
| 708 |
+
num_stages=1,
|
| 709 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 710 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 711 |
+
tx_count=qs_tma_copy_bytes,
|
| 712 |
+
defer_sync=True,
|
| 713 |
+
).make_participants()
|
| 714 |
+
k_producer, k_consumer = pipeline.PipelineTmaAsync.create(
|
| 715 |
+
barrier_storage=storage.k_mbar_ptr.data_ptr(),
|
| 716 |
+
num_stages=self.num_ab_stage,
|
| 717 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 718 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 719 |
+
tx_count=k_pair_tma_copy_bytes,
|
| 720 |
+
defer_sync=True,
|
| 721 |
+
).make_participants()
|
| 722 |
+
cute.arch.mbarrier_init_fence()
|
| 723 |
+
cute.arch.barrier()
|
| 724 |
+
if warp_idx == self.load_warp_id:
|
| 725 |
+
if group_has_visible:
|
| 726 |
+
q_empty = q_producer.acquire_and_advance()
|
| 727 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 728 |
+
if q_scale_tma_safe:
|
| 729 |
+
qs_empty = qs_producer.acquire_and_advance()
|
| 730 |
+
cute.copy(
|
| 731 |
+
tma_qs.atom,
|
| 732 |
+
tQgQS_tma[(None, q_tile_idx, 0, hq)],
|
| 733 |
+
tQsQS_tma[(None, qs_empty.index)],
|
| 734 |
+
tma_bar_ptr=qs_empty.barrier,
|
| 735 |
+
)
|
| 736 |
+
qs_empty.commit()
|
| 737 |
+
else:
|
| 738 |
+
for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32):
|
| 739 |
+
row = row_base + lane_idx
|
| 740 |
+
q_local = q_tile_start + row
|
| 741 |
+
row_major = row // Int32(32)
|
| 742 |
+
row_atom = row - row_major * Int32(32)
|
| 743 |
+
for group in cutlass.range_constexpr(self.scale_groups):
|
| 744 |
+
group_i = Int32(group)
|
| 745 |
+
mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
|
| 746 |
+
group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
|
| 747 |
+
sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0))
|
| 748 |
+
q_scale_row = q_begin + q_local
|
| 749 |
+
if q_local >= q_len:
|
| 750 |
+
q_scale_row = q_begin
|
| 751 |
+
sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq]
|
| 752 |
+
else:
|
| 753 |
+
for row_base in cutlass.range(0, Int32(self.cta_tile_shape_mnk[0]), 32):
|
| 754 |
+
row = row_base + lane_idx
|
| 755 |
+
q_local = q_tile_start + row
|
| 756 |
+
row_major = row // Int32(32)
|
| 757 |
+
row_atom = row - row_major * Int32(32)
|
| 758 |
+
for group in cutlass.range_constexpr(self.scale_groups):
|
| 759 |
+
group_i = Int32(group)
|
| 760 |
+
mma_k = group_i // Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
|
| 761 |
+
group_in_mma_k = group_i - mma_k * Int32(_MMA_INST_SHAPE_K // self.sf_vec_size)
|
| 762 |
+
sf_coord = ((((row_atom, row_major), Int32(0)), (Int32(0), group_in_mma_k)), Int32(0), mma_k, Int32(0))
|
| 763 |
+
q_scale_row = q_begin + q_local
|
| 764 |
+
if q_local >= q_len:
|
| 765 |
+
q_scale_row = q_begin
|
| 766 |
+
sQS[sf_coord] = mQS[q_scale_row, group_i * Int32(self.sf_vec_size), hq]
|
| 767 |
+
cute.copy(
|
| 768 |
+
tma_q.atom,
|
| 769 |
+
tQgQ_tma[(None, q_tile_idx, 0, hq)],
|
| 770 |
+
tQsQ_tma[(None, q_empty.index)],
|
| 771 |
+
tma_bar_ptr=q_empty.barrier,
|
| 772 |
+
)
|
| 773 |
+
q_empty.commit()
|
| 774 |
+
|
| 775 |
+
if warp_idx == self.mma_warp_id:
|
| 776 |
+
tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
|
| 777 |
+
tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
|
| 778 |
+
# Move block scales into TMEM and issue one FP4 GEMM per visible K tile.
|
| 779 |
+
tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa(
|
| 780 |
+
tiled_mma,
|
| 781 |
+
self.mma_tiler,
|
| 782 |
+
self.sf_vec_size,
|
| 783 |
+
cute.slice_(q_scale_smem_layout, (None, None, None, 0)),
|
| 784 |
+
)
|
| 785 |
+
tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb(
|
| 786 |
+
tiled_mma,
|
| 787 |
+
self.mma_tiler,
|
| 788 |
+
self.sf_vec_size,
|
| 789 |
+
cute.slice_(k_scale_smem_layout, (None, None, None, 0)),
|
| 790 |
+
)
|
| 791 |
+
tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype)
|
| 792 |
+
tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype)
|
| 793 |
+
copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype)
|
| 794 |
+
tCsQS_compact = cute.filter_zeros(sQS)
|
| 795 |
+
tCtQS_compact = cute.filter_zeros(tCtQS)
|
| 796 |
+
tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact)
|
| 797 |
+
thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0)
|
| 798 |
+
tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
| 799 |
+
tiled_copy_s2t_qs,
|
| 800 |
+
thr_copy_s2t_qs.partition_S(tCsQS_compact),
|
| 801 |
+
)
|
| 802 |
+
tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact)
|
| 803 |
+
tCsKS_compact = cute.filter_zeros(sKS)
|
| 804 |
+
tCtKS_compact = cute.filter_zeros(tCtKS)
|
| 805 |
+
tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact)
|
| 806 |
+
thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0)
|
| 807 |
+
tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
| 808 |
+
tiled_copy_s2t_ks,
|
| 809 |
+
thr_copy_s2t_ks.partition_S(tCsKS_compact),
|
| 810 |
+
)
|
| 811 |
+
tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact)
|
| 812 |
+
if group_has_visible:
|
| 813 |
+
q_full = q_consumer.wait_and_advance()
|
| 814 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 815 |
+
if q_scale_tma_safe:
|
| 816 |
+
qs_full = qs_consumer.wait_and_advance()
|
| 817 |
+
qs_full.release()
|
| 818 |
+
q_full.release()
|
| 819 |
+
cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t)
|
| 820 |
+
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
| 821 |
+
q_tile_crd = (None, None, None, 0)
|
| 822 |
+
if const_expr(self.is_causal):
|
| 823 |
+
causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 824 |
+
causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
|
| 825 |
+
causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
|
| 826 |
+
ktile = Int32(0)
|
| 827 |
+
if causal_group_full:
|
| 828 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 829 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 830 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 831 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 832 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 833 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 834 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 835 |
+
acc_empty.commit()
|
| 836 |
+
k_pair_full.release()
|
| 837 |
+
else:
|
| 838 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 839 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 840 |
+
if ktile < max_k_tiles:
|
| 841 |
+
tile_has_visible = self._tile_has_visible(
|
| 842 |
+
q_tile_start,
|
| 843 |
+
q_tile_last,
|
| 844 |
+
q_len,
|
| 845 |
+
ktile,
|
| 846 |
+
batch_k_tiles,
|
| 847 |
+
causal_offset,
|
| 848 |
+
)
|
| 849 |
+
if tile_has_visible:
|
| 850 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 851 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 852 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 853 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 854 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 855 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 856 |
+
acc_empty.commit()
|
| 857 |
+
k_pair_full.release()
|
| 858 |
+
else:
|
| 859 |
+
k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 860 |
+
ktile = Int32(0)
|
| 861 |
+
if k_group_full:
|
| 862 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 863 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 864 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 865 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 866 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 867 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 868 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 869 |
+
acc_empty.commit()
|
| 870 |
+
k_pair_full.release()
|
| 871 |
+
else:
|
| 872 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 873 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 874 |
+
if ktile < batch_k_tiles:
|
| 875 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 876 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 877 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 878 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 879 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 880 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 881 |
+
acc_empty.commit()
|
| 882 |
+
k_pair_full.release()
|
| 883 |
+
acc_producer.tail()
|
| 884 |
+
|
| 885 |
+
if warp_idx == self.load_warp_id:
|
| 886 |
+
if group_has_visible:
|
| 887 |
+
load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 888 |
+
if const_expr(self.is_causal):
|
| 889 |
+
load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
|
| 890 |
+
load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= q_tile_last + causal_offset
|
| 891 |
+
ktile = Int32(0)
|
| 892 |
+
if load_group_full:
|
| 893 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 894 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 895 |
+
k_pair_empty = k_producer.acquire_and_advance()
|
| 896 |
+
physical_page = mKvIndices[page_begin + ktile]
|
| 897 |
+
cute.copy(
|
| 898 |
+
tma_k.atom,
|
| 899 |
+
tKgK_tma[(None, 0, 0, hk, physical_page)],
|
| 900 |
+
tKsK_tma[(None, k_pair_empty.index)],
|
| 901 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 902 |
+
)
|
| 903 |
+
scale_l = physical_page * heads_k + hk
|
| 904 |
+
cute.copy(
|
| 905 |
+
tma_ks.atom,
|
| 906 |
+
tKgKS_tma[(None, 0, 0, scale_l)],
|
| 907 |
+
tKsKS_tma[(None, k_pair_empty.index)],
|
| 908 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 909 |
+
)
|
| 910 |
+
k_pair_empty.commit()
|
| 911 |
+
else:
|
| 912 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 913 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 914 |
+
if ktile < max_k_tiles:
|
| 915 |
+
tile_has_visible = self._tile_has_visible(
|
| 916 |
+
q_tile_start,
|
| 917 |
+
q_tile_last,
|
| 918 |
+
q_len,
|
| 919 |
+
ktile,
|
| 920 |
+
batch_k_tiles,
|
| 921 |
+
causal_offset,
|
| 922 |
+
)
|
| 923 |
+
if tile_has_visible:
|
| 924 |
+
k_pair_empty = k_producer.acquire_and_advance()
|
| 925 |
+
physical_page = mKvIndices[page_begin + ktile]
|
| 926 |
+
cute.copy(
|
| 927 |
+
tma_k.atom,
|
| 928 |
+
tKgK_tma[(None, 0, 0, hk, physical_page)],
|
| 929 |
+
tKsK_tma[(None, k_pair_empty.index)],
|
| 930 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 931 |
+
)
|
| 932 |
+
scale_l = physical_page * heads_k + hk
|
| 933 |
+
cute.copy(
|
| 934 |
+
tma_ks.atom,
|
| 935 |
+
tKgKS_tma[(None, 0, 0, scale_l)],
|
| 936 |
+
tKsKS_tma[(None, k_pair_empty.index)],
|
| 937 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 938 |
+
)
|
| 939 |
+
k_pair_empty.commit()
|
| 940 |
+
k_producer.tail()
|
| 941 |
+
q_producer.tail()
|
| 942 |
+
if const_expr(self.preordered_q_scale_tma):
|
| 943 |
+
if q_scale_tma_safe:
|
| 944 |
+
qs_producer.tail()
|
| 945 |
+
|
| 946 |
+
if warp_idx < self.mma_warp_id:
|
| 947 |
+
tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
|
| 948 |
+
tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
|
| 949 |
+
# Load accumulators from TMEM, reduce per-row max, and store scores.
|
| 950 |
+
if const_expr(self.use_tmem_load_red):
|
| 951 |
+
copy_atom_t2r = cute.make_copy_atom(
|
| 952 |
+
tcgen05.LdRed32x32bOp(
|
| 953 |
+
tcgen05.Repetition.x128,
|
| 954 |
+
tcgen05.Pack.NONE,
|
| 955 |
+
tcgen05.TmemLoadRedOp.MAX,
|
| 956 |
+
),
|
| 957 |
+
Float32,
|
| 958 |
+
)
|
| 959 |
+
else:
|
| 960 |
+
copy_atom_t2r = cute.make_copy_atom(
|
| 961 |
+
tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
|
| 962 |
+
Float32,
|
| 963 |
+
)
|
| 964 |
+
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)])
|
| 965 |
+
thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
|
| 966 |
+
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
|
| 967 |
+
tTR_cC = thr_copy_t2r.partition_D(tCcC)
|
| 968 |
+
tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32)
|
| 969 |
+
if const_expr(self.use_tmem_load_red):
|
| 970 |
+
tTR_rRed = cute.make_rmem_tensor((1,), Float32)
|
| 971 |
+
q_local_store0 = q_tile_start + epi_tidx
|
| 972 |
+
q_global_store0 = q_begin + q_local_store0
|
| 973 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 974 |
+
q_local_store1 = q_tile_start + epi_tidx + Int32(self.epi_threads_per_cta)
|
| 975 |
+
q_global_store1 = q_begin + q_local_store1
|
| 976 |
+
if group_has_visible:
|
| 977 |
+
visible_tile_count = Int32(0)
|
| 978 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 979 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 980 |
+
if ktile < max_k_tiles:
|
| 981 |
+
tile_has_visible = self._tile_has_visible(
|
| 982 |
+
q_tile_start,
|
| 983 |
+
q_tile_last,
|
| 984 |
+
q_len,
|
| 985 |
+
ktile,
|
| 986 |
+
batch_k_tiles,
|
| 987 |
+
causal_offset,
|
| 988 |
+
)
|
| 989 |
+
if tile_has_visible:
|
| 990 |
+
epilogue_owns_tile = epi_warpgroup_idx == Int32(
|
| 991 |
+
ktile_inner % self.num_epi_warpgroups
|
| 992 |
+
)
|
| 993 |
+
if epilogue_owns_tile:
|
| 994 |
+
acc_stage_index = visible_tile_count % Int32(self.num_acc_stage)
|
| 995 |
+
acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2)
|
| 996 |
+
tile_mask_free = self._tile_mask_free(q_tile_start, ktile, causal_offset)
|
| 997 |
+
k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len
|
| 998 |
+
tile_full = q_tile_full and k_tile_full
|
| 999 |
+
acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase)
|
| 1000 |
+
tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)]
|
| 1001 |
+
if const_expr(self.use_tmem_load_red):
|
| 1002 |
+
cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed])
|
| 1003 |
+
else:
|
| 1004 |
+
cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc)
|
| 1005 |
+
row_max0 = -Float32.inf
|
| 1006 |
+
row_max1 = -Float32.inf
|
| 1007 |
+
if tile_mask_free:
|
| 1008 |
+
if tile_full:
|
| 1009 |
+
if const_expr(not self.use_tmem_load_red or self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1010 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1011 |
+
coord_m, _ = tTR_cC[i]
|
| 1012 |
+
if coord_m == epi_tidx:
|
| 1013 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1014 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1015 |
+
if coord_m == epi_tidx + Int32(self.epi_threads_per_cta):
|
| 1016 |
+
row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
|
| 1017 |
+
else:
|
| 1018 |
+
row_max0 = tTR_rRed[0]
|
| 1019 |
+
else:
|
| 1020 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1021 |
+
coord_m, coord_n = tTR_cC[i]
|
| 1022 |
+
q_local = q_tile_start + coord_m
|
| 1023 |
+
k_local = ktile * Int32(_BLOCK_K) + coord_n
|
| 1024 |
+
if coord_m == epi_tidx and q_local < q_len and k_local < k_len:
|
| 1025 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1026 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1027 |
+
if coord_m == epi_tidx + Int32(self.epi_threads_per_cta) and q_local < q_len and k_local < k_len:
|
| 1028 |
+
row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
|
| 1029 |
+
else:
|
| 1030 |
+
if tile_full:
|
| 1031 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1032 |
+
coord_m, coord_n = tTR_cC[i]
|
| 1033 |
+
q_local = q_tile_start + coord_m
|
| 1034 |
+
k_local = ktile * Int32(_BLOCK_K) + coord_n
|
| 1035 |
+
if self._full_tile_coord_visible(
|
| 1036 |
+
coord_m,
|
| 1037 |
+
epi_tidx,
|
| 1038 |
+
q_local,
|
| 1039 |
+
k_local,
|
| 1040 |
+
causal_offset,
|
| 1041 |
+
):
|
| 1042 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1043 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1044 |
+
if self._full_tile_coord_visible(
|
| 1045 |
+
coord_m,
|
| 1046 |
+
epi_tidx + Int32(self.epi_threads_per_cta),
|
| 1047 |
+
q_local,
|
| 1048 |
+
k_local,
|
| 1049 |
+
causal_offset,
|
| 1050 |
+
):
|
| 1051 |
+
row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
|
| 1052 |
+
else:
|
| 1053 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1054 |
+
coord_m, coord_n = tTR_cC[i]
|
| 1055 |
+
q_local = q_tile_start + coord_m
|
| 1056 |
+
k_local = ktile * Int32(_BLOCK_K) + coord_n
|
| 1057 |
+
if self._partial_tile_coord_visible(
|
| 1058 |
+
coord_m,
|
| 1059 |
+
epi_tidx,
|
| 1060 |
+
q_local,
|
| 1061 |
+
k_local,
|
| 1062 |
+
q_len,
|
| 1063 |
+
k_len,
|
| 1064 |
+
causal_offset,
|
| 1065 |
+
):
|
| 1066 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1067 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1068 |
+
if self._partial_tile_coord_visible(
|
| 1069 |
+
coord_m,
|
| 1070 |
+
epi_tidx + Int32(self.epi_threads_per_cta),
|
| 1071 |
+
q_local,
|
| 1072 |
+
k_local,
|
| 1073 |
+
q_len,
|
| 1074 |
+
k_len,
|
| 1075 |
+
causal_offset,
|
| 1076 |
+
):
|
| 1077 |
+
row_max1 = cute.arch.fmax(row_max1, tTR_rAcc[i])
|
| 1078 |
+
if q_tile_full:
|
| 1079 |
+
mScores[hq, ktile, q_global_store0] = row_max0
|
| 1080 |
+
elif q_local_store0 < q_len:
|
| 1081 |
+
mScores[hq, ktile, q_global_store0] = row_max0
|
| 1082 |
+
if const_expr(self.cta_tile_shape_mnk[0] > self.epi_threads_per_cta):
|
| 1083 |
+
if q_tile_full:
|
| 1084 |
+
mScores[hq, ktile, q_global_store1] = row_max1
|
| 1085 |
+
elif q_local_store1 < q_len:
|
| 1086 |
+
mScores[hq, ktile, q_global_store1] = row_max1
|
| 1087 |
+
cute.arch.fence_view_async_tmem_load()
|
| 1088 |
+
acc_pipeline.consumer_release_w_index(acc_stage_index)
|
| 1089 |
+
visible_tile_count += Int32(1)
|
| 1090 |
+
else:
|
| 1091 |
+
if const_expr(not self.compact_schedule):
|
| 1092 |
+
if epi_warpgroup_idx == Int32(0):
|
| 1093 |
+
if q_tile_full:
|
| 1094 |
+
mScores[hq, ktile, q_global_store0] = -Float32.inf
|
| 1095 |
+
elif q_local_store0 < q_len:
|
| 1096 |
+
mScores[hq, ktile, q_global_store0] = -Float32.inf
|
| 1097 |
+
else:
|
| 1098 |
+
if const_expr(not self.compact_schedule):
|
| 1099 |
+
if epi_warpgroup_idx == Int32(0):
|
| 1100 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1101 |
+
ktile = ktile_group * Int32(self.k_tiles_per_cta) + Int32(ktile_inner)
|
| 1102 |
+
if ktile < max_k_tiles:
|
| 1103 |
+
if q_tile_full:
|
| 1104 |
+
mScores[hq, ktile, q_global_store0] = -Float32.inf
|
| 1105 |
+
elif q_local_store0 < q_len:
|
| 1106 |
+
mScores[hq, ktile, q_global_store0] = -Float32.inf
|
| 1107 |
+
cute.arch.barrier()
|
| 1108 |
+
tmem.free(tmem_pool.base_ptr)
|
| 1109 |
+
|
| 1110 |
+
class Fp4IndexerDecodeQPackSm100:
|
| 1111 |
+
"""Pack decode Q rows as ``[B * Hk, 128, 64]`` and pack Q scales to MMA storage."""
|
| 1112 |
+
|
| 1113 |
+
def __init__(self, *, fmt: str):
|
| 1114 |
+
spec = normalize_fp4_format(fmt)
|
| 1115 |
+
self.fmt = spec.name
|
| 1116 |
+
self.sf_dtype = spec.cutlass_scale_dtype
|
| 1117 |
+
self.scale_groups = spec.scale_groups
|
| 1118 |
+
self.threads_per_cta = 256
|
| 1119 |
+
|
| 1120 |
+
@cute.jit
|
| 1121 |
+
def __call__(
|
| 1122 |
+
self,
|
| 1123 |
+
q_ptr: cute.Pointer,
|
| 1124 |
+
q_scale_ptr: cute.Pointer,
|
| 1125 |
+
q_pack_ptr: cute.Pointer,
|
| 1126 |
+
q_scale_pack_ptr: cute.Pointer,
|
| 1127 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 1128 |
+
problem_size: tuple,
|
| 1129 |
+
stream: cuda.CUstream,
|
| 1130 |
+
):
|
| 1131 |
+
total_q, heads_q, heads_k, batch = problem_size
|
| 1132 |
+
rest_q_m = cute.ceil_div(total_q, 128)
|
| 1133 |
+
rest_g = ceil_div(self.scale_groups, 4)
|
| 1134 |
+
q = cute.make_tensor(
|
| 1135 |
+
q_ptr,
|
| 1136 |
+
cute.make_layout(
|
| 1137 |
+
(total_q, heads_q, _FP4_PACKED_D_BYTES),
|
| 1138 |
+
stride=(heads_q * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1),
|
| 1139 |
+
),
|
| 1140 |
+
)
|
| 1141 |
+
q_scale = cute.make_tensor(
|
| 1142 |
+
q_scale_ptr,
|
| 1143 |
+
cute.make_layout(
|
| 1144 |
+
(heads_q, rest_q_m, rest_g, 32, 4, 4),
|
| 1145 |
+
stride=(512 * rest_q_m * rest_g, 512 * rest_g, 512, 16, 4, 1),
|
| 1146 |
+
),
|
| 1147 |
+
)
|
| 1148 |
+
q_pack_l = batch * heads_k
|
| 1149 |
+
q_pack = cute.make_tensor(
|
| 1150 |
+
q_pack_ptr,
|
| 1151 |
+
cute.make_layout(
|
| 1152 |
+
(q_pack_l, _PAGE_SIZE, _FP4_PACKED_D_BYTES),
|
| 1153 |
+
stride=(_PAGE_SIZE * _FP4_PACKED_D_BYTES, _FP4_PACKED_D_BYTES, 1),
|
| 1154 |
+
),
|
| 1155 |
+
)
|
| 1156 |
+
q_scale_pack = cute.make_tensor(
|
| 1157 |
+
q_scale_pack_ptr,
|
| 1158 |
+
cute.make_layout(
|
| 1159 |
+
(q_pack_l, 1, rest_g, 32, 4, 4),
|
| 1160 |
+
stride=(512 * rest_g, 512 * rest_g, 512, 16, 4, 1),
|
| 1161 |
+
),
|
| 1162 |
+
)
|
| 1163 |
+
cu_q = cute.make_tensor(cu_seqlens_q_ptr, cute.make_layout((batch + 1,), stride=(1,)))
|
| 1164 |
+
self.kernel(q, q_scale, q_pack, q_scale_pack, cu_q, heads_q, heads_k).launch(
|
| 1165 |
+
grid=(q_pack_l, 1, 1),
|
| 1166 |
+
block=[self.threads_per_cta, 1, 1],
|
| 1167 |
+
cluster=(1, 1, 1),
|
| 1168 |
+
stream=stream,
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
@cute.kernel
|
| 1172 |
+
def kernel(
|
| 1173 |
+
self,
|
| 1174 |
+
mQ: cute.Tensor,
|
| 1175 |
+
mQS: cute.Tensor,
|
| 1176 |
+
mQPack: cute.Tensor,
|
| 1177 |
+
mQSPack: cute.Tensor,
|
| 1178 |
+
mCuQ: cute.Tensor,
|
| 1179 |
+
heads_q: Int32,
|
| 1180 |
+
heads_k: Int32,
|
| 1181 |
+
):
|
| 1182 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 1183 |
+
q_pack_l, _, _ = cute.arch.block_idx()
|
| 1184 |
+
batch_idx = q_pack_l // heads_k
|
| 1185 |
+
hk = q_pack_l - batch_idx * heads_k
|
| 1186 |
+
q_begin = mCuQ[batch_idx]
|
| 1187 |
+
q_end = mCuQ[batch_idx + 1]
|
| 1188 |
+
q_len = q_end - q_begin
|
| 1189 |
+
qhead_per_kv = heads_q // heads_k
|
| 1190 |
+
|
| 1191 |
+
linear = tidx
|
| 1192 |
+
while linear < Int32(_PAGE_SIZE * _FP4_PACKED_D_BYTES):
|
| 1193 |
+
row = linear // Int32(_FP4_PACKED_D_BYTES)
|
| 1194 |
+
byte = linear - row * Int32(_FP4_PACKED_D_BYTES)
|
| 1195 |
+
h_in_group = row // Int32(_DECODE_PACK_Q_LEN)
|
| 1196 |
+
q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN)
|
| 1197 |
+
hq = hk * qhead_per_kv + h_in_group
|
| 1198 |
+
if q_local < q_len and h_in_group < qhead_per_kv:
|
| 1199 |
+
mQPack[q_pack_l, row, byte] = mQ[q_begin + q_local, hq, byte]
|
| 1200 |
+
else:
|
| 1201 |
+
mQPack[q_pack_l, row, byte] = cutlass.Uint8(0)
|
| 1202 |
+
linear += Int32(self.threads_per_cta)
|
| 1203 |
+
|
| 1204 |
+
scale_linear = tidx
|
| 1205 |
+
while scale_linear < Int32(_PAGE_SIZE * self.scale_groups):
|
| 1206 |
+
row = scale_linear // Int32(self.scale_groups)
|
| 1207 |
+
group = scale_linear - row * Int32(self.scale_groups)
|
| 1208 |
+
h_in_group = row // Int32(_DECODE_PACK_Q_LEN)
|
| 1209 |
+
q_local = row - h_in_group * Int32(_DECODE_PACK_Q_LEN)
|
| 1210 |
+
hq = hk * qhead_per_kv + h_in_group
|
| 1211 |
+
q_abs = q_begin + q_local
|
| 1212 |
+
if q_local >= q_len or h_in_group >= qhead_per_kv:
|
| 1213 |
+
q_abs = q_begin
|
| 1214 |
+
hq = hk * qhead_per_kv
|
| 1215 |
+
src_rest_m = q_abs // Int32(128)
|
| 1216 |
+
src_row = q_abs - src_rest_m * Int32(128)
|
| 1217 |
+
src_row_atom = src_row % Int32(32)
|
| 1218 |
+
src_row_major = src_row // Int32(32)
|
| 1219 |
+
dst_row_atom = row % Int32(32)
|
| 1220 |
+
dst_row_major = row // Int32(32)
|
| 1221 |
+
rest_g = group // Int32(4)
|
| 1222 |
+
group_in_rest = group - rest_g * Int32(4)
|
| 1223 |
+
mQSPack[q_pack_l, Int32(0), rest_g, dst_row_atom, dst_row_major, group_in_rest] = mQS[
|
| 1224 |
+
hq, src_rest_m, rest_g, src_row_atom, src_row_major, group_in_rest
|
| 1225 |
+
]
|
| 1226 |
+
scale_linear += Int32(self.threads_per_cta)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
class Fp4IndexerDecodePackedQSm100:
|
| 1230 |
+
"""Decode score kernel with M packed as ``qhead_per_kv * q_len == 128``."""
|
| 1231 |
+
|
| 1232 |
+
def __init__(self, *, fmt: str, causal: bool, compact_schedule: bool, use_tmem_load_red: bool = False):
|
| 1233 |
+
spec = normalize_fp4_format(fmt)
|
| 1234 |
+
self.fmt = spec.name
|
| 1235 |
+
self.is_causal = bool(causal)
|
| 1236 |
+
self.compact_schedule = bool(compact_schedule)
|
| 1237 |
+
self.use_tmem_load_red = bool(use_tmem_load_red)
|
| 1238 |
+
self.sf_vec_size = spec.sf_vec_size
|
| 1239 |
+
self.sf_dtype = spec.cutlass_scale_dtype
|
| 1240 |
+
self.use_nvfp4 = spec.name == "nvfp4"
|
| 1241 |
+
self.epi_threads_per_cta = 128
|
| 1242 |
+
self.epi_warps_per_group = 4
|
| 1243 |
+
self.num_epi_warpgroups = 2
|
| 1244 |
+
self.mma_warp_id = self.epi_warps_per_group * self.num_epi_warpgroups
|
| 1245 |
+
self.load_warp_id = self.mma_warp_id + 1
|
| 1246 |
+
self.threads_per_cta = 384
|
| 1247 |
+
self.num_tmem_alloc_cols = 512
|
| 1248 |
+
self.num_q_stage = 1
|
| 1249 |
+
self.num_acc_stage = 3
|
| 1250 |
+
self.num_ab_stage = 3
|
| 1251 |
+
self.k_tiles_per_cta = _DECODE_K_TILES_PER_CTA
|
| 1252 |
+
self.mma_tiler = (_MMA_TILER_MN[0], _MMA_TILER_MN[1], _MMA_INST_SHAPE_K * 2)
|
| 1253 |
+
self.cta_tile_shape_mnk = self.mma_tiler
|
| 1254 |
+
|
| 1255 |
+
@cute.jit
|
| 1256 |
+
def __call__(
|
| 1257 |
+
self,
|
| 1258 |
+
q_pack_ptr: cute.Pointer,
|
| 1259 |
+
k_ptr: cute.Pointer,
|
| 1260 |
+
q_scale_pack_ptr: cute.Pointer,
|
| 1261 |
+
k_scale_ptr: cute.Pointer,
|
| 1262 |
+
scores_ptr: cute.Pointer,
|
| 1263 |
+
kv_indices_ptr: cute.Pointer,
|
| 1264 |
+
cu_seqlens_q_ptr: cute.Pointer,
|
| 1265 |
+
cu_seqlens_k_ptr: cute.Pointer,
|
| 1266 |
+
cu_page_offsets_ptr: cute.Pointer,
|
| 1267 |
+
qo_offset_ptr: cute.Pointer,
|
| 1268 |
+
problem_size: tuple,
|
| 1269 |
+
stream: cuda.CUstream,
|
| 1270 |
+
):
|
| 1271 |
+
(
|
| 1272 |
+
_,
|
| 1273 |
+
_,
|
| 1274 |
+
_,
|
| 1275 |
+
_,
|
| 1276 |
+
lk,
|
| 1277 |
+
heads_q,
|
| 1278 |
+
heads_k,
|
| 1279 |
+
batch,
|
| 1280 |
+
max_k_tiles,
|
| 1281 |
+
total_q,
|
| 1282 |
+
has_qo_offset,
|
| 1283 |
+
) = problem_size
|
| 1284 |
+
page_count = lk // heads_k
|
| 1285 |
+
q_pack_l = batch * heads_k
|
| 1286 |
+
q_tma_tensor = cute.make_tensor(
|
| 1287 |
+
cute.recast_ptr(q_pack_ptr, dtype=_AB_DTYPE),
|
| 1288 |
+
cute.make_layout(
|
| 1289 |
+
(_PAGE_SIZE, _HEAD_DIM, q_pack_l),
|
| 1290 |
+
stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM),
|
| 1291 |
+
),
|
| 1292 |
+
)
|
| 1293 |
+
k_tma_tensor = cute.make_tensor(
|
| 1294 |
+
cute.recast_ptr(k_ptr, dtype=_AB_DTYPE),
|
| 1295 |
+
cute.make_layout(
|
| 1296 |
+
(_PAGE_SIZE, _HEAD_DIM, heads_k, page_count),
|
| 1297 |
+
stride=(_HEAD_DIM, 1, _PAGE_SIZE * _HEAD_DIM, heads_k * _PAGE_SIZE * _HEAD_DIM),
|
| 1298 |
+
),
|
| 1299 |
+
)
|
| 1300 |
+
q_scale_tensor = cute.make_tensor(
|
| 1301 |
+
q_scale_pack_ptr,
|
| 1302 |
+
blockscaled_utils.tile_atom_to_shape_SF(
|
| 1303 |
+
(_PAGE_SIZE, _HEAD_DIM, q_pack_l),
|
| 1304 |
+
self.sf_vec_size,
|
| 1305 |
+
),
|
| 1306 |
+
)
|
| 1307 |
+
k_scale_tensor = cute.make_tensor(
|
| 1308 |
+
k_scale_ptr,
|
| 1309 |
+
blockscaled_utils.tile_atom_to_shape_SF(
|
| 1310 |
+
(_PAGE_SIZE, _HEAD_DIM, page_count * heads_k),
|
| 1311 |
+
self.sf_vec_size,
|
| 1312 |
+
),
|
| 1313 |
+
)
|
| 1314 |
+
scores_tensor = cute.make_tensor(
|
| 1315 |
+
scores_ptr,
|
| 1316 |
+
cute.make_layout((heads_q, max_k_tiles, total_q), stride=(max_k_tiles * total_q, total_q, 1)),
|
| 1317 |
+
)
|
| 1318 |
+
kv_indices_tensor = cute.make_tensor(kv_indices_ptr, cute.make_layout((page_count,), stride=(1,)))
|
| 1319 |
+
cu_layout = cute.make_layout((batch + 1,), stride=(1,))
|
| 1320 |
+
cu_q_tensor = cute.make_tensor(cu_seqlens_q_ptr, cu_layout)
|
| 1321 |
+
cu_k_tensor = cute.make_tensor(cu_seqlens_k_ptr, cu_layout)
|
| 1322 |
+
cu_page_offsets_tensor = cute.make_tensor(cu_page_offsets_ptr, cu_layout)
|
| 1323 |
+
qo_offset_tensor = cute.make_tensor(qo_offset_ptr, cute.make_layout((batch,), stride=(1,)))
|
| 1324 |
+
|
| 1325 |
+
if const_expr(self.use_nvfp4):
|
| 1326 |
+
mma_op = tcgen05.MmaMXF4NVF4Op(
|
| 1327 |
+
self.sf_dtype,
|
| 1328 |
+
(*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
|
| 1329 |
+
tcgen05.CtaGroup.ONE,
|
| 1330 |
+
tcgen05.OperandSource.SMEM,
|
| 1331 |
+
)
|
| 1332 |
+
else:
|
| 1333 |
+
mma_op = tcgen05.MmaMXF4Op(
|
| 1334 |
+
(*_MMA_TILER_MN, _MMA_INST_SHAPE_K),
|
| 1335 |
+
tcgen05.CtaGroup.ONE,
|
| 1336 |
+
tcgen05.OperandSource.SMEM,
|
| 1337 |
+
)
|
| 1338 |
+
tiled_mma = cute.make_tiled_mma(mma_op)
|
| 1339 |
+
q_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_q_stage)
|
| 1340 |
+
k_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, _AB_DTYPE, self.num_ab_stage)
|
| 1341 |
+
q_scale_smem_layout = blockscaled_utils.make_smem_layout_sfa(
|
| 1342 |
+
tiled_mma,
|
| 1343 |
+
self.mma_tiler,
|
| 1344 |
+
self.sf_vec_size,
|
| 1345 |
+
self.num_q_stage,
|
| 1346 |
+
)
|
| 1347 |
+
k_scale_smem_layout = blockscaled_utils.make_smem_layout_sfb(
|
| 1348 |
+
tiled_mma,
|
| 1349 |
+
self.mma_tiler,
|
| 1350 |
+
self.sf_vec_size,
|
| 1351 |
+
self.num_ab_stage,
|
| 1352 |
+
)
|
| 1353 |
+
cluster_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
| 1354 |
+
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
| 1355 |
+
q_smem_layout_stage = cute.slice_(q_smem_layout, (None, None, None, 0))
|
| 1356 |
+
k_smem_layout_stage = cute.slice_(k_smem_layout, (None, None, None, 0))
|
| 1357 |
+
tma_q = cute.nvgpu.make_tiled_tma_atom_A(
|
| 1358 |
+
tma_load_op,
|
| 1359 |
+
q_tma_tensor,
|
| 1360 |
+
q_smem_layout_stage,
|
| 1361 |
+
self.mma_tiler,
|
| 1362 |
+
tiled_mma,
|
| 1363 |
+
cluster_layout_vmnk.shape,
|
| 1364 |
+
)
|
| 1365 |
+
tma_k = cute.nvgpu.make_tiled_tma_atom_B(
|
| 1366 |
+
tma_load_op,
|
| 1367 |
+
k_tma_tensor,
|
| 1368 |
+
k_smem_layout_stage,
|
| 1369 |
+
self.mma_tiler,
|
| 1370 |
+
tiled_mma,
|
| 1371 |
+
cluster_layout_vmnk.shape,
|
| 1372 |
+
)
|
| 1373 |
+
tma_qs = cute.nvgpu.make_tiled_tma_atom_A(
|
| 1374 |
+
tma_load_op,
|
| 1375 |
+
q_scale_tensor,
|
| 1376 |
+
q_scale_smem_layout,
|
| 1377 |
+
self.mma_tiler,
|
| 1378 |
+
tiled_mma,
|
| 1379 |
+
cluster_layout_vmnk.shape,
|
| 1380 |
+
internal_type=cutlass.Int16,
|
| 1381 |
+
)
|
| 1382 |
+
tma_ks = cute.nvgpu.make_tiled_tma_atom_B(
|
| 1383 |
+
tma_load_op,
|
| 1384 |
+
k_scale_tensor,
|
| 1385 |
+
k_scale_smem_layout,
|
| 1386 |
+
self.mma_tiler,
|
| 1387 |
+
tiled_mma,
|
| 1388 |
+
cluster_layout_vmnk.shape,
|
| 1389 |
+
internal_type=cutlass.Int16,
|
| 1390 |
+
)
|
| 1391 |
+
grid_k_groups = cute.ceil_div(max_k_tiles, self.k_tiles_per_cta)
|
| 1392 |
+
compact_k_groups = cute.ceil_div(page_count + batch * (self.k_tiles_per_cta - 1), self.k_tiles_per_cta)
|
| 1393 |
+
if const_expr(self.compact_schedule):
|
| 1394 |
+
grid = (compact_k_groups, heads_k, 1)
|
| 1395 |
+
else:
|
| 1396 |
+
grid = (grid_k_groups, batch * heads_k, 1)
|
| 1397 |
+
self.kernel(
|
| 1398 |
+
tiled_mma,
|
| 1399 |
+
tma_q,
|
| 1400 |
+
tma_qs,
|
| 1401 |
+
tma_k,
|
| 1402 |
+
tma_ks,
|
| 1403 |
+
scores_tensor,
|
| 1404 |
+
kv_indices_tensor,
|
| 1405 |
+
cu_q_tensor,
|
| 1406 |
+
cu_k_tensor,
|
| 1407 |
+
cu_page_offsets_tensor,
|
| 1408 |
+
qo_offset_tensor,
|
| 1409 |
+
q_smem_layout,
|
| 1410 |
+
k_smem_layout,
|
| 1411 |
+
q_scale_smem_layout,
|
| 1412 |
+
k_scale_smem_layout,
|
| 1413 |
+
heads_q,
|
| 1414 |
+
heads_k,
|
| 1415 |
+
batch,
|
| 1416 |
+
has_qo_offset,
|
| 1417 |
+
max_k_tiles,
|
| 1418 |
+
).launch(
|
| 1419 |
+
grid=grid,
|
| 1420 |
+
block=[self.threads_per_cta, 1, 1],
|
| 1421 |
+
cluster=(1, 1, 1),
|
| 1422 |
+
stream=stream,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
@cute.jit
|
| 1426 |
+
def _group_has_visible(
|
| 1427 |
+
self,
|
| 1428 |
+
q_len: Int32,
|
| 1429 |
+
group_first_ktile: Int32,
|
| 1430 |
+
batch_k_tiles: Int32,
|
| 1431 |
+
causal_offset: Int32,
|
| 1432 |
+
):
|
| 1433 |
+
visible = q_len > Int32(0) and group_first_ktile < batch_k_tiles
|
| 1434 |
+
if const_expr(self.is_causal):
|
| 1435 |
+
visible = visible and group_first_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
|
| 1436 |
+
return visible
|
| 1437 |
+
|
| 1438 |
+
@cute.jit
|
| 1439 |
+
def _tile_has_visible(
|
| 1440 |
+
self,
|
| 1441 |
+
q_len: Int32,
|
| 1442 |
+
ktile: Int32,
|
| 1443 |
+
batch_k_tiles: Int32,
|
| 1444 |
+
causal_offset: Int32,
|
| 1445 |
+
):
|
| 1446 |
+
visible = ktile < batch_k_tiles
|
| 1447 |
+
if const_expr(self.is_causal):
|
| 1448 |
+
visible = visible and ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
|
| 1449 |
+
return visible
|
| 1450 |
+
|
| 1451 |
+
@cute.jit
|
| 1452 |
+
def _tile_mask_free(self, ktile: Int32, causal_offset: Int32):
|
| 1453 |
+
if const_expr(self.is_causal):
|
| 1454 |
+
return ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) <= causal_offset
|
| 1455 |
+
return True
|
| 1456 |
+
|
| 1457 |
+
@cute.jit
|
| 1458 |
+
def _packed_coord_visible(
|
| 1459 |
+
self,
|
| 1460 |
+
coord_m: Int32,
|
| 1461 |
+
target_m: Int32,
|
| 1462 |
+
h_in_group: Int32,
|
| 1463 |
+
qhead_per_kv: Int32,
|
| 1464 |
+
q_local: Int32,
|
| 1465 |
+
q_len: Int32,
|
| 1466 |
+
k_local: Int32,
|
| 1467 |
+
k_len: Int32,
|
| 1468 |
+
causal_offset: Int32,
|
| 1469 |
+
):
|
| 1470 |
+
visible = coord_m == target_m and h_in_group < qhead_per_kv and q_local < q_len and k_local < k_len
|
| 1471 |
+
if const_expr(self.is_causal):
|
| 1472 |
+
visible = visible and k_local <= q_local + causal_offset
|
| 1473 |
+
return visible
|
| 1474 |
+
|
| 1475 |
+
@cute.kernel
|
| 1476 |
+
def kernel(
|
| 1477 |
+
self,
|
| 1478 |
+
tiled_mma: cute.TiledMma,
|
| 1479 |
+
tma_q: cpasync.TmaInfo,
|
| 1480 |
+
tma_qs: cpasync.TmaInfo,
|
| 1481 |
+
tma_k: cpasync.TmaInfo,
|
| 1482 |
+
tma_ks: cpasync.TmaInfo,
|
| 1483 |
+
mScores: cute.Tensor,
|
| 1484 |
+
mKvIndices: cute.Tensor,
|
| 1485 |
+
mCuQ: cute.Tensor,
|
| 1486 |
+
mCuK: cute.Tensor,
|
| 1487 |
+
mCuPages: cute.Tensor,
|
| 1488 |
+
mQoOffset: cute.Tensor,
|
| 1489 |
+
q_smem_layout: cute.ComposedLayout,
|
| 1490 |
+
k_smem_layout: cute.ComposedLayout,
|
| 1491 |
+
q_scale_smem_layout: cute.Layout,
|
| 1492 |
+
k_scale_smem_layout: cute.Layout,
|
| 1493 |
+
heads_q: Int32,
|
| 1494 |
+
heads_k: Int32,
|
| 1495 |
+
batch: Int32,
|
| 1496 |
+
has_qo_offset: Int32,
|
| 1497 |
+
max_k_tiles: Int32,
|
| 1498 |
+
):
|
| 1499 |
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
| 1500 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 1501 |
+
epi_tidx = tidx % Int32(self.epi_threads_per_cta)
|
| 1502 |
+
epi_warpgroup_idx = warp_idx // Int32(self.epi_warps_per_group)
|
| 1503 |
+
task_x, task_y, _ = cute.arch.block_idx()
|
| 1504 |
+
task_valid = True
|
| 1505 |
+
batch_idx = Int32(0)
|
| 1506 |
+
hk = Int32(0)
|
| 1507 |
+
ktile_group = Int32(0)
|
| 1508 |
+
q_l = Int32(0)
|
| 1509 |
+
if const_expr(self.compact_schedule):
|
| 1510 |
+
hk = task_y
|
| 1511 |
+
group_base = Int32(0)
|
| 1512 |
+
scan_batch = Int32(0)
|
| 1513 |
+
task_valid = False
|
| 1514 |
+
while scan_batch < batch and not task_valid:
|
| 1515 |
+
batch_pages = mCuPages[scan_batch + Int32(1)] - mCuPages[scan_batch]
|
| 1516 |
+
batch_groups = (batch_pages + Int32(self.k_tiles_per_cta - 1)) // Int32(self.k_tiles_per_cta)
|
| 1517 |
+
task_valid = task_x < group_base + batch_groups
|
| 1518 |
+
if not task_valid:
|
| 1519 |
+
group_base += batch_groups
|
| 1520 |
+
scan_batch += Int32(1)
|
| 1521 |
+
if task_valid:
|
| 1522 |
+
batch_idx = scan_batch
|
| 1523 |
+
ktile_group = task_x - group_base
|
| 1524 |
+
q_l = batch_idx * heads_k + hk
|
| 1525 |
+
else:
|
| 1526 |
+
ktile_group = task_x
|
| 1527 |
+
q_l = task_y
|
| 1528 |
+
batch_idx = q_l // heads_k
|
| 1529 |
+
hk = q_l - batch_idx * heads_k
|
| 1530 |
+
qhead_per_kv = heads_q // heads_k
|
| 1531 |
+
q_begin = mCuQ[batch_idx]
|
| 1532 |
+
q_end = mCuQ[batch_idx + 1]
|
| 1533 |
+
k_begin = mCuK[batch_idx]
|
| 1534 |
+
k_end = mCuK[batch_idx + 1]
|
| 1535 |
+
q_len = q_end - q_begin
|
| 1536 |
+
k_len = k_end - k_begin
|
| 1537 |
+
if const_expr(self.compact_schedule):
|
| 1538 |
+
if not task_valid:
|
| 1539 |
+
q_len = Int32(0)
|
| 1540 |
+
k_len = Int32(0)
|
| 1541 |
+
page_begin = mCuPages[batch_idx]
|
| 1542 |
+
batch_k_tiles = (k_len + Int32(_PAGE_SIZE - 1)) // Int32(_PAGE_SIZE)
|
| 1543 |
+
causal_offset = Int32(0)
|
| 1544 |
+
if const_expr(self.is_causal):
|
| 1545 |
+
causal_offset = k_len - q_len
|
| 1546 |
+
if has_qo_offset != 0:
|
| 1547 |
+
causal_offset = mQoOffset[batch_idx]
|
| 1548 |
+
group_first_ktile = ktile_group * Int32(self.k_tiles_per_cta)
|
| 1549 |
+
group_has_visible = self._group_has_visible(
|
| 1550 |
+
q_len,
|
| 1551 |
+
group_first_ktile,
|
| 1552 |
+
batch_k_tiles,
|
| 1553 |
+
causal_offset,
|
| 1554 |
+
)
|
| 1555 |
+
|
| 1556 |
+
@cute.struct
|
| 1557 |
+
class SharedStorage:
|
| 1558 |
+
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
| 1559 |
+
q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
| 1560 |
+
k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
| 1561 |
+
tmem_holding_buf: cutlass.Int32
|
| 1562 |
+
|
| 1563 |
+
smem = utils.SmemAllocator()
|
| 1564 |
+
storage = smem.allocate(SharedStorage)
|
| 1565 |
+
sQ_public = smem.allocate_tensor(_AB_DTYPE, q_smem_layout.outer, 128, swizzle=q_smem_layout.inner)
|
| 1566 |
+
sK_public = smem.allocate_tensor(_AB_DTYPE, k_smem_layout.outer, 128, swizzle=k_smem_layout.inner)
|
| 1567 |
+
sQS_public = smem.allocate_tensor(self.sf_dtype, q_scale_smem_layout, 128)
|
| 1568 |
+
sKS_public = smem.allocate_tensor(self.sf_dtype, k_scale_smem_layout, 128)
|
| 1569 |
+
mQ_tma = tma_q.tma_tensor
|
| 1570 |
+
mQS_tma = tma_qs.tma_tensor
|
| 1571 |
+
mK_tma = tma_k.tma_tensor
|
| 1572 |
+
mKS_tma = tma_ks.tma_tensor
|
| 1573 |
+
thr_mma = tiled_mma.get_slice(0)
|
| 1574 |
+
tCrQ = tiled_mma.make_fragment_A(sQ_public)
|
| 1575 |
+
tCrK = tiled_mma.make_fragment_B(sK_public)
|
| 1576 |
+
tCcC = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler[:2]))
|
| 1577 |
+
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
| 1578 |
+
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
| 1579 |
+
|
| 1580 |
+
gQ_tma = cute.local_tile(
|
| 1581 |
+
mQ_tma,
|
| 1582 |
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
| 1583 |
+
(None, None, None),
|
| 1584 |
+
)
|
| 1585 |
+
tCgQ_tma = thr_mma.partition_A(gQ_tma)
|
| 1586 |
+
tQsQ_tma, tQgQ_tma = cpasync.tma_partition(
|
| 1587 |
+
tma_q.atom,
|
| 1588 |
+
0,
|
| 1589 |
+
cute.make_layout(1),
|
| 1590 |
+
cute.group_modes(sQ_public, 0, 3),
|
| 1591 |
+
cute.group_modes(tCgQ_tma, 0, 3),
|
| 1592 |
+
)
|
| 1593 |
+
gQS_tma = cute.local_tile(
|
| 1594 |
+
mQS_tma,
|
| 1595 |
+
cute.slice_(self.mma_tiler, (None, 0, None)),
|
| 1596 |
+
(None, None, None),
|
| 1597 |
+
)
|
| 1598 |
+
tCgQS_tma = thr_mma.partition_A(gQS_tma)
|
| 1599 |
+
tQsQS_tma, tQgQS_tma = cpasync.tma_partition(
|
| 1600 |
+
tma_qs.atom,
|
| 1601 |
+
0,
|
| 1602 |
+
cute.make_layout(1),
|
| 1603 |
+
cute.group_modes(sQS_public, 0, 3),
|
| 1604 |
+
cute.group_modes(tCgQS_tma, 0, 3),
|
| 1605 |
+
)
|
| 1606 |
+
tQsQS_tma = cute.filter_zeros(tQsQS_tma)
|
| 1607 |
+
tQgQS_tma = cute.filter_zeros(tQgQS_tma)
|
| 1608 |
+
gK_tma = cute.local_tile(
|
| 1609 |
+
mK_tma,
|
| 1610 |
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
| 1611 |
+
(None, None, None, None),
|
| 1612 |
+
)
|
| 1613 |
+
tCgK_tma = thr_mma.partition_B(gK_tma)
|
| 1614 |
+
tKsK_tma, tKgK_tma = cpasync.tma_partition(
|
| 1615 |
+
tma_k.atom,
|
| 1616 |
+
0,
|
| 1617 |
+
cute.make_layout(1),
|
| 1618 |
+
cute.group_modes(sK_public, 0, 3),
|
| 1619 |
+
cute.group_modes(tCgK_tma, 0, 3),
|
| 1620 |
+
)
|
| 1621 |
+
gKS_tma = cute.local_tile(
|
| 1622 |
+
mKS_tma,
|
| 1623 |
+
cute.slice_(self.mma_tiler, (0, None, None)),
|
| 1624 |
+
(None, None, None),
|
| 1625 |
+
)
|
| 1626 |
+
tCgKS_tma = thr_mma.partition_B(gKS_tma)
|
| 1627 |
+
tKsKS_tma, tKgKS_tma = cpasync.tma_partition(
|
| 1628 |
+
tma_ks.atom,
|
| 1629 |
+
0,
|
| 1630 |
+
cute.make_layout(1),
|
| 1631 |
+
cute.group_modes(sKS_public, 0, 3),
|
| 1632 |
+
cute.group_modes(tCgKS_tma, 0, 3),
|
| 1633 |
+
)
|
| 1634 |
+
tKsKS_tma = cute.filter_zeros(tKsKS_tma)
|
| 1635 |
+
tKgKS_tma = cute.filter_zeros(tKgKS_tma)
|
| 1636 |
+
|
| 1637 |
+
tmem = utils.TmemAllocator(
|
| 1638 |
+
storage.tmem_holding_buf.ptr,
|
| 1639 |
+
barrier_for_retrieve=pipeline.NamedBarrier(
|
| 1640 |
+
barrier_id=1,
|
| 1641 |
+
num_threads=32 * (self.mma_warp_id + 1),
|
| 1642 |
+
),
|
| 1643 |
+
)
|
| 1644 |
+
acc_pipeline = common_pipeline.PipelineUmmaAsync.create(
|
| 1645 |
+
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
| 1646 |
+
num_stages=self.num_acc_stage,
|
| 1647 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 1648 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, self.epi_threads_per_cta),
|
| 1649 |
+
defer_sync=True,
|
| 1650 |
+
)
|
| 1651 |
+
acc_producer, _ = acc_pipeline.make_participants()
|
| 1652 |
+
q_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_q.smem_layout)
|
| 1653 |
+
qs_tma_copy_bytes = cute.size_in_bytes(
|
| 1654 |
+
self.sf_dtype,
|
| 1655 |
+
cute.select(tma_qs.smem_layout, mode=[0, 1, 2]),
|
| 1656 |
+
)
|
| 1657 |
+
k_tma_copy_bytes = cute.size_in_bytes(_AB_DTYPE, tma_k.smem_layout)
|
| 1658 |
+
ks_tma_copy_bytes = cute.size_in_bytes(
|
| 1659 |
+
self.sf_dtype,
|
| 1660 |
+
cute.select(tma_ks.smem_layout, mode=[0, 1, 2]),
|
| 1661 |
+
)
|
| 1662 |
+
q_pair_tma_copy_bytes = q_tma_copy_bytes + qs_tma_copy_bytes
|
| 1663 |
+
k_pair_tma_copy_bytes = k_tma_copy_bytes + ks_tma_copy_bytes
|
| 1664 |
+
q_producer, q_consumer = pipeline.PipelineTmaAsync.create(
|
| 1665 |
+
barrier_storage=storage.q_mbar_ptr.data_ptr(),
|
| 1666 |
+
num_stages=1,
|
| 1667 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 1668 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 1669 |
+
tx_count=q_pair_tma_copy_bytes,
|
| 1670 |
+
defer_sync=True,
|
| 1671 |
+
).make_participants()
|
| 1672 |
+
k_producer, k_consumer = pipeline.PipelineTmaAsync.create(
|
| 1673 |
+
barrier_storage=storage.k_mbar_ptr.data_ptr(),
|
| 1674 |
+
num_stages=self.num_ab_stage,
|
| 1675 |
+
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 1676 |
+
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
| 1677 |
+
tx_count=k_pair_tma_copy_bytes,
|
| 1678 |
+
defer_sync=True,
|
| 1679 |
+
).make_participants()
|
| 1680 |
+
cute.arch.mbarrier_init_fence()
|
| 1681 |
+
cute.arch.barrier()
|
| 1682 |
+
|
| 1683 |
+
if warp_idx == self.load_warp_id:
|
| 1684 |
+
if group_has_visible:
|
| 1685 |
+
q_pair_empty = q_producer.acquire_and_advance()
|
| 1686 |
+
cute.copy(
|
| 1687 |
+
tma_q.atom,
|
| 1688 |
+
tQgQ_tma[(None, 0, 0, q_l)],
|
| 1689 |
+
tQsQ_tma[(None, q_pair_empty.index)],
|
| 1690 |
+
tma_bar_ptr=q_pair_empty.barrier,
|
| 1691 |
+
)
|
| 1692 |
+
cute.copy(
|
| 1693 |
+
tma_qs.atom,
|
| 1694 |
+
tQgQS_tma[(None, 0, 0, q_l)],
|
| 1695 |
+
tQsQS_tma[(None, q_pair_empty.index)],
|
| 1696 |
+
tma_bar_ptr=q_pair_empty.barrier,
|
| 1697 |
+
)
|
| 1698 |
+
q_pair_empty.commit()
|
| 1699 |
+
load_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 1700 |
+
if const_expr(self.is_causal):
|
| 1701 |
+
load_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
|
| 1702 |
+
load_group_full = load_group_full and load_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
|
| 1703 |
+
if load_group_full:
|
| 1704 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1705 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1706 |
+
k_pair_empty = k_producer.acquire_and_advance()
|
| 1707 |
+
physical_page = mKvIndices[page_begin + ktile]
|
| 1708 |
+
cute.copy(
|
| 1709 |
+
tma_k.atom,
|
| 1710 |
+
tKgK_tma[(None, 0, 0, hk, physical_page)],
|
| 1711 |
+
tKsK_tma[(None, k_pair_empty.index)],
|
| 1712 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 1713 |
+
)
|
| 1714 |
+
scale_l = physical_page * heads_k + hk
|
| 1715 |
+
cute.copy(
|
| 1716 |
+
tma_ks.atom,
|
| 1717 |
+
tKgKS_tma[(None, 0, 0, scale_l)],
|
| 1718 |
+
tKsKS_tma[(None, k_pair_empty.index)],
|
| 1719 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 1720 |
+
)
|
| 1721 |
+
k_pair_empty.commit()
|
| 1722 |
+
else:
|
| 1723 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1724 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1725 |
+
if ktile < max_k_tiles:
|
| 1726 |
+
tile_has_visible = self._tile_has_visible(
|
| 1727 |
+
q_len,
|
| 1728 |
+
ktile,
|
| 1729 |
+
batch_k_tiles,
|
| 1730 |
+
causal_offset,
|
| 1731 |
+
)
|
| 1732 |
+
if tile_has_visible:
|
| 1733 |
+
k_pair_empty = k_producer.acquire_and_advance()
|
| 1734 |
+
physical_page = mKvIndices[page_begin + ktile]
|
| 1735 |
+
cute.copy(
|
| 1736 |
+
tma_k.atom,
|
| 1737 |
+
tKgK_tma[(None, 0, 0, hk, physical_page)],
|
| 1738 |
+
tKsK_tma[(None, k_pair_empty.index)],
|
| 1739 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 1740 |
+
)
|
| 1741 |
+
scale_l = physical_page * heads_k + hk
|
| 1742 |
+
cute.copy(
|
| 1743 |
+
tma_ks.atom,
|
| 1744 |
+
tKgKS_tma[(None, 0, 0, scale_l)],
|
| 1745 |
+
tKsKS_tma[(None, k_pair_empty.index)],
|
| 1746 |
+
tma_bar_ptr=k_pair_empty.barrier,
|
| 1747 |
+
)
|
| 1748 |
+
k_pair_empty.commit()
|
| 1749 |
+
k_producer.tail()
|
| 1750 |
+
q_producer.tail()
|
| 1751 |
+
|
| 1752 |
+
if warp_idx == self.mma_warp_id:
|
| 1753 |
+
tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
|
| 1754 |
+
tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
|
| 1755 |
+
tCtQS_layout = blockscaled_utils.make_tmem_layout_sfa(
|
| 1756 |
+
tiled_mma,
|
| 1757 |
+
self.mma_tiler,
|
| 1758 |
+
self.sf_vec_size,
|
| 1759 |
+
cute.slice_(q_scale_smem_layout, (None, None, None, 0)),
|
| 1760 |
+
)
|
| 1761 |
+
tCtKS_layout = blockscaled_utils.make_tmem_layout_sfb(
|
| 1762 |
+
tiled_mma,
|
| 1763 |
+
self.mma_tiler,
|
| 1764 |
+
self.sf_vec_size,
|
| 1765 |
+
cute.slice_(k_scale_smem_layout, (None, None, None, 0)),
|
| 1766 |
+
)
|
| 1767 |
+
tCtQS = tmem_pool.allocate_tensor(tCtQS_layout, self.sf_dtype)
|
| 1768 |
+
tCtKS = tmem_pool.allocate_tensor(tCtKS_layout, self.sf_dtype)
|
| 1769 |
+
copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), self.sf_dtype)
|
| 1770 |
+
tCsQS_compact = cute.filter_zeros(sQS_public)
|
| 1771 |
+
tCtQS_compact = cute.filter_zeros(tCtQS)
|
| 1772 |
+
tiled_copy_s2t_qs = tcgen05.make_s2t_copy(copy_atom_s2t, tCtQS_compact)
|
| 1773 |
+
thr_copy_s2t_qs = tiled_copy_s2t_qs.get_slice(0)
|
| 1774 |
+
tCsQS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
| 1775 |
+
tiled_copy_s2t_qs,
|
| 1776 |
+
thr_copy_s2t_qs.partition_S(tCsQS_compact),
|
| 1777 |
+
)
|
| 1778 |
+
tCtQS_compact_s2t = thr_copy_s2t_qs.partition_D(tCtQS_compact)
|
| 1779 |
+
tCsKS_compact = cute.filter_zeros(sKS_public)
|
| 1780 |
+
tCtKS_compact = cute.filter_zeros(tCtKS)
|
| 1781 |
+
tiled_copy_s2t_ks = tcgen05.make_s2t_copy(copy_atom_s2t, tCtKS_compact)
|
| 1782 |
+
thr_copy_s2t_ks = tiled_copy_s2t_ks.get_slice(0)
|
| 1783 |
+
tCsKS_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
| 1784 |
+
tiled_copy_s2t_ks,
|
| 1785 |
+
thr_copy_s2t_ks.partition_S(tCsKS_compact),
|
| 1786 |
+
)
|
| 1787 |
+
tCtKS_compact_s2t = thr_copy_s2t_ks.partition_D(tCtKS_compact)
|
| 1788 |
+
if group_has_visible:
|
| 1789 |
+
q_pair_full = q_consumer.wait_and_advance()
|
| 1790 |
+
q_pair_full.release()
|
| 1791 |
+
cute.copy(tiled_copy_s2t_qs, tCsQS_compact_s2t[(None, None, None, None, 0)], tCtQS_compact_s2t)
|
| 1792 |
+
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
| 1793 |
+
q_tile_crd = (None, None, None, 0)
|
| 1794 |
+
if const_expr(self.is_causal):
|
| 1795 |
+
causal_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 1796 |
+
causal_group_last_ktile = group_first_ktile + Int32(self.k_tiles_per_cta - 1)
|
| 1797 |
+
causal_group_full = causal_group_full and causal_group_last_ktile * Int32(_BLOCK_K) <= (q_len - Int32(1)) + causal_offset
|
| 1798 |
+
if causal_group_full:
|
| 1799 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1800 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 1801 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 1802 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 1803 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 1804 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 1805 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 1806 |
+
acc_empty.commit()
|
| 1807 |
+
k_pair_full.release()
|
| 1808 |
+
else:
|
| 1809 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1810 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1811 |
+
if ktile < max_k_tiles:
|
| 1812 |
+
tile_has_visible = self._tile_has_visible(
|
| 1813 |
+
q_len,
|
| 1814 |
+
ktile,
|
| 1815 |
+
batch_k_tiles,
|
| 1816 |
+
causal_offset,
|
| 1817 |
+
)
|
| 1818 |
+
if tile_has_visible:
|
| 1819 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 1820 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 1821 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 1822 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 1823 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 1824 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 1825 |
+
acc_empty.commit()
|
| 1826 |
+
k_pair_full.release()
|
| 1827 |
+
else:
|
| 1828 |
+
k_group_full = group_first_ktile + Int32(self.k_tiles_per_cta) <= batch_k_tiles
|
| 1829 |
+
if k_group_full:
|
| 1830 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1831 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 1832 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 1833 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 1834 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 1835 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 1836 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 1837 |
+
acc_empty.commit()
|
| 1838 |
+
k_pair_full.release()
|
| 1839 |
+
else:
|
| 1840 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1841 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1842 |
+
if ktile < batch_k_tiles:
|
| 1843 |
+
k_pair_full = k_consumer.wait_and_advance()
|
| 1844 |
+
acc_empty = acc_producer.acquire_and_advance()
|
| 1845 |
+
cute.copy(tiled_copy_s2t_ks, tCsKS_compact_s2t[(None, None, None, None, k_pair_full.index)], tCtKS_compact_s2t)
|
| 1846 |
+
k_tile_crd = (None, None, None, k_pair_full.index)
|
| 1847 |
+
tCtAcc_stage = tCtAcc[(None, None, None, acc_empty.index)]
|
| 1848 |
+
cute.gemm(tiled_mma, tCtAcc_stage, [tCrQ[q_tile_crd], tCtQS], [tCrK[k_tile_crd], tCtKS], tCtAcc_stage)
|
| 1849 |
+
acc_empty.commit()
|
| 1850 |
+
k_pair_full.release()
|
| 1851 |
+
acc_producer.tail()
|
| 1852 |
+
|
| 1853 |
+
if warp_idx < self.mma_warp_id:
|
| 1854 |
+
tmem_pool = tmem.reserve(self.num_tmem_alloc_cols)
|
| 1855 |
+
tCtAcc = tmem_pool.allocate_tensor(tCtAcc_fake.layout, Float32)
|
| 1856 |
+
if const_expr(self.use_tmem_load_red):
|
| 1857 |
+
copy_atom_t2r = cute.make_copy_atom(
|
| 1858 |
+
tcgen05.LdRed32x32bOp(
|
| 1859 |
+
tcgen05.Repetition.x128,
|
| 1860 |
+
tcgen05.Pack.NONE,
|
| 1861 |
+
tcgen05.TmemLoadRedOp.MAX,
|
| 1862 |
+
),
|
| 1863 |
+
Float32,
|
| 1864 |
+
)
|
| 1865 |
+
else:
|
| 1866 |
+
copy_atom_t2r = cute.make_copy_atom(
|
| 1867 |
+
tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
|
| 1868 |
+
Float32,
|
| 1869 |
+
)
|
| 1870 |
+
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None, None, 0)])
|
| 1871 |
+
thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
|
| 1872 |
+
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
|
| 1873 |
+
tTR_cC = thr_copy_t2r.partition_D(tCcC)
|
| 1874 |
+
tTR_rAcc = cute.make_rmem_tensor(tTR_cC.shape, Float32)
|
| 1875 |
+
if const_expr(self.use_tmem_load_red):
|
| 1876 |
+
tTR_rRed = cute.make_rmem_tensor((1,), Float32)
|
| 1877 |
+
h_store = epi_tidx // Int32(_DECODE_PACK_Q_LEN)
|
| 1878 |
+
q_local_store = epi_tidx - h_store * Int32(_DECODE_PACK_Q_LEN)
|
| 1879 |
+
h_global_store = hk * qhead_per_kv + h_store
|
| 1880 |
+
q_global_store = q_begin + q_local_store
|
| 1881 |
+
if group_has_visible:
|
| 1882 |
+
visible_tile_count = Int32(0)
|
| 1883 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1884 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1885 |
+
if ktile < max_k_tiles:
|
| 1886 |
+
tile_has_visible = self._tile_has_visible(
|
| 1887 |
+
q_len,
|
| 1888 |
+
ktile,
|
| 1889 |
+
batch_k_tiles,
|
| 1890 |
+
causal_offset,
|
| 1891 |
+
)
|
| 1892 |
+
if tile_has_visible:
|
| 1893 |
+
epilogue_owns_tile = epi_warpgroup_idx == Int32(
|
| 1894 |
+
ktile_inner % self.num_epi_warpgroups
|
| 1895 |
+
)
|
| 1896 |
+
if epilogue_owns_tile:
|
| 1897 |
+
acc_stage_index = visible_tile_count % Int32(self.num_acc_stage)
|
| 1898 |
+
acc_stage_phase = (visible_tile_count // Int32(self.num_acc_stage)) % Int32(2)
|
| 1899 |
+
tile_mask_free = self._tile_mask_free(ktile, causal_offset)
|
| 1900 |
+
k_tile_full = ktile * Int32(_BLOCK_K) + Int32(_BLOCK_K - 1) < k_len
|
| 1901 |
+
q_pack_full = q_len == Int32(_DECODE_PACK_Q_LEN)
|
| 1902 |
+
tile_full = q_pack_full and k_tile_full
|
| 1903 |
+
acc_pipeline.consumer_wait_w_index_phase(acc_stage_index, acc_stage_phase)
|
| 1904 |
+
tTR_tAcc_stage = tTR_tAcc[(None, None, None, None, acc_stage_index)]
|
| 1905 |
+
if const_expr(self.use_tmem_load_red):
|
| 1906 |
+
cute.copy(tiled_copy_t2r, tTR_tAcc_stage, [tTR_rAcc, tTR_rRed])
|
| 1907 |
+
else:
|
| 1908 |
+
cute.copy(tiled_copy_t2r, tTR_tAcc_stage, tTR_rAcc)
|
| 1909 |
+
row_max0 = -Float32.inf
|
| 1910 |
+
if tile_mask_free and tile_full:
|
| 1911 |
+
if const_expr(self.use_tmem_load_red):
|
| 1912 |
+
row_max0 = tTR_rRed[0]
|
| 1913 |
+
else:
|
| 1914 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1915 |
+
coord_m, _ = tTR_cC[i]
|
| 1916 |
+
if coord_m == epi_tidx:
|
| 1917 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1918 |
+
else:
|
| 1919 |
+
for i in cutlass.range(cute.size(tTR_rAcc), unroll_full=True):
|
| 1920 |
+
coord_m, coord_n = tTR_cC[i]
|
| 1921 |
+
h_in_group = coord_m // Int32(_DECODE_PACK_Q_LEN)
|
| 1922 |
+
q_local = coord_m - h_in_group * Int32(_DECODE_PACK_Q_LEN)
|
| 1923 |
+
k_local = ktile * Int32(_BLOCK_K) + coord_n
|
| 1924 |
+
valid = self._packed_coord_visible(
|
| 1925 |
+
coord_m,
|
| 1926 |
+
epi_tidx,
|
| 1927 |
+
h_in_group,
|
| 1928 |
+
qhead_per_kv,
|
| 1929 |
+
q_local,
|
| 1930 |
+
q_len,
|
| 1931 |
+
k_local,
|
| 1932 |
+
k_len,
|
| 1933 |
+
causal_offset,
|
| 1934 |
+
)
|
| 1935 |
+
if valid:
|
| 1936 |
+
row_max0 = cute.arch.fmax(row_max0, tTR_rAcc[i])
|
| 1937 |
+
if h_store < qhead_per_kv and q_local_store < q_len:
|
| 1938 |
+
mScores[h_global_store, ktile, q_global_store] = row_max0
|
| 1939 |
+
cute.arch.fence_view_async_tmem_load()
|
| 1940 |
+
acc_pipeline.consumer_release_w_index(acc_stage_index)
|
| 1941 |
+
visible_tile_count += Int32(1)
|
| 1942 |
+
else:
|
| 1943 |
+
if const_expr(not self.compact_schedule):
|
| 1944 |
+
if epi_warpgroup_idx == Int32(0):
|
| 1945 |
+
if h_store < qhead_per_kv and q_local_store < q_len:
|
| 1946 |
+
mScores[h_global_store, ktile, q_global_store] = -Float32.inf
|
| 1947 |
+
else:
|
| 1948 |
+
if const_expr(not self.compact_schedule):
|
| 1949 |
+
if epi_warpgroup_idx == Int32(0):
|
| 1950 |
+
for ktile_inner in cutlass.range_constexpr(self.k_tiles_per_cta):
|
| 1951 |
+
ktile = group_first_ktile + Int32(ktile_inner)
|
| 1952 |
+
if ktile < max_k_tiles:
|
| 1953 |
+
if h_store < qhead_per_kv and q_local_store < q_len:
|
| 1954 |
+
mScores[h_global_store, ktile, q_global_store] = -Float32.inf
|
| 1955 |
+
cute.arch.barrier()
|
| 1956 |
+
tmem.free(tmem_pool.base_ptr)
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""SM100 forward kernels and combine paths."""
|
| 5 |
+
|
| 6 |
+
from .atten_fwd_nvfp4_kv import SparseAttentionForwardNvfp4KvSm100
|
| 7 |
+
|
| 8 |
+
__all__ = ["SparseAttentionForwardNvfp4KvSm100"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/atten_fwd_nvfp4_kv.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd/combine.py
ADDED
|
@@ -0,0 +1,1498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Sparse forward combine kernel and public launcher.
|
| 5 |
+
|
| 6 |
+
This keeps the local fake-layout -> real-layout epilogue needed by the lean
|
| 7 |
+
sparse forward path.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Modified Step 7: O_out write with SMEM fake->real column permutation.
|
| 11 |
+
# O_partial dim is in STG.128 fake layout; O_out dim is real layout.
|
| 12 |
+
import math
|
| 13 |
+
from typing import Type, Optional
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
import cuda.bindings.driver as cuda
|
| 17 |
+
|
| 18 |
+
import cutlass
|
| 19 |
+
import cutlass.cute as cute
|
| 20 |
+
import torch
|
| 21 |
+
from cutlass.cute.nvgpu import cpasync
|
| 22 |
+
from cutlass import Float32, Int32, Int64, Boolean, const_expr
|
| 23 |
+
|
| 24 |
+
from ....src.common import utils
|
| 25 |
+
from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map
|
| 26 |
+
from ....src.common.seqlen_info import SeqlenInfo
|
| 27 |
+
from cutlass.cute import FastDivmodDivisor
|
| 28 |
+
|
| 29 |
+
from ....src.common.pack_gqa import PackGQAComb
|
| 30 |
+
from ....src.common.tma_utils import (
|
| 31 |
+
stg128_fake_col_to_real_col,
|
| 32 |
+
stg128_fp8_fake_col_to_real_col,
|
| 33 |
+
stg128_half_fake_col_to_real_col,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SparseAttentionForwardCombine:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dtype: Type[cutlass.Numeric],
|
| 41 |
+
dtype_partial: Type[cutlass.Numeric],
|
| 42 |
+
head_dim: int,
|
| 43 |
+
tile_m: int = 8,
|
| 44 |
+
k_block_size: int = 64,
|
| 45 |
+
topk: int = 16,
|
| 46 |
+
num_threads: int = 256,
|
| 47 |
+
stages: int = 4,
|
| 48 |
+
use_pdl: bool = False,
|
| 49 |
+
min_blocks_per_mp: int = 0,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Forward combine kernel for split attention computation.
|
| 53 |
+
|
| 54 |
+
:param dtype: output data type
|
| 55 |
+
:param dtype_partial: partial accumulation data type
|
| 56 |
+
:param head_dim: head dimension
|
| 57 |
+
:param tile_m: m block size
|
| 58 |
+
:param k_block_size: k block size
|
| 59 |
+
:param topk: exact number of split partials
|
| 60 |
+
:param num_threads: number of threads
|
| 61 |
+
:param varlen: whether using variable length sequences
|
| 62 |
+
:param stages: number of pipeline stages
|
| 63 |
+
"""
|
| 64 |
+
self.dtype = dtype
|
| 65 |
+
self.dtype_partial = dtype_partial
|
| 66 |
+
self.head_dim = head_dim
|
| 67 |
+
self.tile_m = tile_m
|
| 68 |
+
self.k_block_size = k_block_size
|
| 69 |
+
self.topk = topk
|
| 70 |
+
self.num_threads = num_threads
|
| 71 |
+
self.is_even_k = head_dim % k_block_size == 0
|
| 72 |
+
self.stages = stages
|
| 73 |
+
self.use_pdl = use_pdl
|
| 74 |
+
self.min_blocks_per_mp = min_blocks_per_mp
|
| 75 |
+
self.use_stg128_half_layout = dtype_partial in (cutlass.BFloat16, cutlass.Float16)
|
| 76 |
+
self.use_stg128_fp8_layout = dtype_partial is cutlass.Float8E4M3FN
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def can_implement(
|
| 80 |
+
dtype,
|
| 81 |
+
dtype_partial,
|
| 82 |
+
head_dim,
|
| 83 |
+
tile_m,
|
| 84 |
+
k_block_size,
|
| 85 |
+
topk,
|
| 86 |
+
num_threads,
|
| 87 |
+
) -> bool:
|
| 88 |
+
"""Check if the kernel can be implemented with the given parameters."""
|
| 89 |
+
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
|
| 90 |
+
return False
|
| 91 |
+
if dtype_partial not in [
|
| 92 |
+
cutlass.Float16,
|
| 93 |
+
cutlass.BFloat16,
|
| 94 |
+
cutlass.Float8E4M3FN,
|
| 95 |
+
Float32,
|
| 96 |
+
]:
|
| 97 |
+
return False
|
| 98 |
+
if head_dim % 8 != 0:
|
| 99 |
+
return False
|
| 100 |
+
if num_threads % 32 != 0:
|
| 101 |
+
return False
|
| 102 |
+
if tile_m % 8 != 0:
|
| 103 |
+
return False
|
| 104 |
+
if topk > 256:
|
| 105 |
+
return False
|
| 106 |
+
if (tile_m * topk) % num_threads != 0:
|
| 107 |
+
return False
|
| 108 |
+
return True
|
| 109 |
+
|
| 110 |
+
def _setup_attributes(self):
|
| 111 |
+
# GMEM copy setup for O partial
|
| 112 |
+
universal_copy_bits = 128
|
| 113 |
+
async_copy_elems = universal_copy_bits // self.dtype_partial.width
|
| 114 |
+
assert self.k_block_size % async_copy_elems == 0
|
| 115 |
+
|
| 116 |
+
k_block_gmem = (
|
| 117 |
+
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
|
| 118 |
+
)
|
| 119 |
+
gmem_threads_per_row = k_block_gmem // async_copy_elems
|
| 120 |
+
assert self.num_threads % gmem_threads_per_row == 0
|
| 121 |
+
|
| 122 |
+
# Async copy atom for O partial load
|
| 123 |
+
atom_async_copy_partial = cute.make_copy_atom(
|
| 124 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 125 |
+
self.dtype_partial,
|
| 126 |
+
num_bits_per_copy=universal_copy_bits,
|
| 127 |
+
)
|
| 128 |
+
tOpartial_layout = cute.make_ordered_layout(
|
| 129 |
+
(self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 130 |
+
order=(1, 0),
|
| 131 |
+
)
|
| 132 |
+
vOpartial_layout = cute.make_layout((1, async_copy_elems))
|
| 133 |
+
self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
|
| 134 |
+
atom_async_copy_partial, tOpartial_layout, vOpartial_layout
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# GMEM copy setup for final O (use universal copy for store).
|
| 138 |
+
# Keep this independent from O_partial: fp8 partial uses 16 elements
|
| 139 |
+
# per 128b transaction, while bf16/fp16 O stores must remain 8-wide.
|
| 140 |
+
output_copy_elems = universal_copy_bits // self.dtype.width
|
| 141 |
+
assert self.k_block_size % output_copy_elems == 0
|
| 142 |
+
gmem_threads_per_row_o = k_block_gmem // output_copy_elems
|
| 143 |
+
assert self.num_threads % gmem_threads_per_row_o == 0
|
| 144 |
+
atom_universal_copy = cute.make_copy_atom(
|
| 145 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 146 |
+
self.dtype,
|
| 147 |
+
num_bits_per_copy=universal_copy_bits,
|
| 148 |
+
)
|
| 149 |
+
tO_layout = cute.make_ordered_layout(
|
| 150 |
+
(self.num_threads // gmem_threads_per_row_o, gmem_threads_per_row_o),
|
| 151 |
+
order=(1, 0),
|
| 152 |
+
)
|
| 153 |
+
vO_layout = cute.make_layout((1, output_copy_elems))
|
| 154 |
+
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
| 155 |
+
atom_universal_copy,
|
| 156 |
+
tO_layout,
|
| 157 |
+
vO_layout,
|
| 158 |
+
)
|
| 159 |
+
# LSE copy setup with async copy (alignment = 1)
|
| 160 |
+
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
| 161 |
+
m_block_smem = (
|
| 162 |
+
128
|
| 163 |
+
if self.tile_m % 128 == 0
|
| 164 |
+
else (
|
| 165 |
+
64
|
| 166 |
+
if self.tile_m % 64 == 0
|
| 167 |
+
else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
gmem_threads_per_row_lse = m_block_smem
|
| 171 |
+
assert self.num_threads % gmem_threads_per_row_lse == 0
|
| 172 |
+
|
| 173 |
+
# Async copy atom for LSE load
|
| 174 |
+
atom_async_copy_lse = cute.make_copy_atom(
|
| 175 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
| 176 |
+
Float32,
|
| 177 |
+
num_bits_per_copy=lse_copy_bits,
|
| 178 |
+
)
|
| 179 |
+
tLSE_layout = cute.make_ordered_layout(
|
| 180 |
+
(self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
|
| 181 |
+
order=(1, 0),
|
| 182 |
+
)
|
| 183 |
+
vLSE_layout = cute.make_layout(1)
|
| 184 |
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 185 |
+
atom_async_copy_lse, tLSE_layout, vLSE_layout
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 189 |
+
# Shared memory
|
| 190 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 191 |
+
|
| 192 |
+
# Shared memory to register copy for LSE
|
| 193 |
+
self.smem_threads_per_col_lse = self.num_threads // m_block_smem
|
| 194 |
+
assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
|
| 195 |
+
|
| 196 |
+
s2r_layout_atom_lse = cute.make_ordered_layout(
|
| 197 |
+
(self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
|
| 198 |
+
order=(0, 1),
|
| 199 |
+
)
|
| 200 |
+
self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 201 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
|
| 202 |
+
s2r_layout_atom_lse,
|
| 203 |
+
cute.make_layout(1),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# LSE shared memory layout with swizzling to avoid bank conflicts
|
| 207 |
+
# This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
|
| 208 |
+
if const_expr(m_block_smem == 8):
|
| 209 |
+
smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
|
| 210 |
+
elif const_expr(m_block_smem == 16):
|
| 211 |
+
smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
|
| 212 |
+
else:
|
| 213 |
+
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
|
| 214 |
+
lse_atom_splits = min(self.topk, 8)
|
| 215 |
+
smem_layout_atom_lse = cute.make_composed_layout(
|
| 216 |
+
smem_lse_swizzle,
|
| 217 |
+
0,
|
| 218 |
+
cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)),
|
| 219 |
+
)
|
| 220 |
+
self.smem_layout_lse = cute.tile_to_shape(
|
| 221 |
+
smem_layout_atom_lse, (self.topk, self.tile_m), (0, 1)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# O_partial staging layout.
|
| 225 |
+
if const_expr(
|
| 226 |
+
self.dtype_partial
|
| 227 |
+
in [cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN]
|
| 228 |
+
):
|
| 229 |
+
smem_layout_atom_o = _get_cpasync_smem_layout_atom(
|
| 230 |
+
self.dtype_partial, self.k_block_size
|
| 231 |
+
)
|
| 232 |
+
self.smem_layout_o = cute.tile_to_shape(
|
| 233 |
+
smem_layout_atom_o,
|
| 234 |
+
(self.tile_m, self.k_block_size, self.stages),
|
| 235 |
+
(0, 1, 2),
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
self.smem_layout_o = cute.make_ordered_layout(
|
| 239 |
+
(self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
@cute.jit
|
| 243 |
+
def __call__(
|
| 244 |
+
self,
|
| 245 |
+
mO_partial: cute.Tensor,
|
| 246 |
+
mLSE_partial: cute.Tensor,
|
| 247 |
+
mO: cute.Tensor,
|
| 248 |
+
mLSE: Optional[cute.Tensor] = None,
|
| 249 |
+
mLSE_temperature_partial: Optional[cute.Tensor] = None,
|
| 250 |
+
mLSE_temperature: Optional[cute.Tensor] = None,
|
| 251 |
+
cu_seqlens: Optional[cute.Tensor] = None,
|
| 252 |
+
seqused: Optional[cute.Tensor] = None,
|
| 253 |
+
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
|
| 254 |
+
varlen_batch_idx: Optional[cute.Tensor] = None,
|
| 255 |
+
semaphore_to_reset: Optional[cute.Tensor] = None,
|
| 256 |
+
mSplitCounts: Optional[cute.Tensor] = None,
|
| 257 |
+
mOutputScale: Optional[cute.Tensor] = None,
|
| 258 |
+
qhead_per_kvhead: Int32 = Int32(1),
|
| 259 |
+
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
|
| 260 |
+
stream: cuda.CUstream = None,
|
| 261 |
+
):
|
| 262 |
+
# Type checking
|
| 263 |
+
if const_expr(not (mO_partial.element_type == self.dtype_partial)):
|
| 264 |
+
raise TypeError("O partial tensor must match dtype_partial")
|
| 265 |
+
if const_expr(not (mO.element_type == self.dtype)):
|
| 266 |
+
raise TypeError("O tensor must match dtype")
|
| 267 |
+
if const_expr(mLSE_partial.element_type not in [Float32]):
|
| 268 |
+
raise TypeError("LSE partial tensor must be Float32")
|
| 269 |
+
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
|
| 270 |
+
raise TypeError("LSE tensor must be Float32")
|
| 271 |
+
if const_expr(
|
| 272 |
+
mLSE_temperature_partial is not None
|
| 273 |
+
and mLSE_temperature_partial.element_type not in [Float32]
|
| 274 |
+
):
|
| 275 |
+
raise TypeError("temperature LSE partial tensor must be Float32")
|
| 276 |
+
if const_expr(mLSE_temperature is not None and mLSE_temperature.element_type not in [Float32]):
|
| 277 |
+
raise TypeError("temperature LSE tensor must be Float32")
|
| 278 |
+
if const_expr((mLSE_temperature_partial is None) != (mLSE_temperature is None)):
|
| 279 |
+
raise ValueError(
|
| 280 |
+
"temperature LSE partial and output tensors must either both be provided or both be None"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Shape validation - input tensors are in user format, need to be converted to kernel format
|
| 284 |
+
if const_expr(len(mO_partial.shape) not in [4, 5]):
|
| 285 |
+
raise ValueError(
|
| 286 |
+
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
|
| 287 |
+
)
|
| 288 |
+
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
|
| 289 |
+
raise ValueError(
|
| 290 |
+
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
|
| 291 |
+
)
|
| 292 |
+
if const_expr(len(mO.shape) not in [3, 4]):
|
| 293 |
+
raise ValueError(
|
| 294 |
+
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
|
| 295 |
+
)
|
| 296 |
+
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
|
| 297 |
+
raise ValueError(
|
| 298 |
+
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
|
| 299 |
+
)
|
| 300 |
+
if const_expr(mLSE_temperature_partial is not None and len(mLSE_temperature_partial.shape) not in [3, 4]):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
"temperature LSE partial tensor must have 3 or 4 dimensions: "
|
| 303 |
+
"(num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
|
| 304 |
+
)
|
| 305 |
+
if const_expr(mLSE_temperature is not None and len(mLSE_temperature.shape) not in [2, 3]):
|
| 306 |
+
raise ValueError(
|
| 307 |
+
"temperature LSE tensor must have 2 or 3 dimensions: "
|
| 308 |
+
"(batch, seqlen, nheads) or (total_q, nheads)"
|
| 309 |
+
)
|
| 310 |
+
if const_expr(mSplitCounts is not None):
|
| 311 |
+
if const_expr(mSplitCounts.element_type not in [Int32]):
|
| 312 |
+
raise TypeError("split_counts tensor must be Int32")
|
| 313 |
+
if const_expr(cu_seqlens is not None):
|
| 314 |
+
if const_expr(len(mSplitCounts.shape) != 2):
|
| 315 |
+
raise ValueError("varlen split_counts tensor must have shape (total_q, nheads_kv)")
|
| 316 |
+
elif const_expr(len(mSplitCounts.shape) != 3):
|
| 317 |
+
raise ValueError("batched split_counts tensor must have shape (batch, seqlen, nheads_kv)")
|
| 318 |
+
if const_expr(mOutputScale is not None and mOutputScale.element_type not in [Float32]):
|
| 319 |
+
raise TypeError("output_scale tensor must be Float32")
|
| 320 |
+
|
| 321 |
+
mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
|
| 322 |
+
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
|
| 323 |
+
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
|
| 324 |
+
O_partial_layout_transpose = (
|
| 325 |
+
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
|
| 326 |
+
)
|
| 327 |
+
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
|
| 328 |
+
mO_partial = cute.make_tensor(
|
| 329 |
+
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
|
| 330 |
+
)
|
| 331 |
+
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
|
| 332 |
+
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
|
| 333 |
+
# (num_splits, b, h, seqlen) -> (seqlen, num_splits, h, b)
|
| 334 |
+
# Input is pre-transposed: [topK, B, Hq, Sq] with Sq innermost for K2-friendly reads.
|
| 335 |
+
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
|
| 336 |
+
LSE_partial_layout_transpose = [3, 0, 2, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
|
| 337 |
+
mLSE_partial = cute.make_tensor(
|
| 338 |
+
mLSE_partial.iterator,
|
| 339 |
+
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
|
| 340 |
+
)
|
| 341 |
+
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
|
| 342 |
+
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
|
| 343 |
+
mLSE = (
|
| 344 |
+
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
| 345 |
+
if mLSE is not None
|
| 346 |
+
else None
|
| 347 |
+
)
|
| 348 |
+
mLSE_temperature_partial = (
|
| 349 |
+
cute.make_tensor(
|
| 350 |
+
mLSE_temperature_partial.iterator,
|
| 351 |
+
cute.select(mLSE_temperature_partial.layout, mode=LSE_partial_layout_transpose),
|
| 352 |
+
)
|
| 353 |
+
if mLSE_temperature_partial is not None
|
| 354 |
+
else None
|
| 355 |
+
)
|
| 356 |
+
mLSE_temperature = (
|
| 357 |
+
cute.make_tensor(
|
| 358 |
+
mLSE_temperature.iterator,
|
| 359 |
+
cute.select(mLSE_temperature.layout, mode=LSE_layout_transpose),
|
| 360 |
+
)
|
| 361 |
+
if mLSE_temperature is not None
|
| 362 |
+
else None
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Determine if we have variable length sequences
|
| 366 |
+
varlen = const_expr(cu_seqlens is not None or seqused is not None)
|
| 367 |
+
|
| 368 |
+
self._setup_attributes()
|
| 369 |
+
|
| 370 |
+
# Output-dtype permutation buffer for Step 7 (tile_m × k_block_size).
|
| 371 |
+
# Accumulation stays fp32; the final dtype conversion happens before
|
| 372 |
+
# the fake→real SMEM scatter to reduce half-output SMEM pressure.
|
| 373 |
+
if const_expr(self.dtype in [cutlass.Float16, cutlass.BFloat16]):
|
| 374 |
+
smem_layout_perm = cute.make_layout(
|
| 375 |
+
(self.tile_m, self.k_block_size),
|
| 376 |
+
stride=(self.k_block_size + 16, 1),
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
smem_layout_perm = cute.make_ordered_layout(
|
| 380 |
+
(self.tile_m, self.k_block_size), order=(1, 0)
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
@cute.struct
|
| 384 |
+
class SharedStorage:
|
| 385 |
+
sLSE: cute.struct.Align[
|
| 386 |
+
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 387 |
+
]
|
| 388 |
+
sLSETemperature: cute.struct.Align[
|
| 389 |
+
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 390 |
+
]
|
| 391 |
+
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128]
|
| 392 |
+
sO: cute.struct.Align[
|
| 393 |
+
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
| 394 |
+
]
|
| 395 |
+
sO_perm: cute.struct.Align[
|
| 396 |
+
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_perm)], 128
|
| 397 |
+
]
|
| 398 |
+
|
| 399 |
+
smem_size = SharedStorage.size_in_bytes()
|
| 400 |
+
|
| 401 |
+
# Grid: (ceil(seqlen/tile_m), ceil(dim/k_block), num_head * batch)
|
| 402 |
+
# Head separated from seqlen → enables future TMA (contiguous Sq tiles)
|
| 403 |
+
seqlen = mO_partial.shape[0]
|
| 404 |
+
num_head = mO_partial.shape[3]
|
| 405 |
+
batch_size = (
|
| 406 |
+
mO_partial.shape[4]
|
| 407 |
+
if const_expr(cu_seqlens is None)
|
| 408 |
+
else Int32(cu_seqlens.shape[0] - 1)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
seqlen_divmod = FastDivmodDivisor(seqlen)
|
| 412 |
+
head_divmod = FastDivmodDivisor(num_head)
|
| 413 |
+
|
| 414 |
+
grid_dim = (
|
| 415 |
+
cute.ceil_div(seqlen * num_head, self.tile_m),
|
| 416 |
+
cute.ceil_div(self.head_dim, self.k_block_size),
|
| 417 |
+
batch_size,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
self.kernel(
|
| 421 |
+
mO_partial,
|
| 422 |
+
mLSE_partial,
|
| 423 |
+
mO,
|
| 424 |
+
mLSE,
|
| 425 |
+
mLSE_temperature_partial,
|
| 426 |
+
mLSE_temperature,
|
| 427 |
+
cu_seqlens,
|
| 428 |
+
seqused,
|
| 429 |
+
num_splits_dynamic_ptr,
|
| 430 |
+
varlen_batch_idx,
|
| 431 |
+
semaphore_to_reset,
|
| 432 |
+
mSplitCounts,
|
| 433 |
+
mOutputScale,
|
| 434 |
+
qhead_per_kvhead,
|
| 435 |
+
SharedStorage,
|
| 436 |
+
self.smem_layout_lse,
|
| 437 |
+
self.smem_layout_o,
|
| 438 |
+
smem_layout_perm,
|
| 439 |
+
self.gmem_tiled_copy_O_partial,
|
| 440 |
+
self.gmem_tiled_copy_O,
|
| 441 |
+
self.gmem_tiled_copy_LSE,
|
| 442 |
+
self.s2r_tiled_copy_LSE,
|
| 443 |
+
seqlen_divmod,
|
| 444 |
+
head_divmod,
|
| 445 |
+
self.use_pdl,
|
| 446 |
+
varlen,
|
| 447 |
+
).launch(
|
| 448 |
+
grid=grid_dim,
|
| 449 |
+
block=[self.num_threads, 1, 1],
|
| 450 |
+
smem=smem_size,
|
| 451 |
+
stream=stream,
|
| 452 |
+
min_blocks_per_mp=self.min_blocks_per_mp,
|
| 453 |
+
use_pdl=self.use_pdl,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
@cute.jit
|
| 457 |
+
def decode_flat_row_idx(
|
| 458 |
+
self,
|
| 459 |
+
idx: Int32,
|
| 460 |
+
head_divmod: FastDivmodDivisor,
|
| 461 |
+
):
|
| 462 |
+
"""Decode flattened tile rows under the H_q-innermost contract."""
|
| 463 |
+
q_idx_local, head_idx = divmod(idx, head_divmod)
|
| 464 |
+
return q_idx_local, head_idx
|
| 465 |
+
|
| 466 |
+
@cute.kernel
|
| 467 |
+
def kernel(
|
| 468 |
+
self,
|
| 469 |
+
mO_partial: cute.Tensor,
|
| 470 |
+
mLSE_partial: cute.Tensor,
|
| 471 |
+
mO: cute.Tensor,
|
| 472 |
+
mLSE: Optional[cute.Tensor],
|
| 473 |
+
mLSE_temperature_partial: Optional[cute.Tensor],
|
| 474 |
+
mLSE_temperature: Optional[cute.Tensor],
|
| 475 |
+
cu_seqlens: Optional[cute.Tensor],
|
| 476 |
+
seqused: Optional[cute.Tensor],
|
| 477 |
+
num_splits_dynamic_ptr: Optional[cute.Tensor],
|
| 478 |
+
varlen_batch_idx: Optional[cute.Tensor],
|
| 479 |
+
semaphore_to_reset: Optional[cute.Tensor],
|
| 480 |
+
mSplitCounts: Optional[cute.Tensor],
|
| 481 |
+
mOutputScale: Optional[cute.Tensor],
|
| 482 |
+
qhead_per_kvhead: Int32,
|
| 483 |
+
SharedStorage: cutlass.Constexpr,
|
| 484 |
+
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
| 485 |
+
smem_layout_o: cute.Layout | cute.ComposedLayout,
|
| 486 |
+
smem_layout_perm: cute.Layout,
|
| 487 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 488 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 489 |
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
| 490 |
+
s2r_tiled_copy_LSE: cute.TiledCopy,
|
| 491 |
+
seqlen_divmod: FastDivmodDivisor,
|
| 492 |
+
head_divmod: FastDivmodDivisor,
|
| 493 |
+
use_pdl: cutlass.Constexpr[bool],
|
| 494 |
+
varlen: cutlass.Constexpr[bool],
|
| 495 |
+
):
|
| 496 |
+
# Thread and block indices
|
| 497 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 498 |
+
m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()
|
| 499 |
+
|
| 500 |
+
batch_idx = (
|
| 501 |
+
varlen_batch_idx[maybe_virtual_batch]
|
| 502 |
+
if const_expr(varlen_batch_idx is not None)
|
| 503 |
+
else maybe_virtual_batch
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 507 |
+
# Get shared memory buffer
|
| 508 |
+
# ///////////////////////////////////////////////////////////////////////////////
|
| 509 |
+
smem = cutlass.utils.SmemAllocator()
|
| 510 |
+
storage = smem.allocate(SharedStorage)
|
| 511 |
+
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
| 512 |
+
sLSE_temperature = storage.sLSETemperature.get_tensor(smem_layout_lse)
|
| 513 |
+
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
|
| 514 |
+
sO = storage.sO.get_tensor(smem_layout_o)
|
| 515 |
+
sO_perm_buf = storage.sO_perm.get_tensor(smem_layout_perm)
|
| 516 |
+
|
| 517 |
+
# Handle semaphore reset — wait for dependent grids first
|
| 518 |
+
if const_expr(use_pdl and semaphore_to_reset is not None):
|
| 519 |
+
if (
|
| 520 |
+
tidx == 0
|
| 521 |
+
and m_block == cute.arch.grid_dim()[0] - 1
|
| 522 |
+
and k_block == cute.arch.grid_dim()[1] - 1
|
| 523 |
+
and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1
|
| 524 |
+
):
|
| 525 |
+
cute.arch.griddepcontrol_wait()
|
| 526 |
+
semaphore_to_reset[0] = 0
|
| 527 |
+
|
| 528 |
+
if const_expr(num_splits_dynamic_ptr is not None):
|
| 529 |
+
raise ValueError("K2 combine requires compile-time exact topK")
|
| 530 |
+
num_splits = Int32(self.topk)
|
| 531 |
+
# Handle variable length sequences using SeqlenInfo
|
| 532 |
+
seqlen_info = SeqlenInfo.create(
|
| 533 |
+
batch_idx=batch_idx,
|
| 534 |
+
seqlen_static=mO_partial.shape[0],
|
| 535 |
+
cu_seqlens=cu_seqlens,
|
| 536 |
+
seqused=seqused,
|
| 537 |
+
# Don't need to pass in tile size since we won't use offset_padded
|
| 538 |
+
)
|
| 539 |
+
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
| 540 |
+
|
| 541 |
+
num_head = mO_partial.shape[3]
|
| 542 |
+
max_idx = seqlen * num_head
|
| 543 |
+
output_scale = Float32(1.0)
|
| 544 |
+
if const_expr(mOutputScale is not None):
|
| 545 |
+
output_scale = mOutputScale[0]
|
| 546 |
+
|
| 547 |
+
if const_expr(not varlen) or m_block * self.tile_m < max_idx:
|
| 548 |
+
# Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)
|
| 549 |
+
if const_expr(use_pdl):
|
| 550 |
+
cute.arch.griddepcontrol_wait()
|
| 551 |
+
|
| 552 |
+
# ===============================
|
| 553 |
+
# Step 1: Load LSE_partial from gmem to shared memory
|
| 554 |
+
# ===============================
|
| 555 |
+
# `cLSE` (identity tensor for row/split coord tracking) is reused
|
| 556 |
+
# later in steps 4-5, so it must be defined on both branches.
|
| 557 |
+
cLSE = cute.make_identity_tensor((self.topk, self.tile_m))
|
| 558 |
+
# Reshape mLSE_partial to PackGQA packed layout and delegate the
|
| 559 |
+
# tile load to PackGQAComb.load_LSE. The packed form folds (H_q, Sq)
|
| 560 |
+
# into one compound dim with H_q innermost (stride 1), so thread
|
| 561 |
+
# rows that vary along h_pos produce one-sector coalesced reads.
|
| 562 |
+
# Non-varlen path only — varlen keeps the original inline loop.
|
| 563 |
+
if const_expr(not varlen):
|
| 564 |
+
mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
|
| 565 |
+
# mLSE_partial_cur: (H_q, topK, Sq) — after initial transpose
|
| 566 |
+
# [3,0,2,1] on [topK,B,Sq,H_q] and dropping B.
|
| 567 |
+
# Reorder to (H_q, Sq, topK) then group modes 0..1 for packed dim:
|
| 568 |
+
mLSE_partial_reord = cute.make_tensor(
|
| 569 |
+
mLSE_partial_cur.iterator,
|
| 570 |
+
cute.select(mLSE_partial_cur.layout, mode=[0, 2, 1]),
|
| 571 |
+
)
|
| 572 |
+
mLSE_partial_packed = cute.group_modes(mLSE_partial_reord, 0, 2)
|
| 573 |
+
# shape ((H_q, Sq), topK) with H_q innermost.
|
| 574 |
+
packgqa = PackGQAComb(
|
| 575 |
+
m_block_size=self.tile_m,
|
| 576 |
+
head_dim_padded=0, # unused for LSE load
|
| 577 |
+
check_hdim_oob=False, # unused for LSE load
|
| 578 |
+
qhead_per_kvhead=1, # unused; num_heads_divmod is passed explicitly
|
| 579 |
+
)
|
| 580 |
+
packgqa.load_LSE(
|
| 581 |
+
mLSE_partial_packed,
|
| 582 |
+
sLSE,
|
| 583 |
+
self.topk,
|
| 584 |
+
gmem_tiled_copy_LSE,
|
| 585 |
+
tidx,
|
| 586 |
+
m_block,
|
| 587 |
+
num_splits,
|
| 588 |
+
seqlen,
|
| 589 |
+
head_divmod,
|
| 590 |
+
mSplitCounts,
|
| 591 |
+
batch_idx,
|
| 592 |
+
qhead_per_kvhead,
|
| 593 |
+
)
|
| 594 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 595 |
+
mLSE_temperature_partial_cur = seqlen_info.offset_batch(
|
| 596 |
+
mLSE_temperature_partial, batch_idx, dim=3)
|
| 597 |
+
mLSE_temperature_partial_reord = cute.make_tensor(
|
| 598 |
+
mLSE_temperature_partial_cur.iterator,
|
| 599 |
+
cute.select(mLSE_temperature_partial_cur.layout, mode=[0, 2, 1]),
|
| 600 |
+
)
|
| 601 |
+
mLSE_temperature_partial_packed = cute.group_modes(
|
| 602 |
+
mLSE_temperature_partial_reord, 0, 2)
|
| 603 |
+
packgqa.load_LSE(
|
| 604 |
+
mLSE_temperature_partial_packed,
|
| 605 |
+
sLSE_temperature,
|
| 606 |
+
self.topk,
|
| 607 |
+
gmem_tiled_copy_LSE,
|
| 608 |
+
tidx,
|
| 609 |
+
m_block,
|
| 610 |
+
num_splits,
|
| 611 |
+
seqlen,
|
| 612 |
+
head_divmod,
|
| 613 |
+
mSplitCounts,
|
| 614 |
+
batch_idx,
|
| 615 |
+
qhead_per_kvhead,
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
# Varlen path keeps the same H_q-innermost flat-row contract:
|
| 619 |
+
# after transpose [1, 0, 2], mLSE_partial_cur is
|
| 620 |
+
# (q_local, split, head).
|
| 621 |
+
# mSplitCounts is the authoritative valid-split count per
|
| 622 |
+
# packed (q_abs, kv_head); masked splits stay at -inf and
|
| 623 |
+
# therefore drop out of the final kernel LSE_out reduction.
|
| 624 |
+
mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
|
| 625 |
+
mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
|
| 626 |
+
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 627 |
+
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
| 628 |
+
tLSEsLSE_temperature = gmem_thr_copy_LSE.partition_D(sLSE_temperature)
|
| 629 |
+
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
| 630 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 631 |
+
mLSE_temperature_partial_cur = seqlen_info.offset_batch(
|
| 632 |
+
mLSE_temperature_partial, batch_idx, dim=3)
|
| 633 |
+
mLSE_temperature_partial_copy = cute.tiled_divide(
|
| 634 |
+
mLSE_temperature_partial_cur, (1,))
|
| 635 |
+
|
| 636 |
+
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 637 |
+
mi = tLSEcLSE[0, 0, m][1]
|
| 638 |
+
idx = m_block * self.tile_m + mi
|
| 639 |
+
if idx < max_idx:
|
| 640 |
+
m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
|
| 641 |
+
row_count = (
|
| 642 |
+
mSplitCounts[offset + m_idx, head_idx // qhead_per_kvhead]
|
| 643 |
+
if const_expr(mSplitCounts is not None)
|
| 644 |
+
else num_splits
|
| 645 |
+
)
|
| 646 |
+
mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
|
| 647 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 648 |
+
mLSE_temperature_partial_cur_copy = (
|
| 649 |
+
mLSE_temperature_partial_copy[None, m_idx, None, head_idx])
|
| 650 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 651 |
+
si = tLSEcLSE[0, s, 0][0]
|
| 652 |
+
if si < num_splits and si < row_count:
|
| 653 |
+
cute.copy(
|
| 654 |
+
gmem_thr_copy_LSE,
|
| 655 |
+
mLSE_partial_cur_copy[None, si],
|
| 656 |
+
tLSEsLSE[None, s, m],
|
| 657 |
+
)
|
| 658 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 659 |
+
cute.copy(
|
| 660 |
+
gmem_thr_copy_LSE,
|
| 661 |
+
mLSE_temperature_partial_cur_copy[None, si],
|
| 662 |
+
tLSEsLSE_temperature[None, s, m],
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 666 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 667 |
+
tLSEsLSE_temperature[None, s, m].fill(-Float32.inf)
|
| 668 |
+
else:
|
| 669 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 670 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 671 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 672 |
+
tLSEsLSE_temperature[None, s, m].fill(-Float32.inf)
|
| 673 |
+
cute.arch.cp_async_commit_group()
|
| 674 |
+
|
| 675 |
+
# ===============================
|
| 676 |
+
# Step 2: Load O_partial for pipeline stages
|
| 677 |
+
# ===============================
|
| 678 |
+
|
| 679 |
+
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
| 680 |
+
cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
|
| 681 |
+
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
| 682 |
+
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
| 683 |
+
mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4)
|
| 684 |
+
|
| 685 |
+
# Precompute per-row values for flattened (q_local, head) tiles.
|
| 686 |
+
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
| 687 |
+
tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
|
| 688 |
+
tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
|
| 689 |
+
tOSplitCount = cute.make_rmem_tensor(num_rows, cutlass.Int32)
|
| 690 |
+
tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64)
|
| 691 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 692 |
+
mi = tOcO[0, m, 0][0] # m coordinate in tile
|
| 693 |
+
idx = m_block * self.tile_m + mi
|
| 694 |
+
if idx >= max_idx:
|
| 695 |
+
tOhidx[m] = -1
|
| 696 |
+
tOmidx[m] = 0
|
| 697 |
+
tOSplitCount[m] = 0
|
| 698 |
+
tOrOptr[m] = cutlass.Int64(0)
|
| 699 |
+
else:
|
| 700 |
+
tOmidx[m], tOhidx[m] = self.decode_flat_row_idx(idx, head_divmod)
|
| 701 |
+
if const_expr(mSplitCounts is None):
|
| 702 |
+
tOSplitCount[m] = num_splits
|
| 703 |
+
elif const_expr(cu_seqlens is None):
|
| 704 |
+
tOSplitCount[m] = mSplitCounts[
|
| 705 |
+
batch_idx, tOmidx[m], tOhidx[m] // qhead_per_kvhead
|
| 706 |
+
]
|
| 707 |
+
else:
|
| 708 |
+
tOSplitCount[m] = mSplitCounts[
|
| 709 |
+
offset + tOmidx[m], tOhidx[m] // qhead_per_kvhead
|
| 710 |
+
]
|
| 711 |
+
tOrOptr[m] = utils.elem_pointer(
|
| 712 |
+
mO_partial_cur,
|
| 713 |
+
(tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]),
|
| 714 |
+
).toint()
|
| 715 |
+
|
| 716 |
+
tOpO = None
|
| 717 |
+
if const_expr(not self.is_even_k):
|
| 718 |
+
tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean)
|
| 719 |
+
for k in cutlass.range(cute.size(tOpO), unroll_full=True):
|
| 720 |
+
tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
|
| 721 |
+
# if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
|
| 722 |
+
|
| 723 |
+
load_O_partial = partial(
|
| 724 |
+
self.load_O_partial,
|
| 725 |
+
gmem_tiled_copy_O_partial,
|
| 726 |
+
tOrOptr,
|
| 727 |
+
tOsO_partial,
|
| 728 |
+
tOhidx,
|
| 729 |
+
tOSplitCount,
|
| 730 |
+
tOpO,
|
| 731 |
+
tOcO,
|
| 732 |
+
mO_partial_cur.layout,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Load first few stages of O_partial
|
| 736 |
+
for stage in cutlass.range(self.stages - 1, unroll_full=True):
|
| 737 |
+
if stage < num_splits:
|
| 738 |
+
load_O_partial(stage, stage)
|
| 739 |
+
cute.arch.cp_async_commit_group()
|
| 740 |
+
|
| 741 |
+
# ===============================
|
| 742 |
+
# Step 3: Load and transpose LSE from smem to registers
|
| 743 |
+
# ===============================
|
| 744 |
+
|
| 745 |
+
# Wait for LSE and initial O partial stages to complete
|
| 746 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 747 |
+
cute.arch.sync_threads()
|
| 748 |
+
# if cute.arch.thread_idx()[0] == 0:
|
| 749 |
+
# # cute.print_tensor(sLSE)
|
| 750 |
+
# for i in range(64):
|
| 751 |
+
# cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
|
| 752 |
+
# cute.arch.sync_threads()
|
| 753 |
+
|
| 754 |
+
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
| 755 |
+
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
| 756 |
+
ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
|
| 757 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
| 758 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 759 |
+
ts2rsLSE_temperature = s2r_thr_copy_LSE.partition_S(sLSE_temperature)
|
| 760 |
+
ts2rrLSE_temperature = cute.make_rmem_tensor_like(ts2rsLSE_temperature)
|
| 761 |
+
cute.copy(
|
| 762 |
+
s2r_tiled_copy_LSE,
|
| 763 |
+
ts2rsLSE_temperature,
|
| 764 |
+
ts2rrLSE_temperature,
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# ===============================
|
| 768 |
+
# Step 4: Compute final LSE along split dimension
|
| 769 |
+
# ===============================
|
| 770 |
+
|
| 771 |
+
final_lse = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
|
| 772 |
+
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
| 773 |
+
# We compute the max valid split for each row to short-circuit the computation later
|
| 774 |
+
max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
|
| 775 |
+
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
| 776 |
+
# Compute max, scales, and final LSE for each row. Invalid splits
|
| 777 |
+
# have already been filled with -inf, so Step 5 can write the
|
| 778 |
+
# kernel-native LSE_out directly.
|
| 779 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 780 |
+
# Find max LSE value across splits
|
| 781 |
+
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
| 782 |
+
lse_max = cute.arch.warp_reduction_max(
|
| 783 |
+
ts2rrLSE[None, None, m]
|
| 784 |
+
.load()
|
| 785 |
+
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
| 786 |
+
threads_in_group=threads_per_col,
|
| 787 |
+
)
|
| 788 |
+
# if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
|
| 789 |
+
# Find max valid split index
|
| 790 |
+
max_valid_idx = -1
|
| 791 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 792 |
+
if ts2rrLSE[0, s, m] != -Float32.inf:
|
| 793 |
+
max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
|
| 794 |
+
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
|
| 795 |
+
max_valid_split[m] = cute.arch.warp_reduction_max(
|
| 796 |
+
max_valid_idx, threads_in_group=threads_per_col
|
| 797 |
+
)
|
| 798 |
+
# Compute exp scales and sum
|
| 799 |
+
lse_max_cur = (
|
| 800 |
+
0.0 if lse_max == -Float32.inf else lse_max
|
| 801 |
+
) # In case all local LSEs are -inf
|
| 802 |
+
LOG2_E = math.log2(math.e)
|
| 803 |
+
lse_sum_cur = 0.0
|
| 804 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 805 |
+
scale = cute.math.exp2(
|
| 806 |
+
ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
|
| 807 |
+
)
|
| 808 |
+
lse_sum_cur += scale
|
| 809 |
+
ts2rrLSE[0, s, m] = scale # Store scale for later use
|
| 810 |
+
lse_sum_cur = cute.arch.warp_reduction_sum(
|
| 811 |
+
lse_sum_cur, threads_in_group=threads_per_col
|
| 812 |
+
)
|
| 813 |
+
# Normalize scales
|
| 814 |
+
inv_sum = 0.0
|
| 815 |
+
if max_valid_split[m] < 0 or lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur:
|
| 816 |
+
final_lse[m] = -Float32.inf
|
| 817 |
+
else:
|
| 818 |
+
final_lse[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
|
| 819 |
+
inv_sum = 1.0 / lse_sum_cur
|
| 820 |
+
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
|
| 821 |
+
# Store the scales exp(lse - lse_logsum) back to smem
|
| 822 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
|
| 823 |
+
|
| 824 |
+
if const_expr(mLSE_temperature_partial is not None):
|
| 825 |
+
final_lse_temperature = cute.make_rmem_tensor(
|
| 826 |
+
cute.size(ts2rrLSE_temperature, mode=[2]), Float32)
|
| 827 |
+
for m in cutlass.range(cute.size(ts2rrLSE_temperature, mode=[2]), unroll_full=True):
|
| 828 |
+
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
| 829 |
+
lse_temperature_max = cute.arch.warp_reduction_max(
|
| 830 |
+
ts2rrLSE_temperature[None, None, m]
|
| 831 |
+
.load()
|
| 832 |
+
.reduce(
|
| 833 |
+
cute.ReductionOp.MAX,
|
| 834 |
+
init_val=-Float32.inf,
|
| 835 |
+
reduction_profile=0,
|
| 836 |
+
),
|
| 837 |
+
threads_in_group=threads_per_col,
|
| 838 |
+
)
|
| 839 |
+
lse_temperature_max_cur = (
|
| 840 |
+
0.0 if lse_temperature_max == -Float32.inf else lse_temperature_max
|
| 841 |
+
)
|
| 842 |
+
LOG2_E = math.log2(math.e)
|
| 843 |
+
lse_temperature_sum_cur = 0.0
|
| 844 |
+
for s in cutlass.range(
|
| 845 |
+
cute.size(ts2rrLSE_temperature, mode=[1]), unroll_full=True):
|
| 846 |
+
scale = cute.math.exp2(
|
| 847 |
+
ts2rrLSE_temperature[0, s, m] * LOG2_E
|
| 848 |
+
- (lse_temperature_max_cur * LOG2_E),
|
| 849 |
+
fastmath=True,
|
| 850 |
+
)
|
| 851 |
+
lse_temperature_sum_cur += scale
|
| 852 |
+
lse_temperature_sum_cur = cute.arch.warp_reduction_sum(
|
| 853 |
+
lse_temperature_sum_cur, threads_in_group=threads_per_col
|
| 854 |
+
)
|
| 855 |
+
if (
|
| 856 |
+
max_valid_split[m] < 0
|
| 857 |
+
or lse_temperature_sum_cur == 0.0
|
| 858 |
+
or lse_temperature_sum_cur != lse_temperature_sum_cur
|
| 859 |
+
):
|
| 860 |
+
final_lse_temperature[m] = -Float32.inf
|
| 861 |
+
else:
|
| 862 |
+
final_lse_temperature[m] = (
|
| 863 |
+
cute.math.log(lse_temperature_sum_cur, fastmath=True)
|
| 864 |
+
+ lse_temperature_max
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
# Store max valid split to smem
|
| 868 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 869 |
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 870 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 871 |
+
if mi < self.tile_m:
|
| 872 |
+
sMaxValidSplit[mi] = max_valid_split[m]
|
| 873 |
+
|
| 874 |
+
# ===============================
|
| 875 |
+
# Step 5: Store final LSE to gmem
|
| 876 |
+
# This writeback is the authoritative LSE_out returned by the
|
| 877 |
+
# public Sparse Attention / Sparse Page Attention interface.
|
| 878 |
+
# ===============================
|
| 879 |
+
|
| 880 |
+
if const_expr(mLSE is not None):
|
| 881 |
+
if const_expr(cu_seqlens is None):
|
| 882 |
+
mLSE_cur = mLSE[None, None, batch_idx]
|
| 883 |
+
else:
|
| 884 |
+
mLSE_cur = cute.domain_offset((offset, 0), mLSE)
|
| 885 |
+
if const_expr(mLSE_temperature is not None):
|
| 886 |
+
if const_expr(cu_seqlens is None):
|
| 887 |
+
mLSE_temperature_cur = mLSE_temperature[None, None, batch_idx]
|
| 888 |
+
else:
|
| 889 |
+
mLSE_temperature_cur = cute.domain_offset(
|
| 890 |
+
(offset, 0), mLSE_temperature)
|
| 891 |
+
if k_block == 0: # Only first k_block writes LSE when mLSE is provided
|
| 892 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 893 |
+
if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
|
| 894 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 895 |
+
idx = m_block * self.tile_m + mi
|
| 896 |
+
if idx < max_idx:
|
| 897 |
+
m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
|
| 898 |
+
mLSE_cur[m_idx, head_idx] = final_lse[m]
|
| 899 |
+
if const_expr(mLSE_temperature is not None):
|
| 900 |
+
mLSE_temperature_cur[m_idx, head_idx] = (
|
| 901 |
+
final_lse_temperature[m])
|
| 902 |
+
|
| 903 |
+
# ===============================
|
| 904 |
+
# Step 6: Read O_partial and accumulate final O
|
| 905 |
+
# ===============================
|
| 906 |
+
|
| 907 |
+
cute.arch.sync_threads()
|
| 908 |
+
|
| 909 |
+
# Get max valid split for this thread
|
| 910 |
+
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
| 911 |
+
for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
|
| 912 |
+
thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
|
| 913 |
+
|
| 914 |
+
tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
|
| 915 |
+
tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
|
| 916 |
+
tOrO.fill(0.0)
|
| 917 |
+
|
| 918 |
+
stage_load = self.stages - 1
|
| 919 |
+
stage_compute = 0
|
| 920 |
+
|
| 921 |
+
# Main accumulation loop
|
| 922 |
+
for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
|
| 923 |
+
# Get scales for this split
|
| 924 |
+
scale = cute.make_rmem_tensor(num_rows, Float32)
|
| 925 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 926 |
+
scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
|
| 927 |
+
|
| 928 |
+
# Load next stage if needed
|
| 929 |
+
split_to_load = s + self.stages - 1
|
| 930 |
+
if split_to_load <= thr_max_valid_split:
|
| 931 |
+
load_O_partial(split_to_load, stage_load)
|
| 932 |
+
cute.arch.cp_async_commit_group()
|
| 933 |
+
stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
|
| 934 |
+
|
| 935 |
+
# Wait for the current stage to be ready
|
| 936 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 937 |
+
# We don't need __syncthreads() because each thread is just reading its own data from smem
|
| 938 |
+
# Copy from smem to registers
|
| 939 |
+
cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
|
| 940 |
+
stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
|
| 941 |
+
|
| 942 |
+
# Accumulate scaled partial results
|
| 943 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 944 |
+
if tOhidx[m] >= 0 and scale[m] > 0.0:
|
| 945 |
+
tOrO[None, m, None].store(
|
| 946 |
+
tOrO[None, m, None].load()
|
| 947 |
+
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
# Flush any outstanding async-copy groups before the local Step-7
|
| 951 |
+
# permutation buffer is read on the tail of the kernel.
|
| 952 |
+
cute.arch.cp_async_wait_group(0)
|
| 953 |
+
cute.arch.sync_threads()
|
| 954 |
+
|
| 955 |
+
# ===============================
|
| 956 |
+
# Step 7: Write final O to gmem (fake→real via SMEM)
|
| 957 |
+
# ===============================
|
| 958 |
+
|
| 959 |
+
mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)
|
| 960 |
+
if const_expr(cu_seqlens is None):
|
| 961 |
+
mO_cur = mO[None, None, None, batch_idx]
|
| 962 |
+
else:
|
| 963 |
+
mO_cur = cute.domain_offset((offset, 0, 0), mO)
|
| 964 |
+
mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
|
| 965 |
+
num_vals = const_expr(cute.size(tOcO, mode=[0]))
|
| 966 |
+
if const_expr(not use_pdl):
|
| 967 |
+
# Direct / standalone calls don't participate in the K1->K2
|
| 968 |
+
# dependency chain. Use a simple per-element real-column store
|
| 969 |
+
# path here to keep mixed-shape launches stable.
|
| 970 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 971 |
+
if tOhidx[m] >= 0:
|
| 972 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 973 |
+
if const_expr(self.is_even_k) or tOpO[k]:
|
| 974 |
+
for v in cutlass.range(num_vals, unroll_full=True):
|
| 975 |
+
fake_col = tOcO[v, 0, k][1]
|
| 976 |
+
if const_expr(self.use_stg128_fp8_layout):
|
| 977 |
+
real_col = stg128_fp8_fake_col_to_real_col(fake_col)
|
| 978 |
+
elif const_expr(self.use_stg128_half_layout):
|
| 979 |
+
real_col = stg128_half_fake_col_to_real_col(fake_col)
|
| 980 |
+
else:
|
| 981 |
+
real_col = stg128_fake_col_to_real_col(fake_col)
|
| 982 |
+
o_val = tOrO[v, m, k]
|
| 983 |
+
if const_expr(mOutputScale is not None):
|
| 984 |
+
o_val = o_val * output_scale
|
| 985 |
+
mO_cur[tOmidx[m], real_col, tOhidx[m]] = o_val.to(self.dtype)
|
| 986 |
+
else:
|
| 987 |
+
# 7a: fp32 accumulator -> output dtype SMEM with fake→real
|
| 988 |
+
# permutation. The dedicated permutation buffer stays separate
|
| 989 |
+
# from the O_partial pipeline staging buffer.
|
| 990 |
+
sO_perm = sO_perm_buf
|
| 991 |
+
|
| 992 |
+
if const_expr(self.dtype in [cutlass.BFloat16, cutlass.Float16]):
|
| 993 |
+
# O_partial uses a dtype-specific STG.128 fake layout, but
|
| 994 |
+
# sO_perm is in the final O dtype. For all supported fake
|
| 995 |
+
# layouts, adjacent fake pairs map to adjacent real columns,
|
| 996 |
+
# so write the final BF16/F16 O pair as one 32-bit SMEM store.
|
| 997 |
+
assert num_vals % 2 == 0
|
| 998 |
+
r2s_o_pair_atom = cute.make_copy_atom(
|
| 999 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 1000 |
+
cutlass.Int32,
|
| 1001 |
+
num_bits_per_copy=32,
|
| 1002 |
+
)
|
| 1003 |
+
rO_pair_word = cute.make_rmem_tensor((1,), cutlass.Int32)
|
| 1004 |
+
sO_perm_i32_base = cute.make_ptr(
|
| 1005 |
+
dtype=cutlass.Int32,
|
| 1006 |
+
value=sO_perm.iterator.toint(),
|
| 1007 |
+
mem_space=sO_perm.iterator.memspace,
|
| 1008 |
+
assumed_align=4,
|
| 1009 |
+
)
|
| 1010 |
+
sO_perm_i32_row_stride = Int32((self.k_block_size + 16) // 2)
|
| 1011 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 1012 |
+
row_local = tOcO[0, m, 0][0]
|
| 1013 |
+
if tOhidx[m] >= 0:
|
| 1014 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 1015 |
+
for v_pair in cutlass.range(num_vals // 2, unroll_full=True):
|
| 1016 |
+
v = v_pair * 2
|
| 1017 |
+
fake_col = tOcO[v, 0, k][1]
|
| 1018 |
+
if const_expr(self.use_stg128_fp8_layout):
|
| 1019 |
+
real_col = stg128_fp8_fake_col_to_real_col(fake_col)
|
| 1020 |
+
elif const_expr(self.use_stg128_half_layout):
|
| 1021 |
+
real_col = stg128_half_fake_col_to_real_col(fake_col)
|
| 1022 |
+
else:
|
| 1023 |
+
real_col = stg128_fake_col_to_real_col(fake_col)
|
| 1024 |
+
o0 = tOrO[v, m, k]
|
| 1025 |
+
o1 = tOrO[v + 1, m, k]
|
| 1026 |
+
if const_expr(mOutputScale is not None):
|
| 1027 |
+
o0, o1 = cute.arch.mul_packed_f32x2(
|
| 1028 |
+
(o0, o1),
|
| 1029 |
+
(output_scale, output_scale),
|
| 1030 |
+
)
|
| 1031 |
+
rO_pair_word[0] = utils.cvt_f16x2_f32(o0, o1, self.dtype)
|
| 1032 |
+
smem_pair_ptr = cute.make_ptr(
|
| 1033 |
+
dtype=cutlass.Int32,
|
| 1034 |
+
value=(
|
| 1035 |
+
sO_perm_i32_base.toint()
|
| 1036 |
+
+ Int64(
|
| 1037 |
+
row_local * sO_perm_i32_row_stride
|
| 1038 |
+
+ real_col // Int32(2)
|
| 1039 |
+
)
|
| 1040 |
+
* Int64(4)
|
| 1041 |
+
),
|
| 1042 |
+
mem_space=sO_perm.iterator.memspace,
|
| 1043 |
+
assumed_align=4,
|
| 1044 |
+
)
|
| 1045 |
+
sO_pair = cute.make_tensor(
|
| 1046 |
+
smem_pair_ptr,
|
| 1047 |
+
cute.make_layout((1,), stride=(1,)),
|
| 1048 |
+
)
|
| 1049 |
+
cute.copy(r2s_o_pair_atom, rO_pair_word, sO_pair)
|
| 1050 |
+
else:
|
| 1051 |
+
# 7a: iterate over ALL val elements in mode[0].
|
| 1052 |
+
# tOcO[v, m, k][1] gives different fake_col for each v.
|
| 1053 |
+
r2s_o_scalar_atom = cute.make_copy_atom(
|
| 1054 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 1055 |
+
self.dtype,
|
| 1056 |
+
num_bits_per_copy=self.dtype.width,
|
| 1057 |
+
)
|
| 1058 |
+
rO_scalar = cute.make_rmem_tensor((1,), self.dtype)
|
| 1059 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 1060 |
+
row_local = tOcO[0, m, 0][0]
|
| 1061 |
+
if tOhidx[m] >= 0:
|
| 1062 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 1063 |
+
for v in cutlass.range(num_vals, unroll_full=True):
|
| 1064 |
+
fake_col = tOcO[v, 0, k][1]
|
| 1065 |
+
if const_expr(self.use_stg128_fp8_layout):
|
| 1066 |
+
real_col = stg128_fp8_fake_col_to_real_col(fake_col)
|
| 1067 |
+
elif const_expr(self.use_stg128_half_layout):
|
| 1068 |
+
real_col = stg128_half_fake_col_to_real_col(fake_col)
|
| 1069 |
+
else:
|
| 1070 |
+
real_col = stg128_fake_col_to_real_col(fake_col)
|
| 1071 |
+
o_val = tOrO[v, m, k]
|
| 1072 |
+
if const_expr(mOutputScale is not None):
|
| 1073 |
+
o_val = o_val * output_scale
|
| 1074 |
+
rO_scalar[0] = o_val.to(self.dtype)
|
| 1075 |
+
smem_ptr = utils.elem_pointer(sO_perm, (row_local, real_col))
|
| 1076 |
+
smem_scalar_ptr = cute.make_ptr(
|
| 1077 |
+
dtype=self.dtype,
|
| 1078 |
+
value=smem_ptr.toint(),
|
| 1079 |
+
mem_space=sO_perm.iterator.memspace,
|
| 1080 |
+
assumed_align=self.dtype.width // 8,
|
| 1081 |
+
)
|
| 1082 |
+
sO_scalar = cute.make_tensor(
|
| 1083 |
+
smem_scalar_ptr,
|
| 1084 |
+
cute.make_layout((1,), stride=(1,)),
|
| 1085 |
+
)
|
| 1086 |
+
cute.copy(r2s_o_scalar_atom, rO_scalar, sO_scalar)
|
| 1087 |
+
|
| 1088 |
+
cute.arch.sync_threads()
|
| 1089 |
+
|
| 1090 |
+
# 7b: SMEM (real order, output dtype) → GMEM
|
| 1091 |
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 1092 |
+
tOcO_store = gmem_thr_copy_O.partition_D(cO)
|
| 1093 |
+
tOsO_store = gmem_thr_copy_O.partition_D(sO_perm)
|
| 1094 |
+
rO = cute.make_rmem_tensor(tOcO_store.shape, self.dtype)
|
| 1095 |
+
elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
|
| 1096 |
+
num_store_rows = const_expr(cute.size(tOcO_store, mode=[1]))
|
| 1097 |
+
num_store_vals = const_expr(cute.size(tOcO_store, mode=[0]))
|
| 1098 |
+
tOpO_store = None
|
| 1099 |
+
if const_expr(not self.is_even_k):
|
| 1100 |
+
tOpO_store = cute.make_rmem_tensor(cute.size(tOcO_store, mode=[2]), Boolean)
|
| 1101 |
+
for k in cutlass.range(cute.size(tOpO_store), unroll_full=True):
|
| 1102 |
+
tOpO_store[k] = (
|
| 1103 |
+
tOcO_store[0, 0, k][1]
|
| 1104 |
+
< mO_partial.shape[1] - k_block * self.k_block_size
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# Read output dtype from SMEM (now in real column order).
|
| 1108 |
+
for m in cutlass.range(num_store_rows, unroll_full=True):
|
| 1109 |
+
for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True):
|
| 1110 |
+
if const_expr(self.is_even_k) or tOpO_store[k]:
|
| 1111 |
+
cute.autovec_copy(tOsO_store[None, m, k], rO[None, m, k])
|
| 1112 |
+
|
| 1113 |
+
# Write bf16 to GMEM using gmem_tiled_copy_O (same as original FA Step 7)
|
| 1114 |
+
for m in cutlass.range(num_store_rows, unroll_full=True):
|
| 1115 |
+
row_local = tOcO_store[0, m, 0][0]
|
| 1116 |
+
idx = m_block * self.tile_m + row_local
|
| 1117 |
+
if idx < max_idx:
|
| 1118 |
+
m_idx, head_idx = self.decode_flat_row_idx(idx, head_divmod)
|
| 1119 |
+
mO_cur_copy = cute.tiled_divide(
|
| 1120 |
+
mO_cur[m_idx, None, head_idx], (elems_per_store,)
|
| 1121 |
+
)
|
| 1122 |
+
for k in cutlass.range(cute.size(tOcO_store, mode=[2]), unroll_full=True):
|
| 1123 |
+
k_idx = tOcO_store[0, 0, k][1] // elems_per_store
|
| 1124 |
+
if const_expr(self.is_even_k) or tOpO_store[k]:
|
| 1125 |
+
cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
|
| 1126 |
+
|
| 1127 |
+
@cute.jit
|
| 1128 |
+
def load_O_partial(
|
| 1129 |
+
self,
|
| 1130 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 1131 |
+
tOrOptr: cute.Tensor,
|
| 1132 |
+
tOsO_partial: cute.Tensor,
|
| 1133 |
+
tOhidx: cute.Tensor,
|
| 1134 |
+
tOSplitCount: cute.Tensor,
|
| 1135 |
+
tOpO: Optional[cute.Tensor],
|
| 1136 |
+
tOcO: cute.Tensor,
|
| 1137 |
+
mO_cur_partial_layout: cute.Layout,
|
| 1138 |
+
split: Int32,
|
| 1139 |
+
stage: Int32,
|
| 1140 |
+
) -> None:
|
| 1141 |
+
elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
|
| 1142 |
+
tOsO_partial_cur = tOsO_partial[None, None, None, stage]
|
| 1143 |
+
for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
|
| 1144 |
+
if tOhidx[m] >= 0:
|
| 1145 |
+
o_gmem_ptr = cute.make_ptr(
|
| 1146 |
+
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
|
| 1147 |
+
)
|
| 1148 |
+
mO_partial_cur = cute.make_tensor(
|
| 1149 |
+
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
|
| 1150 |
+
)
|
| 1151 |
+
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
| 1152 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 1153 |
+
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
| 1154 |
+
if split < tOSplitCount[m] and (const_expr(tOpO is None) or tOpO[k]):
|
| 1155 |
+
cute.copy(
|
| 1156 |
+
gmem_tiled_copy_O_partial,
|
| 1157 |
+
mO_partial_cur_copy[None, k_idx, split],
|
| 1158 |
+
tOsO_partial_cur[None, m, k],
|
| 1159 |
+
)
|
| 1160 |
+
else:
|
| 1161 |
+
tOsO_partial_cur[None, m, k].fill(0)
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
def _get_cutlass_dtype(torch_dtype: torch.dtype):
|
| 1165 |
+
if torch_dtype not in torch2cute_dtype_map:
|
| 1166 |
+
raise TypeError(f"Unsupported dtype: {torch_dtype}")
|
| 1167 |
+
return torch2cute_dtype_map[torch_dtype]
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
_combine_compile_cache = {}
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
def _get_cpasync_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
|
| 1174 |
+
dtype_byte = const_expr(dtype.width // 8)
|
| 1175 |
+
bytes_per_row = const_expr(k_dim * dtype_byte)
|
| 1176 |
+
smem_k_block_size = (
|
| 1177 |
+
const_expr(
|
| 1178 |
+
128
|
| 1179 |
+
if bytes_per_row % 128 == 0
|
| 1180 |
+
else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
|
| 1181 |
+
)
|
| 1182 |
+
// dtype_byte
|
| 1183 |
+
)
|
| 1184 |
+
swizzle_bits = (
|
| 1185 |
+
4
|
| 1186 |
+
if smem_k_block_size == 128
|
| 1187 |
+
else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
|
| 1188 |
+
)
|
| 1189 |
+
swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
|
| 1190 |
+
return cute.make_composed_layout(
|
| 1191 |
+
cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
|
| 1192 |
+
0,
|
| 1193 |
+
cute.make_ordered_layout(
|
| 1194 |
+
(8 if const_expr(k_dim % 32 == 0) else 16, smem_k_block_size),
|
| 1195 |
+
order=(1, 0),
|
| 1196 |
+
),
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
def combine(
|
| 1201 |
+
o_partial_fake,
|
| 1202 |
+
lse_partial,
|
| 1203 |
+
o_out,
|
| 1204 |
+
lse_out,
|
| 1205 |
+
*,
|
| 1206 |
+
lse_temperature_partial=None,
|
| 1207 |
+
lse_temperature_out=None,
|
| 1208 |
+
cu_seqlens=None,
|
| 1209 |
+
seqused=None,
|
| 1210 |
+
split_counts=None,
|
| 1211 |
+
output_scale=None,
|
| 1212 |
+
use_pdl=False,
|
| 1213 |
+
):
|
| 1214 |
+
"""K2: merge sparse forward split partials into the final output.
|
| 1215 |
+
|
| 1216 |
+
STG.128 fake-layout handling remains an internal implementation detail.
|
| 1217 |
+
When lse_out is provided, the kernel writes the final authoritative
|
| 1218 |
+
log-sum-exp for each query row/head directly into that tensor.
|
| 1219 |
+
|
| 1220 |
+
Args:
|
| 1221 |
+
o_partial_fake:
|
| 1222 |
+
Batched: [num_splits, batch, Sq, head_q, dim]
|
| 1223 |
+
Varlen: [num_splits, total_q, head_q, dim]
|
| 1224 |
+
lse_partial:
|
| 1225 |
+
Batched: [num_splits, batch, Sq, head_q]
|
| 1226 |
+
Varlen: [num_splits, total_q, head_q]
|
| 1227 |
+
o_out:
|
| 1228 |
+
Batched: [batch, Sq, head_q, dim]
|
| 1229 |
+
Varlen: [total_q, head_q, dim]
|
| 1230 |
+
lse_out:
|
| 1231 |
+
Batched: [batch, Sq, head_q]
|
| 1232 |
+
Varlen: [total_q, head_q]
|
| 1233 |
+
lse_temperature_partial:
|
| 1234 |
+
Optional temperature-scaled LSE partial with the same shape as
|
| 1235 |
+
lse_partial.
|
| 1236 |
+
lse_temperature_out:
|
| 1237 |
+
Optional temperature-scaled final LSE with the same shape as
|
| 1238 |
+
lse_out.
|
| 1239 |
+
cu_seqlens: Optional [batch + 1] int32 for varlen-Q combine.
|
| 1240 |
+
seqused: Optional [batch] int32 effective lengths for combine.
|
| 1241 |
+
split_counts: Optional int32 rowwise valid split counts prepared from
|
| 1242 |
+
q2k metadata. Batched: [batch, seqlen, head_kv]. Varlen:
|
| 1243 |
+
[total_q, head_kv].
|
| 1244 |
+
output_scale: Optional fp32 tensor with at least one element. When
|
| 1245 |
+
provided, the final O accumulator is multiplied once before store.
|
| 1246 |
+
use_pdl: When True, wait on PDL dependencies from the producer K1
|
| 1247 |
+
kernel. When False, launch without PDL waits.
|
| 1248 |
+
"""
|
| 1249 |
+
D = o_partial_fake.shape[-1]
|
| 1250 |
+
num_splits = o_partial_fake.shape[0]
|
| 1251 |
+
return_temperature_lse = (
|
| 1252 |
+
lse_temperature_partial is not None or lse_temperature_out is not None
|
| 1253 |
+
)
|
| 1254 |
+
if (lse_temperature_partial is None) != (lse_temperature_out is None):
|
| 1255 |
+
raise ValueError(
|
| 1256 |
+
"lse_temperature_partial and lse_temperature_out must either both be provided or both be None"
|
| 1257 |
+
)
|
| 1258 |
+
if lse_temperature_partial is not None and lse_temperature_partial.shape != lse_partial.shape:
|
| 1259 |
+
raise ValueError(
|
| 1260 |
+
"lse_temperature_partial must have the same shape as lse_partial, "
|
| 1261 |
+
f"got {lse_temperature_partial.shape} vs {lse_partial.shape}"
|
| 1262 |
+
)
|
| 1263 |
+
if lse_temperature_out is not None:
|
| 1264 |
+
if lse_out is None:
|
| 1265 |
+
raise ValueError("lse_temperature_out requires lse_out")
|
| 1266 |
+
if lse_temperature_out.shape != lse_out.shape:
|
| 1267 |
+
raise ValueError(
|
| 1268 |
+
"lse_temperature_out must have the same shape as lse_out, "
|
| 1269 |
+
f"got {lse_temperature_out.shape} vs {lse_out.shape}"
|
| 1270 |
+
)
|
| 1271 |
+
if lse_temperature_out.dtype != torch.float32 or lse_temperature_partial.dtype != torch.float32:
|
| 1272 |
+
raise TypeError("temperature LSE tensors must be torch.float32")
|
| 1273 |
+
|
| 1274 |
+
partial_dtype = _get_cutlass_dtype(o_partial_fake.dtype)
|
| 1275 |
+
out_dtype = _get_cutlass_dtype(o_out.dtype)
|
| 1276 |
+
if output_scale is not None:
|
| 1277 |
+
if output_scale.dtype != torch.float32:
|
| 1278 |
+
raise TypeError(f"output_scale must be torch.float32, got {output_scale.dtype}")
|
| 1279 |
+
if output_scale.numel() < 1:
|
| 1280 |
+
raise ValueError("output_scale must contain at least one element")
|
| 1281 |
+
if output_scale.device != o_out.device:
|
| 1282 |
+
raise ValueError("output_scale must be on the same device as o_out")
|
| 1283 |
+
output_scale = output_scale.contiguous()
|
| 1284 |
+
if split_counts is not None:
|
| 1285 |
+
if split_counts.dtype != torch.int32:
|
| 1286 |
+
raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}")
|
| 1287 |
+
if o_out.ndim == 4:
|
| 1288 |
+
if split_counts.ndim != 3:
|
| 1289 |
+
raise ValueError(
|
| 1290 |
+
f"batched split_counts must have shape [batch, seqlen, head_kv], got {split_counts.shape}"
|
| 1291 |
+
)
|
| 1292 |
+
if split_counts.shape[:2] != o_out.shape[:2]:
|
| 1293 |
+
raise ValueError(
|
| 1294 |
+
f"split_counts shape {split_counts.shape} must match batch/seqlen of o_out {o_out.shape}"
|
| 1295 |
+
)
|
| 1296 |
+
else:
|
| 1297 |
+
if cu_seqlens is None:
|
| 1298 |
+
raise ValueError("split_counts with varlen output requires cu_seqlens")
|
| 1299 |
+
if split_counts.ndim != 2:
|
| 1300 |
+
raise ValueError(
|
| 1301 |
+
f"varlen split_counts must have shape [total_q, head_kv], got {split_counts.shape}"
|
| 1302 |
+
)
|
| 1303 |
+
if split_counts.shape[0] != o_out.shape[0]:
|
| 1304 |
+
raise ValueError(
|
| 1305 |
+
f"split_counts total_q ({split_counts.shape[0]}) must match o_out total_q "
|
| 1306 |
+
f"({o_out.shape[0]})"
|
| 1307 |
+
)
|
| 1308 |
+
if o_out.shape[-2] % split_counts.shape[-1] != 0:
|
| 1309 |
+
raise ValueError(
|
| 1310 |
+
f"o_out heads ({o_out.shape[-2]}) must be divisible by split_counts heads ({split_counts.shape[-1]})"
|
| 1311 |
+
)
|
| 1312 |
+
qheadperkv = o_out.shape[-2] // split_counts.shape[-1]
|
| 1313 |
+
else:
|
| 1314 |
+
qheadperkv = 1
|
| 1315 |
+
if cu_seqlens is not None:
|
| 1316 |
+
if cu_seqlens.dtype != torch.int32:
|
| 1317 |
+
raise TypeError(f"cu_seqlens must be torch.int32, got {cu_seqlens.dtype}")
|
| 1318 |
+
if cu_seqlens.ndim != 1:
|
| 1319 |
+
raise ValueError(f"cu_seqlens must be rank-1, got {cu_seqlens.shape}")
|
| 1320 |
+
if not cu_seqlens.is_contiguous():
|
| 1321 |
+
raise ValueError("cu_seqlens must be contiguous")
|
| 1322 |
+
if seqused is not None:
|
| 1323 |
+
if seqused.dtype != torch.int32:
|
| 1324 |
+
raise TypeError(f"seqused must be torch.int32, got {seqused.dtype}")
|
| 1325 |
+
if seqused.ndim != 1:
|
| 1326 |
+
raise ValueError(f"seqused must be rank-1, got {seqused.shape}")
|
| 1327 |
+
if not seqused.is_contiguous():
|
| 1328 |
+
raise ValueError("seqused must be contiguous")
|
| 1329 |
+
|
| 1330 |
+
k_block_size = 128 if D > 64 else 64
|
| 1331 |
+
tile_m = 64
|
| 1332 |
+
has_cu_seqlens = cu_seqlens is not None
|
| 1333 |
+
has_seqused = seqused is not None
|
| 1334 |
+
has_lse = lse_out is not None
|
| 1335 |
+
has_split_counts = split_counts is not None
|
| 1336 |
+
has_output_scale = output_scale is not None
|
| 1337 |
+
min_blocks_per_mp = 3 if has_output_scale and use_pdl else 0
|
| 1338 |
+
|
| 1339 |
+
key = (
|
| 1340 |
+
"combine",
|
| 1341 |
+
D,
|
| 1342 |
+
k_block_size,
|
| 1343 |
+
tile_m,
|
| 1344 |
+
num_splits,
|
| 1345 |
+
partial_dtype,
|
| 1346 |
+
out_dtype,
|
| 1347 |
+
has_cu_seqlens,
|
| 1348 |
+
has_seqused,
|
| 1349 |
+
has_lse,
|
| 1350 |
+
bool(return_temperature_lse),
|
| 1351 |
+
has_split_counts,
|
| 1352 |
+
has_output_scale,
|
| 1353 |
+
use_pdl,
|
| 1354 |
+
min_blocks_per_mp,
|
| 1355 |
+
)
|
| 1356 |
+
if key not in _combine_compile_cache:
|
| 1357 |
+
from ....src.common.aot_cache import try_load_aot, save_aot
|
| 1358 |
+
|
| 1359 |
+
loaded = try_load_aot(key)
|
| 1360 |
+
if loaded is not None:
|
| 1361 |
+
_combine_compile_cache[key] = loaded
|
| 1362 |
+
else:
|
| 1363 |
+
from ....quack.compile_utils import make_fake_tensor
|
| 1364 |
+
|
| 1365 |
+
kernel = SparseAttentionForwardCombine(
|
| 1366 |
+
dtype=out_dtype,
|
| 1367 |
+
dtype_partial=partial_dtype,
|
| 1368 |
+
head_dim=D,
|
| 1369 |
+
tile_m=tile_m,
|
| 1370 |
+
k_block_size=k_block_size,
|
| 1371 |
+
topk=num_splits,
|
| 1372 |
+
use_pdl=use_pdl,
|
| 1373 |
+
min_blocks_per_mp=min_blocks_per_mp,
|
| 1374 |
+
# stages=2 halves per-block SMEM (168 KB -> 103 KB) -> 2 blocks/SM,
|
| 1375 |
+
# theoretical occupancy 12.5% -> 25%. NCU DRAM throughput 76.35%
|
| 1376 |
+
# -> 88.64%. Runtime latency within noise (kernel already at HBM
|
| 1377 |
+
# bandwidth ceiling in practice) but the cleaner SOL profile
|
| 1378 |
+
# matters for downstream NCU comparison.
|
| 1379 |
+
stages=2,
|
| 1380 |
+
)
|
| 1381 |
+
div = 128 // partial_dtype.width
|
| 1382 |
+
if has_cu_seqlens:
|
| 1383 |
+
total_q, nheads = (cute.sym_int64() for _ in range(2))
|
| 1384 |
+
mO_partial = make_fake_tensor(
|
| 1385 |
+
partial_dtype, (num_splits, total_q, nheads, D), divisibility=div
|
| 1386 |
+
)
|
| 1387 |
+
mLSE_partial = make_fake_tensor(
|
| 1388 |
+
Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2
|
| 1389 |
+
)
|
| 1390 |
+
mO = make_fake_tensor(
|
| 1391 |
+
out_dtype, (total_q, nheads, D), divisibility=128 // out_dtype.width
|
| 1392 |
+
)
|
| 1393 |
+
mLSE = (
|
| 1394 |
+
make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1)
|
| 1395 |
+
if has_lse
|
| 1396 |
+
else None
|
| 1397 |
+
)
|
| 1398 |
+
mLSE_temperature_partial = (
|
| 1399 |
+
make_fake_tensor(
|
| 1400 |
+
Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=2
|
| 1401 |
+
)
|
| 1402 |
+
if return_temperature_lse
|
| 1403 |
+
else None
|
| 1404 |
+
)
|
| 1405 |
+
mLSE_temperature = (
|
| 1406 |
+
make_fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=1)
|
| 1407 |
+
if return_temperature_lse
|
| 1408 |
+
else None
|
| 1409 |
+
)
|
| 1410 |
+
else:
|
| 1411 |
+
batch, sq, nheads = (cute.sym_int64() for _ in range(3))
|
| 1412 |
+
mO_partial = make_fake_tensor(
|
| 1413 |
+
partial_dtype, (num_splits, batch, sq, nheads, D), divisibility=div
|
| 1414 |
+
)
|
| 1415 |
+
mLSE_partial = make_fake_tensor(
|
| 1416 |
+
Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3
|
| 1417 |
+
)
|
| 1418 |
+
mO = make_fake_tensor(
|
| 1419 |
+
out_dtype, (batch, sq, nheads, D), divisibility=128 // out_dtype.width
|
| 1420 |
+
)
|
| 1421 |
+
mLSE = (
|
| 1422 |
+
make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2)
|
| 1423 |
+
if has_lse
|
| 1424 |
+
else None
|
| 1425 |
+
)
|
| 1426 |
+
mLSE_temperature_partial = (
|
| 1427 |
+
make_fake_tensor(
|
| 1428 |
+
Float32, (num_splits, batch, sq, nheads), divisibility=1, leading_dim=3
|
| 1429 |
+
)
|
| 1430 |
+
if return_temperature_lse
|
| 1431 |
+
else None
|
| 1432 |
+
)
|
| 1433 |
+
mLSE_temperature = (
|
| 1434 |
+
make_fake_tensor(Float32, (batch, sq, nheads), divisibility=1, leading_dim=2)
|
| 1435 |
+
if return_temperature_lse
|
| 1436 |
+
else None
|
| 1437 |
+
)
|
| 1438 |
+
if not has_split_counts:
|
| 1439 |
+
mSplitCounts = None
|
| 1440 |
+
elif has_cu_seqlens:
|
| 1441 |
+
total_q_ctr, nheads_kv = (cute.sym_int64() for _ in range(2))
|
| 1442 |
+
mSplitCounts = make_fake_tensor(
|
| 1443 |
+
Int32, (total_q_ctr, nheads_kv), divisibility=1, leading_dim=1
|
| 1444 |
+
)
|
| 1445 |
+
else:
|
| 1446 |
+
nheads_kv = cute.sym_int64()
|
| 1447 |
+
mSplitCounts = make_fake_tensor(
|
| 1448 |
+
Int32, (batch, sq, nheads_kv), divisibility=1, leading_dim=2
|
| 1449 |
+
)
|
| 1450 |
+
mOutputScale = (
|
| 1451 |
+
make_fake_tensor(Float32, (cute.sym_int64(),), divisibility=1, leading_dim=0)
|
| 1452 |
+
if has_output_scale
|
| 1453 |
+
else None
|
| 1454 |
+
)
|
| 1455 |
+
stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
|
| 1456 |
+
|
| 1457 |
+
_combine_compile_cache[key] = cute.compile(
|
| 1458 |
+
kernel,
|
| 1459 |
+
mO_partial,
|
| 1460 |
+
mLSE_partial,
|
| 1461 |
+
mO,
|
| 1462 |
+
mLSE,
|
| 1463 |
+
mLSE_temperature_partial,
|
| 1464 |
+
mLSE_temperature,
|
| 1465 |
+
None
|
| 1466 |
+
if cu_seqlens is None
|
| 1467 |
+
else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0),
|
| 1468 |
+
None
|
| 1469 |
+
if seqused is None
|
| 1470 |
+
else make_fake_tensor(Int32, (cute.sym_int64(),), divisibility=1, leading_dim=0),
|
| 1471 |
+
None,
|
| 1472 |
+
None,
|
| 1473 |
+
None,
|
| 1474 |
+
mSplitCounts,
|
| 1475 |
+
mOutputScale,
|
| 1476 |
+
Int32(qheadperkv),
|
| 1477 |
+
stream,
|
| 1478 |
+
options="--enable-tvm-ffi",
|
| 1479 |
+
)
|
| 1480 |
+
save_aot(key, _combine_compile_cache[key])
|
| 1481 |
+
|
| 1482 |
+
with torch.cuda.nvtx.range("K2_Combine"):
|
| 1483 |
+
_combine_compile_cache[key](
|
| 1484 |
+
o_partial_fake,
|
| 1485 |
+
lse_partial,
|
| 1486 |
+
o_out,
|
| 1487 |
+
lse_out,
|
| 1488 |
+
lse_temperature_partial,
|
| 1489 |
+
lse_temperature_out,
|
| 1490 |
+
cu_seqlens,
|
| 1491 |
+
seqused,
|
| 1492 |
+
None,
|
| 1493 |
+
None,
|
| 1494 |
+
None,
|
| 1495 |
+
split_counts,
|
| 1496 |
+
output_scale,
|
| 1497 |
+
qheadperkv,
|
| 1498 |
+
)
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/__init__.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""CUTE DSL launchers for paged fp8 decode forward."""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .atten_fwd import run_decode_attention
|
| 11 |
+
from .combine import run_decode_combine
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def decode_forward_paged_fp8(
|
| 15 |
+
q: torch.Tensor,
|
| 16 |
+
k: torch.Tensor,
|
| 17 |
+
v: torch.Tensor,
|
| 18 |
+
page_table: torch.Tensor,
|
| 19 |
+
seqused_k: torch.Tensor,
|
| 20 |
+
out: torch.Tensor,
|
| 21 |
+
lse: torch.Tensor,
|
| 22 |
+
request_indices: torch.Tensor,
|
| 23 |
+
qo_tile_indices: torch.Tensor,
|
| 24 |
+
kv_tile_indices: torch.Tensor,
|
| 25 |
+
block_valid_mask: torch.Tensor,
|
| 26 |
+
split_counts: torch.Tensor,
|
| 27 |
+
o_indptr: torch.Tensor,
|
| 28 |
+
merge_indptr: torch.Tensor,
|
| 29 |
+
O_partial: torch.Tensor | None,
|
| 30 |
+
LSE_partial: torch.Tensor | None,
|
| 31 |
+
*,
|
| 32 |
+
softmax_scale: float,
|
| 33 |
+
seqlen_q: int,
|
| 34 |
+
page_size: int,
|
| 35 |
+
kv_chunk_size_pages: int,
|
| 36 |
+
max_split_count: int,
|
| 37 |
+
split_kv: bool,
|
| 38 |
+
causal: bool,
|
| 39 |
+
return_lse: bool = True,
|
| 40 |
+
O_partial_dummy: torch.Tensor | None = None,
|
| 41 |
+
LSE_partial_dummy: torch.Tensor | None = None,
|
| 42 |
+
) -> None:
|
| 43 |
+
"""Launch dense paged fp8 decode forward and optional compressed combine.
|
| 44 |
+
|
| 45 |
+
``O_partial_dummy`` / ``LSE_partial_dummy`` are caller-provided pre-allocated
|
| 46 |
+
placeholder buffers for the non-split path. When supplied, ``run_decode_attention``
|
| 47 |
+
skips the per-call ``torch.empty`` it would otherwise need to satisfy the
|
| 48 |
+
kernel's positional arg signature, saving ~5us on small-kv calls.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
run_decode_attention(
|
| 52 |
+
q,
|
| 53 |
+
k,
|
| 54 |
+
v,
|
| 55 |
+
page_table,
|
| 56 |
+
seqused_k,
|
| 57 |
+
request_indices,
|
| 58 |
+
qo_tile_indices,
|
| 59 |
+
kv_tile_indices,
|
| 60 |
+
block_valid_mask,
|
| 61 |
+
split_counts,
|
| 62 |
+
o_indptr,
|
| 63 |
+
out,
|
| 64 |
+
lse,
|
| 65 |
+
O_partial,
|
| 66 |
+
LSE_partial,
|
| 67 |
+
softmax_scale=float(softmax_scale),
|
| 68 |
+
seqlen_q=int(seqlen_q),
|
| 69 |
+
page_size=int(page_size),
|
| 70 |
+
kv_chunk_size_pages=int(kv_chunk_size_pages),
|
| 71 |
+
split_kv=bool(split_kv),
|
| 72 |
+
causal=bool(causal),
|
| 73 |
+
return_lse=bool(return_lse),
|
| 74 |
+
O_partial_dummy=O_partial_dummy,
|
| 75 |
+
LSE_partial_dummy=LSE_partial_dummy,
|
| 76 |
+
)
|
| 77 |
+
if split_kv:
|
| 78 |
+
if O_partial is None or LSE_partial is None:
|
| 79 |
+
raise ValueError("split decode requires O_partial and LSE_partial")
|
| 80 |
+
qhead_per_kv = q.shape[1] // k.shape[1]
|
| 81 |
+
q_tokens_per_group = 128 // int(qhead_per_kv)
|
| 82 |
+
run_decode_combine(
|
| 83 |
+
O_partial,
|
| 84 |
+
LSE_partial,
|
| 85 |
+
split_counts,
|
| 86 |
+
o_indptr,
|
| 87 |
+
out,
|
| 88 |
+
lse,
|
| 89 |
+
seqlen_q=int(seqlen_q),
|
| 90 |
+
q_tokens_per_group=q_tokens_per_group,
|
| 91 |
+
max_split_count=int(max_split_count),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
__all__ = ["decode_forward_paged_fp8", "run_decode_attention", "run_decode_combine"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/atten_fwd.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/build_decode_schedule/__init__.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Paged decode split-KV scheduling backed by the precompiled Torch op.
|
| 5 |
+
|
| 6 |
+
The CUDA implementation lives in ``csrc/build_decode_schedule.cu`` and is
|
| 7 |
+
built ahead of time by kernel-builder. The op returns the schedule arrays
|
| 8 |
+
plus a fixed-order scalar summary, which is reassembled into the schedule
|
| 9 |
+
dict here.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from ....._ops import ops
|
| 17 |
+
|
| 18 |
+
# Order of the scalar summary returned by the op; must match
|
| 19 |
+
# csrc/build_decode_schedule.cu.
|
| 20 |
+
_SCALAR_KEYS = (
|
| 21 |
+
"split_kv",
|
| 22 |
+
"cta_tile_q",
|
| 23 |
+
"num_q_tiles",
|
| 24 |
+
"kv_chunk_size_pages",
|
| 25 |
+
"kv_chunk_size_tokens",
|
| 26 |
+
"work_count",
|
| 27 |
+
"padded_work_count",
|
| 28 |
+
"partial_rows",
|
| 29 |
+
"max_split_count",
|
| 30 |
+
"max_grid_size",
|
| 31 |
+
"active_blocks_per_sm",
|
| 32 |
+
"num_sms",
|
| 33 |
+
"base_cta",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_decode_schedule(
|
| 38 |
+
seqused_k: torch.Tensor,
|
| 39 |
+
*,
|
| 40 |
+
page_size: int,
|
| 41 |
+
seqlen_q: int,
|
| 42 |
+
num_qo_heads: int,
|
| 43 |
+
num_kv_heads: int,
|
| 44 |
+
head_dim: int,
|
| 45 |
+
max_seqlen_k: int,
|
| 46 |
+
enable_cuda_graph: bool = False,
|
| 47 |
+
max_grid_size: int = 0,
|
| 48 |
+
fixed_split_size: int = -1,
|
| 49 |
+
disable_split_kv: bool = False,
|
| 50 |
+
) -> dict[str, object]:
|
| 51 |
+
"""GPU-only schedule build: single CUDA kernel produces all schedule
|
| 52 |
+
index arrays on device. Only a small summary tensor is D2H'd at the end
|
| 53 |
+
so the wrapper can size O_partial, pick the kernel grid, and choose
|
| 54 |
+
split/non-split compile path.
|
| 55 |
+
|
| 56 |
+
``max_seqlen_k`` is required as the host-side worst-case bound for
|
| 57 |
+
padding the work-tile arrays.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
(
|
| 61 |
+
request_indices,
|
| 62 |
+
qo_tile_indices,
|
| 63 |
+
kv_tile_indices,
|
| 64 |
+
block_valid_mask,
|
| 65 |
+
split_counts,
|
| 66 |
+
kv_pages,
|
| 67 |
+
merge_indptr,
|
| 68 |
+
o_indptr,
|
| 69 |
+
scalars,
|
| 70 |
+
) = ops.build_decode_schedule(
|
| 71 |
+
seqused_k,
|
| 72 |
+
int(page_size),
|
| 73 |
+
int(seqlen_q),
|
| 74 |
+
int(num_qo_heads),
|
| 75 |
+
int(num_kv_heads),
|
| 76 |
+
int(head_dim),
|
| 77 |
+
int(max_seqlen_k),
|
| 78 |
+
bool(enable_cuda_graph),
|
| 79 |
+
int(max_grid_size),
|
| 80 |
+
int(fixed_split_size),
|
| 81 |
+
bool(disable_split_kv),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
raw: dict[str, object] = dict(zip(_SCALAR_KEYS, (int(s) for s in scalars)))
|
| 85 |
+
raw["split_kv"] = bool(raw["split_kv"])
|
| 86 |
+
raw["request_indices"] = request_indices
|
| 87 |
+
raw["qo_tile_indices"] = qo_tile_indices
|
| 88 |
+
raw["kv_tile_indices"] = kv_tile_indices
|
| 89 |
+
raw["block_valid_mask"] = block_valid_mask
|
| 90 |
+
raw["split_counts"] = split_counts
|
| 91 |
+
raw["kv_pages"] = kv_pages
|
| 92 |
+
raw["merge_indptr"] = merge_indptr
|
| 93 |
+
raw["o_indptr"] = o_indptr
|
| 94 |
+
|
| 95 |
+
# The CUDA kernel writes into worst-case-padded buffers (size =
|
| 96 |
+
# batch * num_q_tiles * max_pages_global) but only the first
|
| 97 |
+
# ``padded_work_count`` entries are valid. Downstream consumers
|
| 98 |
+
# (tile_scheduler) take grid size from ``request_indices.shape[0]``
|
| 99 |
+
# so we narrow the views to that count; the underlying allocation
|
| 100 |
+
# is unchanged so this is a view, no copy.
|
| 101 |
+
pad = int(raw["padded_work_count"])
|
| 102 |
+
for key in (
|
| 103 |
+
"request_indices",
|
| 104 |
+
"qo_tile_indices",
|
| 105 |
+
"kv_tile_indices",
|
| 106 |
+
"block_valid_mask",
|
| 107 |
+
):
|
| 108 |
+
raw[key] = raw[key].narrow(0, 0, pad)
|
| 109 |
+
return raw
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
__all__ = ["build_decode_schedule"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/combine.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""LDGSTS split-KV combine for paged decode attention."""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Type
|
| 9 |
+
|
| 10 |
+
import cuda.bindings.driver as cuda
|
| 11 |
+
import cutlass
|
| 12 |
+
import cutlass.cute as cute
|
| 13 |
+
import torch
|
| 14 |
+
from cutlass import Float32, Int32, Int64, const_expr
|
| 15 |
+
from cutlass.cute import FastDivmodDivisor
|
| 16 |
+
from cutlass.cute.nvgpu import cpasync
|
| 17 |
+
|
| 18 |
+
from ....src.common.cute_dsl_utils import assume_tensor_aligned, torch2cute_dtype_map
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SparseDecodeForwardCombine:
|
| 22 |
+
"""Combine split-KV decode partials with FA-style LDGSTS staging.
|
| 23 |
+
|
| 24 |
+
``mO_partial`` and ``mLSE_partial`` use the split-major padded layout:
|
| 25 |
+
``partial_row = o_indptr[b] + split_idx * q_stride + q_token`` where
|
| 26 |
+
``q_stride = ceil_div(seqlen_q, q_tokens_per_group) * q_tokens_per_group``.
|
| 27 |
+
A CTA covers ``tile_m`` flattened ``(q_token, q_head)`` rows and one
|
| 28 |
+
``k_block_size`` slice of D. O_partial and LSE_partial are loaded to SMEM
|
| 29 |
+
via ``cpasync.CopyG2SOp`` before the split reduction.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
dtype: Type[cutlass.Numeric],
|
| 35 |
+
dtype_partial: Type[cutlass.Numeric],
|
| 36 |
+
head_dim: int,
|
| 37 |
+
*,
|
| 38 |
+
tile_m: int = 64,
|
| 39 |
+
k_block_size: int = 128,
|
| 40 |
+
max_splits: int = 4,
|
| 41 |
+
num_threads: int = 256,
|
| 42 |
+
stages: int = 2,
|
| 43 |
+
):
|
| 44 |
+
if head_dim != 128:
|
| 45 |
+
raise NotImplementedError(
|
| 46 |
+
f"SparseDecodeForwardCombine currently supports only D=128, got D={head_dim}"
|
| 47 |
+
)
|
| 48 |
+
if dtype not in [cutlass.BFloat16, cutlass.Float16, cutlass.Float32]:
|
| 49 |
+
raise TypeError(f"Unsupported output dtype: {dtype}")
|
| 50 |
+
if dtype_partial is not Float32:
|
| 51 |
+
raise TypeError("decode O_partial must be Float32")
|
| 52 |
+
if k_block_size != head_dim:
|
| 53 |
+
raise NotImplementedError("decode combine currently uses one D=128 k block")
|
| 54 |
+
if tile_m % 8 != 0:
|
| 55 |
+
raise ValueError("decode combine tile_m must be divisible by 8")
|
| 56 |
+
if max_splits < 1 or max_splits > 256:
|
| 57 |
+
raise ValueError("decode combine max_splits must be in [1, 256]")
|
| 58 |
+
|
| 59 |
+
self.dtype = dtype
|
| 60 |
+
self.dtype_partial = dtype_partial
|
| 61 |
+
self.head_dim = head_dim
|
| 62 |
+
self.tile_m = tile_m
|
| 63 |
+
self.k_block_size = k_block_size
|
| 64 |
+
self.max_splits = max_splits
|
| 65 |
+
self.num_threads = num_threads
|
| 66 |
+
self.stages = stages
|
| 67 |
+
self.is_even_k = head_dim % k_block_size == 0
|
| 68 |
+
|
| 69 |
+
def _setup_attributes(self) -> None:
|
| 70 |
+
universal_copy_bits = 128
|
| 71 |
+
async_copy_elems = universal_copy_bits // self.dtype_partial.width
|
| 72 |
+
assert self.k_block_size % async_copy_elems == 0
|
| 73 |
+
|
| 74 |
+
k_block_gmem = (
|
| 75 |
+
128
|
| 76 |
+
if self.k_block_size % 128 == 0
|
| 77 |
+
else (64 if self.k_block_size % 64 == 0 else 32)
|
| 78 |
+
)
|
| 79 |
+
gmem_threads_per_row = k_block_gmem // async_copy_elems
|
| 80 |
+
assert self.num_threads % gmem_threads_per_row == 0
|
| 81 |
+
|
| 82 |
+
atom_async_copy_partial = cute.make_copy_atom(
|
| 83 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
| 84 |
+
self.dtype_partial,
|
| 85 |
+
num_bits_per_copy=universal_copy_bits,
|
| 86 |
+
)
|
| 87 |
+
tOpartial_layout = cute.make_ordered_layout(
|
| 88 |
+
(self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
| 89 |
+
order=(1, 0),
|
| 90 |
+
)
|
| 91 |
+
vOpartial_layout = cute.make_layout((1, async_copy_elems))
|
| 92 |
+
self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
|
| 93 |
+
atom_async_copy_partial, tOpartial_layout, vOpartial_layout
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
atom_universal_copy = cute.make_copy_atom(
|
| 97 |
+
cute.nvgpu.CopyUniversalOp(),
|
| 98 |
+
self.dtype,
|
| 99 |
+
num_bits_per_copy=async_copy_elems * self.dtype.width,
|
| 100 |
+
)
|
| 101 |
+
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
| 102 |
+
atom_universal_copy, tOpartial_layout, vOpartial_layout
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
lse_copy_bits = Float32.width
|
| 106 |
+
m_block_smem = (
|
| 107 |
+
128
|
| 108 |
+
if self.tile_m % 128 == 0
|
| 109 |
+
else (
|
| 110 |
+
64
|
| 111 |
+
if self.tile_m % 64 == 0
|
| 112 |
+
else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
gmem_threads_per_row_lse = m_block_smem
|
| 116 |
+
assert self.num_threads % gmem_threads_per_row_lse == 0
|
| 117 |
+
|
| 118 |
+
atom_async_copy_lse = cute.make_copy_atom(
|
| 119 |
+
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
| 120 |
+
Float32,
|
| 121 |
+
num_bits_per_copy=lse_copy_bits,
|
| 122 |
+
)
|
| 123 |
+
tLSE_layout = cute.make_ordered_layout(
|
| 124 |
+
(self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
|
| 125 |
+
order=(1, 0),
|
| 126 |
+
)
|
| 127 |
+
self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 128 |
+
atom_async_copy_lse, tLSE_layout, cute.make_layout(1)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.smem_threads_per_col_lse = self.num_threads // m_block_smem
|
| 132 |
+
assert 32 % self.smem_threads_per_col_lse == 0
|
| 133 |
+
s2r_layout_atom_lse = cute.make_ordered_layout(
|
| 134 |
+
(self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
|
| 135 |
+
order=(0, 1),
|
| 136 |
+
)
|
| 137 |
+
self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
|
| 138 |
+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
|
| 139 |
+
s2r_layout_atom_lse,
|
| 140 |
+
cute.make_layout(1),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if const_expr(m_block_smem == 8):
|
| 144 |
+
smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
|
| 145 |
+
elif const_expr(m_block_smem == 16):
|
| 146 |
+
smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
|
| 147 |
+
else:
|
| 148 |
+
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
|
| 149 |
+
lse_atom_splits = min(self.max_splits, 8)
|
| 150 |
+
smem_layout_atom_lse = cute.make_composed_layout(
|
| 151 |
+
smem_lse_swizzle,
|
| 152 |
+
0,
|
| 153 |
+
cute.make_ordered_layout((lse_atom_splits, m_block_smem), order=(1, 0)),
|
| 154 |
+
)
|
| 155 |
+
self.smem_layout_lse = cute.tile_to_shape(
|
| 156 |
+
smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1)
|
| 157 |
+
)
|
| 158 |
+
self.smem_layout_o = cute.make_ordered_layout(
|
| 159 |
+
(self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
@cute.jit
|
| 163 |
+
def __call__(
|
| 164 |
+
self,
|
| 165 |
+
mO_partial: cute.Tensor, # [partial_rows, Hq, D] fp32
|
| 166 |
+
mLSE_partial: cute.Tensor, # [partial_rows, Hq] fp32
|
| 167 |
+
mSplitCounts: cute.Tensor, # [B] int32
|
| 168 |
+
mOIndptr: cute.Tensor, # [B + 1] int32
|
| 169 |
+
mO: cute.Tensor, # [total_q, Hq, D]
|
| 170 |
+
mLSE: cute.Tensor, # [total_q, Hq] fp32
|
| 171 |
+
seqlen_q: Int32,
|
| 172 |
+
q_tokens_per_group: Int32,
|
| 173 |
+
stream: cuda.CUstream = None,
|
| 174 |
+
):
|
| 175 |
+
if const_expr(mO_partial.element_type is not Float32):
|
| 176 |
+
raise TypeError("decode O_partial tensor must be Float32")
|
| 177 |
+
if const_expr(mLSE_partial.element_type is not Float32):
|
| 178 |
+
raise TypeError("decode LSE_partial tensor must be Float32")
|
| 179 |
+
if const_expr(mLSE.element_type is not Float32):
|
| 180 |
+
raise TypeError("decode LSE tensor must be Float32")
|
| 181 |
+
if const_expr(mO.element_type != self.dtype):
|
| 182 |
+
raise TypeError("decode O tensor dtype must match kernel dtype")
|
| 183 |
+
if const_expr(mSplitCounts.element_type is not Int32):
|
| 184 |
+
raise TypeError("decode split_counts tensor must be Int32")
|
| 185 |
+
if const_expr(mOIndptr.element_type is not Int32):
|
| 186 |
+
raise TypeError("decode o_indptr tensor must be Int32")
|
| 187 |
+
|
| 188 |
+
mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE = [
|
| 189 |
+
assume_tensor_aligned(t)
|
| 190 |
+
for t in (mO_partial, mLSE_partial, mSplitCounts, mOIndptr, mO, mLSE)
|
| 191 |
+
]
|
| 192 |
+
self._setup_attributes()
|
| 193 |
+
|
| 194 |
+
@cute.struct
|
| 195 |
+
class SharedStorage:
|
| 196 |
+
sLSE: cute.struct.Align[
|
| 197 |
+
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
| 198 |
+
]
|
| 199 |
+
sMaxValidSplit: cute.struct.Align[
|
| 200 |
+
cute.struct.MemRange[Int32, self.tile_m], 128
|
| 201 |
+
]
|
| 202 |
+
sO: cute.struct.Align[
|
| 203 |
+
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
total_q = mO.shape[0]
|
| 207 |
+
head_q = mO.shape[1]
|
| 208 |
+
batch = mSplitCounts.shape[0]
|
| 209 |
+
head_divmod = FastDivmodDivisor(head_q)
|
| 210 |
+
grid = (
|
| 211 |
+
cute.ceil_div(seqlen_q * head_q, self.tile_m),
|
| 212 |
+
cute.ceil_div(self.head_dim, self.k_block_size),
|
| 213 |
+
batch,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.kernel(
|
| 217 |
+
mO_partial,
|
| 218 |
+
mLSE_partial,
|
| 219 |
+
mSplitCounts,
|
| 220 |
+
mOIndptr,
|
| 221 |
+
mO,
|
| 222 |
+
mLSE,
|
| 223 |
+
SharedStorage,
|
| 224 |
+
self.smem_layout_lse,
|
| 225 |
+
self.smem_layout_o,
|
| 226 |
+
self.gmem_tiled_copy_O_partial,
|
| 227 |
+
self.gmem_tiled_copy_O,
|
| 228 |
+
self.gmem_tiled_copy_LSE,
|
| 229 |
+
self.s2r_tiled_copy_LSE,
|
| 230 |
+
head_divmod,
|
| 231 |
+
Int32(total_q),
|
| 232 |
+
Int32(head_q),
|
| 233 |
+
seqlen_q,
|
| 234 |
+
q_tokens_per_group,
|
| 235 |
+
).launch(
|
| 236 |
+
grid=grid,
|
| 237 |
+
block=[self.num_threads, 1, 1],
|
| 238 |
+
smem=SharedStorage.size_in_bytes(),
|
| 239 |
+
stream=stream,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
@cute.kernel
|
| 243 |
+
def kernel(
|
| 244 |
+
self,
|
| 245 |
+
mO_partial: cute.Tensor,
|
| 246 |
+
mLSE_partial: cute.Tensor,
|
| 247 |
+
mSplitCounts: cute.Tensor,
|
| 248 |
+
mOIndptr: cute.Tensor,
|
| 249 |
+
mO: cute.Tensor,
|
| 250 |
+
mLSE: cute.Tensor,
|
| 251 |
+
SharedStorage: cutlass.Constexpr,
|
| 252 |
+
smem_layout_lse: cute.Layout | cute.ComposedLayout,
|
| 253 |
+
smem_layout_o: cute.Layout,
|
| 254 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 255 |
+
gmem_tiled_copy_O: cute.TiledCopy,
|
| 256 |
+
gmem_tiled_copy_LSE: cute.TiledCopy,
|
| 257 |
+
s2r_tiled_copy_LSE: cute.TiledCopy,
|
| 258 |
+
head_divmod: FastDivmodDivisor,
|
| 259 |
+
total_q: Int32,
|
| 260 |
+
head_q: Int32,
|
| 261 |
+
seqlen_q: Int32,
|
| 262 |
+
q_tokens_per_group: Int32,
|
| 263 |
+
):
|
| 264 |
+
tidx, _, _ = cute.arch.thread_idx()
|
| 265 |
+
m_block, k_block, batch_idx = cute.arch.block_idx()
|
| 266 |
+
|
| 267 |
+
smem = cutlass.utils.SmemAllocator()
|
| 268 |
+
storage = smem.allocate(SharedStorage)
|
| 269 |
+
sLSE = storage.sLSE.get_tensor(smem_layout_lse)
|
| 270 |
+
sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
|
| 271 |
+
sO = storage.sO.get_tensor(smem_layout_o)
|
| 272 |
+
|
| 273 |
+
split_count = mSplitCounts[batch_idx]
|
| 274 |
+
q_stride = (
|
| 275 |
+
(seqlen_q + q_tokens_per_group - Int32(1))
|
| 276 |
+
// q_tokens_per_group
|
| 277 |
+
) * q_tokens_per_group
|
| 278 |
+
max_idx = seqlen_q * head_q
|
| 279 |
+
|
| 280 |
+
if m_block * Int32(self.tile_m) < max_idx:
|
| 281 |
+
gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
|
| 282 |
+
tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
|
| 283 |
+
cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m))
|
| 284 |
+
tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
|
| 285 |
+
|
| 286 |
+
for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
|
| 287 |
+
mi = tLSEcLSE[0, 0, m][1]
|
| 288 |
+
idx = m_block * Int32(self.tile_m) + mi
|
| 289 |
+
if idx < max_idx:
|
| 290 |
+
q_idx, q_head = divmod(idx, head_divmod)
|
| 291 |
+
partial_base = mOIndptr[batch_idx] + q_idx
|
| 292 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 293 |
+
si = tLSEcLSE[0, s, 0][0]
|
| 294 |
+
if si < split_count:
|
| 295 |
+
partial_row = partial_base + si * q_stride
|
| 296 |
+
lse_ptr = (
|
| 297 |
+
mLSE_partial.iterator
|
| 298 |
+
+ Int64(partial_row) * Int64(head_q)
|
| 299 |
+
+ Int64(q_head)
|
| 300 |
+
)
|
| 301 |
+
lse_gmem_ptr = cute.make_ptr(
|
| 302 |
+
Float32,
|
| 303 |
+
lse_ptr.toint(),
|
| 304 |
+
cute.AddressSpace.gmem,
|
| 305 |
+
assumed_align=4,
|
| 306 |
+
)
|
| 307 |
+
lse_src = cute.make_tensor(lse_gmem_ptr, (1,))
|
| 308 |
+
cute.copy(
|
| 309 |
+
gmem_thr_copy_LSE,
|
| 310 |
+
lse_src,
|
| 311 |
+
tLSEsLSE[None, s, m],
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 315 |
+
else:
|
| 316 |
+
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
| 317 |
+
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
| 318 |
+
cute.arch.cp_async_commit_group()
|
| 319 |
+
|
| 320 |
+
gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
|
| 321 |
+
cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
|
| 322 |
+
tOcO = gmem_thr_copy_O_partial.partition_D(cO)
|
| 323 |
+
tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
|
| 324 |
+
|
| 325 |
+
num_rows = const_expr(cute.size(tOcO, mode=[1]))
|
| 326 |
+
tOqidx = cute.make_rmem_tensor(num_rows, Int32)
|
| 327 |
+
tOhidx = cute.make_rmem_tensor(num_rows, Int32)
|
| 328 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 329 |
+
mi = tOcO[0, m, 0][0]
|
| 330 |
+
idx = m_block * Int32(self.tile_m) + mi
|
| 331 |
+
if idx >= max_idx:
|
| 332 |
+
tOqidx[m] = Int32(0)
|
| 333 |
+
tOhidx[m] = -Int32(1)
|
| 334 |
+
else:
|
| 335 |
+
tOqidx[m], tOhidx[m] = divmod(idx, head_divmod)
|
| 336 |
+
|
| 337 |
+
load_O_partial = partial(
|
| 338 |
+
self.load_O_partial,
|
| 339 |
+
mO_partial,
|
| 340 |
+
mOIndptr,
|
| 341 |
+
gmem_tiled_copy_O_partial,
|
| 342 |
+
tOsO_partial,
|
| 343 |
+
tOqidx,
|
| 344 |
+
tOhidx,
|
| 345 |
+
tOcO,
|
| 346 |
+
batch_idx,
|
| 347 |
+
q_stride,
|
| 348 |
+
split_count,
|
| 349 |
+
head_q,
|
| 350 |
+
k_block,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
for stage in cutlass.range(self.stages - 1, unroll_full=True):
|
| 354 |
+
if stage < split_count:
|
| 355 |
+
load_O_partial(stage, stage)
|
| 356 |
+
cute.arch.cp_async_commit_group()
|
| 357 |
+
|
| 358 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 359 |
+
cute.arch.sync_threads()
|
| 360 |
+
|
| 361 |
+
s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
|
| 362 |
+
ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
|
| 363 |
+
ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
|
| 364 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
|
| 365 |
+
|
| 366 |
+
lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
|
| 367 |
+
ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
|
| 368 |
+
max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
|
| 369 |
+
assert cute.size(ts2rrLSE, mode=[0]) == 1
|
| 370 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 371 |
+
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
| 372 |
+
lse_max = cute.arch.warp_reduction_max(
|
| 373 |
+
ts2rrLSE[None, None, m]
|
| 374 |
+
.load()
|
| 375 |
+
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
| 376 |
+
threads_in_group=threads_per_col,
|
| 377 |
+
)
|
| 378 |
+
max_valid_idx = -Int32(1)
|
| 379 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 380 |
+
if ts2rrLSE[0, s, m] != -Float32.inf:
|
| 381 |
+
max_valid_idx = ts2rcLSE[0, s, 0][0]
|
| 382 |
+
max_valid_split[m] = cute.arch.warp_reduction_max(
|
| 383 |
+
max_valid_idx, threads_in_group=threads_per_col
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
lse_max_cur = Float32(0.0) if lse_max == -Float32.inf else lse_max
|
| 387 |
+
LOG2_E = Float32(math.log2(math.e))
|
| 388 |
+
lse_sum_cur = Float32(0.0)
|
| 389 |
+
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
| 390 |
+
scale = cute.math.exp2(
|
| 391 |
+
(ts2rrLSE[0, s, m] - lse_max_cur) * LOG2_E,
|
| 392 |
+
fastmath=True,
|
| 393 |
+
)
|
| 394 |
+
lse_sum_cur += scale
|
| 395 |
+
ts2rrLSE[0, s, m] = scale
|
| 396 |
+
lse_sum_cur = cute.arch.warp_reduction_sum(
|
| 397 |
+
lse_sum_cur, threads_in_group=threads_per_col
|
| 398 |
+
)
|
| 399 |
+
lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
|
| 400 |
+
inv_sum = (
|
| 401 |
+
Float32(0.0)
|
| 402 |
+
if (lse_sum_cur == Float32(0.0) or lse_sum_cur != lse_sum_cur)
|
| 403 |
+
else cute.arch.rcp_approx(lse_sum_cur)
|
| 404 |
+
)
|
| 405 |
+
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
|
| 406 |
+
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
|
| 407 |
+
|
| 408 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 409 |
+
if ts2rcLSE[0, 0, m][0] == Int32(0):
|
| 410 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 411 |
+
if mi < Int32(self.tile_m):
|
| 412 |
+
sMaxValidSplit[mi] = max_valid_split[m]
|
| 413 |
+
|
| 414 |
+
if k_block == Int32(0):
|
| 415 |
+
for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
|
| 416 |
+
if ts2rcLSE[0, 0, m][0] == Int32(0):
|
| 417 |
+
mi = ts2rcLSE[0, 0, m][1]
|
| 418 |
+
idx = m_block * Int32(self.tile_m) + mi
|
| 419 |
+
if idx < max_idx:
|
| 420 |
+
q_idx, q_head = divmod(idx, head_divmod)
|
| 421 |
+
q_abs = batch_idx * seqlen_q + q_idx
|
| 422 |
+
mLSE[q_abs, q_head] = lse_sum[m]
|
| 423 |
+
|
| 424 |
+
cute.arch.sync_threads()
|
| 425 |
+
|
| 426 |
+
thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
|
| 427 |
+
for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
|
| 428 |
+
thr_max_valid_split = max(
|
| 429 |
+
thr_max_valid_split,
|
| 430 |
+
sMaxValidSplit[tOcO[0, m, 0][0]],
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
|
| 434 |
+
tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
|
| 435 |
+
tOrO.fill(Float32(0.0))
|
| 436 |
+
|
| 437 |
+
stage_load = self.stages - 1
|
| 438 |
+
stage_compute = 0
|
| 439 |
+
for s in cutlass.range(thr_max_valid_split + Int32(1), unroll=4):
|
| 440 |
+
scale = cute.make_rmem_tensor(num_rows, Float32)
|
| 441 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 442 |
+
scale[m] = sLSE[s, tOcO[0, m, 0][0]]
|
| 443 |
+
|
| 444 |
+
split_to_load = s + Int32(self.stages - 1)
|
| 445 |
+
if split_to_load <= thr_max_valid_split:
|
| 446 |
+
load_O_partial(split_to_load, stage_load)
|
| 447 |
+
cute.arch.cp_async_commit_group()
|
| 448 |
+
stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
|
| 449 |
+
|
| 450 |
+
cute.arch.cp_async_wait_group(self.stages - 1)
|
| 451 |
+
cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
|
| 452 |
+
stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
|
| 453 |
+
|
| 454 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 455 |
+
if tOhidx[m] >= Int32(0) and scale[m] > Float32(0.0):
|
| 456 |
+
tOrO[None, m, None].store(
|
| 457 |
+
tOrO[None, m, None].load()
|
| 458 |
+
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
cute.arch.cp_async_wait_group(0)
|
| 462 |
+
cute.arch.sync_threads()
|
| 463 |
+
|
| 464 |
+
rO = cute.make_rmem_tensor_like(tOrO, self.dtype)
|
| 465 |
+
rO.store(tOrO.load().to(self.dtype))
|
| 466 |
+
elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
|
| 467 |
+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
|
| 468 |
+
for m in cutlass.range(num_rows, unroll_full=True):
|
| 469 |
+
if tOhidx[m] >= Int32(0):
|
| 470 |
+
q_abs = batch_idx * seqlen_q + tOqidx[m]
|
| 471 |
+
row_ptr = (
|
| 472 |
+
mO.iterator
|
| 473 |
+
+ (
|
| 474 |
+
(Int64(q_abs) * Int64(head_q) + Int64(tOhidx[m]))
|
| 475 |
+
* Int64(self.head_dim)
|
| 476 |
+
+ Int64(k_block * Int32(self.k_block_size))
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
row_gmem_ptr = cute.make_ptr(
|
| 480 |
+
mO.element_type,
|
| 481 |
+
row_ptr.toint(),
|
| 482 |
+
cute.AddressSpace.gmem,
|
| 483 |
+
assumed_align=16,
|
| 484 |
+
)
|
| 485 |
+
mO_row = cute.make_tensor(
|
| 486 |
+
row_gmem_ptr,
|
| 487 |
+
cute.make_layout((self.k_block_size,)),
|
| 488 |
+
)
|
| 489 |
+
mO_row_copy = cute.tiled_divide(mO_row, (elems_per_store,))
|
| 490 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 491 |
+
k_idx = tOcO[0, 0, k][1] // elems_per_store
|
| 492 |
+
cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_row_copy[None, k_idx])
|
| 493 |
+
|
| 494 |
+
@cute.jit
|
| 495 |
+
def load_O_partial(
|
| 496 |
+
self,
|
| 497 |
+
mO_partial: cute.Tensor,
|
| 498 |
+
mOIndptr: cute.Tensor,
|
| 499 |
+
gmem_tiled_copy_O_partial: cute.TiledCopy,
|
| 500 |
+
tOsO_partial: cute.Tensor,
|
| 501 |
+
tOqidx: cute.Tensor,
|
| 502 |
+
tOhidx: cute.Tensor,
|
| 503 |
+
tOcO: cute.Tensor,
|
| 504 |
+
batch_idx: Int32,
|
| 505 |
+
q_stride: Int32,
|
| 506 |
+
split_count: Int32,
|
| 507 |
+
head_q: Int32,
|
| 508 |
+
k_block: Int32,
|
| 509 |
+
split: Int32,
|
| 510 |
+
stage: Int32,
|
| 511 |
+
) -> None:
|
| 512 |
+
elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
|
| 513 |
+
tOsO_partial_cur = tOsO_partial[None, None, None, stage]
|
| 514 |
+
for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
|
| 515 |
+
if tOhidx[m] >= Int32(0):
|
| 516 |
+
if split < split_count:
|
| 517 |
+
partial_row = mOIndptr[batch_idx] + split * q_stride + tOqidx[m]
|
| 518 |
+
row_ptr = (
|
| 519 |
+
mO_partial.iterator
|
| 520 |
+
+ (
|
| 521 |
+
(Int64(partial_row) * Int64(head_q) + Int64(tOhidx[m]))
|
| 522 |
+
* Int64(self.head_dim)
|
| 523 |
+
+ Int64(k_block * Int32(self.k_block_size))
|
| 524 |
+
)
|
| 525 |
+
)
|
| 526 |
+
row_gmem_ptr = cute.make_ptr(
|
| 527 |
+
mO_partial.element_type,
|
| 528 |
+
row_ptr.toint(),
|
| 529 |
+
cute.AddressSpace.gmem,
|
| 530 |
+
assumed_align=16,
|
| 531 |
+
)
|
| 532 |
+
mO_partial_row = cute.make_tensor(
|
| 533 |
+
row_gmem_ptr,
|
| 534 |
+
cute.make_layout((self.k_block_size,)),
|
| 535 |
+
)
|
| 536 |
+
mO_partial_row_copy = cute.tiled_divide(
|
| 537 |
+
mO_partial_row, (elems_per_load,))
|
| 538 |
+
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
| 539 |
+
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
| 540 |
+
cute.copy(
|
| 541 |
+
gmem_tiled_copy_O_partial,
|
| 542 |
+
mO_partial_row_copy[None, k_idx],
|
| 543 |
+
tOsO_partial_cur[None, m, k],
|
| 544 |
+
)
|
| 545 |
+
else:
|
| 546 |
+
tOsO_partial_cur[None, m, None].fill(Float32(0.0))
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
_combine_compile_cache: dict[tuple[object, ...], object] = {}
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def _next_power_of_2(x: int) -> int:
|
| 553 |
+
return 1 << (max(int(x), 1) - 1).bit_length()
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def run_decode_combine(
|
| 557 |
+
O_partial: torch.Tensor,
|
| 558 |
+
LSE_partial: torch.Tensor,
|
| 559 |
+
split_counts: torch.Tensor,
|
| 560 |
+
o_indptr: torch.Tensor,
|
| 561 |
+
out: torch.Tensor,
|
| 562 |
+
lse: torch.Tensor,
|
| 563 |
+
*,
|
| 564 |
+
seqlen_q: int,
|
| 565 |
+
q_tokens_per_group: int,
|
| 566 |
+
max_split_count: int,
|
| 567 |
+
) -> None:
|
| 568 |
+
"""Launch LDGSTS decode split-KV combine."""
|
| 569 |
+
|
| 570 |
+
if O_partial.dtype != torch.float32:
|
| 571 |
+
raise TypeError(f"O_partial must be torch.float32, got {O_partial.dtype}")
|
| 572 |
+
if LSE_partial.dtype != torch.float32:
|
| 573 |
+
raise TypeError(f"LSE_partial must be torch.float32, got {LSE_partial.dtype}")
|
| 574 |
+
if lse.dtype != torch.float32:
|
| 575 |
+
raise TypeError(f"lse must be torch.float32, got {lse.dtype}")
|
| 576 |
+
if split_counts.dtype != torch.int32:
|
| 577 |
+
raise TypeError(f"split_counts must be torch.int32, got {split_counts.dtype}")
|
| 578 |
+
if o_indptr.dtype != torch.int32:
|
| 579 |
+
raise TypeError(f"o_indptr must be torch.int32, got {o_indptr.dtype}")
|
| 580 |
+
if out.ndim != 3 or O_partial.ndim != 3:
|
| 581 |
+
raise ValueError("decode combine expects O tensors with shape [rows, heads, D]")
|
| 582 |
+
if LSE_partial.ndim != 2 or lse.ndim != 2:
|
| 583 |
+
raise ValueError("decode combine expects LSE tensors with shape [rows, heads]")
|
| 584 |
+
if out.shape[1:] != O_partial.shape[1:]:
|
| 585 |
+
raise ValueError(f"O shape mismatch: out={out.shape}, O_partial={O_partial.shape}")
|
| 586 |
+
if lse.shape != out.shape[:2]:
|
| 587 |
+
raise ValueError(f"lse shape {lse.shape} must match out[:2] {out.shape[:2]}")
|
| 588 |
+
if LSE_partial.shape != O_partial.shape[:2]:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"LSE_partial shape {LSE_partial.shape} must match O_partial[:2] {O_partial.shape[:2]}"
|
| 591 |
+
)
|
| 592 |
+
if split_counts.ndim != 1 or o_indptr.ndim != 1:
|
| 593 |
+
raise ValueError("split_counts and o_indptr must be rank-1 tensors")
|
| 594 |
+
if o_indptr.shape != (split_counts.shape[0] + 1,):
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"o_indptr shape {o_indptr.shape} must be ({split_counts.shape[0] + 1},)"
|
| 597 |
+
)
|
| 598 |
+
seqlen_q = int(seqlen_q)
|
| 599 |
+
q_tokens_per_group = int(q_tokens_per_group)
|
| 600 |
+
if seqlen_q <= 0:
|
| 601 |
+
raise ValueError("seqlen_q must be positive")
|
| 602 |
+
if q_tokens_per_group <= 0:
|
| 603 |
+
raise ValueError("q_tokens_per_group must be positive")
|
| 604 |
+
if out.shape[0] != split_counts.shape[0] * seqlen_q:
|
| 605 |
+
raise ValueError(
|
| 606 |
+
f"out rows {out.shape[0]} must equal batch*seqlen_q "
|
| 607 |
+
f"{split_counts.shape[0]}*{seqlen_q}"
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
max_split_count = int(max_split_count)
|
| 611 |
+
if max_split_count <= 0:
|
| 612 |
+
raise ValueError("max_split_count must be positive")
|
| 613 |
+
if max_split_count > 256:
|
| 614 |
+
raise NotImplementedError(
|
| 615 |
+
f"LDGSTS decode combine supports at most 256 splits, got {max_split_count}"
|
| 616 |
+
)
|
| 617 |
+
max_splits = max(4, _next_power_of_2(max_split_count))
|
| 618 |
+
tile_m = 64
|
| 619 |
+
k_block_size = int(out.shape[-1])
|
| 620 |
+
stages = 2
|
| 621 |
+
|
| 622 |
+
dtype = torch2cute_dtype_map[out.dtype]
|
| 623 |
+
key = (
|
| 624 |
+
"decode_combine_ldgsts",
|
| 625 |
+
out.shape[-1],
|
| 626 |
+
dtype,
|
| 627 |
+
O_partial.dtype,
|
| 628 |
+
seqlen_q,
|
| 629 |
+
q_tokens_per_group,
|
| 630 |
+
tile_m,
|
| 631 |
+
k_block_size,
|
| 632 |
+
max_splits,
|
| 633 |
+
stages,
|
| 634 |
+
)
|
| 635 |
+
if key not in _combine_compile_cache:
|
| 636 |
+
from ....quack.compile_utils import make_fake_tensor
|
| 637 |
+
|
| 638 |
+
total_q = cute.sym_int64()
|
| 639 |
+
batch = cute.sym_int64()
|
| 640 |
+
batch_plus_one = cute.sym_int64()
|
| 641 |
+
partial_rows = cute.sym_int64()
|
| 642 |
+
head_q = cute.sym_int64()
|
| 643 |
+
head_dim = int(out.shape[-1])
|
| 644 |
+
kernel = SparseDecodeForwardCombine(
|
| 645 |
+
dtype=dtype,
|
| 646 |
+
dtype_partial=Float32,
|
| 647 |
+
head_dim=head_dim,
|
| 648 |
+
tile_m=tile_m,
|
| 649 |
+
k_block_size=k_block_size,
|
| 650 |
+
max_splits=max_splits,
|
| 651 |
+
stages=stages,
|
| 652 |
+
)
|
| 653 |
+
_combine_compile_cache[key] = cute.compile(
|
| 654 |
+
kernel,
|
| 655 |
+
make_fake_tensor(Float32, (partial_rows, head_q, head_dim), divisibility=4),
|
| 656 |
+
make_fake_tensor(Float32, (partial_rows, head_q), divisibility=1, leading_dim=1),
|
| 657 |
+
make_fake_tensor(Int32, (batch,), divisibility=1, leading_dim=0),
|
| 658 |
+
make_fake_tensor(Int32, (batch_plus_one,), divisibility=1, leading_dim=0),
|
| 659 |
+
make_fake_tensor(dtype, (total_q, head_q, head_dim), divisibility=128 // dtype.width),
|
| 660 |
+
make_fake_tensor(Float32, (total_q, head_q), divisibility=1, leading_dim=1),
|
| 661 |
+
Int32(seqlen_q),
|
| 662 |
+
Int32(q_tokens_per_group),
|
| 663 |
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
| 664 |
+
options="--enable-tvm-ffi",
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
with torch.cuda.nvtx.range("Decode_Combine_LDGSTS"):
|
| 668 |
+
_combine_compile_cache[key](
|
| 669 |
+
O_partial,
|
| 670 |
+
LSE_partial,
|
| 671 |
+
split_counts,
|
| 672 |
+
o_indptr,
|
| 673 |
+
out,
|
| 674 |
+
lse,
|
| 675 |
+
seqlen_q,
|
| 676 |
+
q_tokens_per_group,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
__all__ = ["SparseDecodeForwardCombine", "run_decode_combine"]
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/fwd_decode/tile_scheduler.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Decode-specific tile scheduler for paged fp8 attention.
|
| 5 |
+
|
| 6 |
+
The pre-schedule step builds a dense worklist over decode KV chunks. Static
|
| 7 |
+
persistent scheduling walks a flattened ``(work_idx, head_kv_idx)`` task id.
|
| 8 |
+
CLC scheduling keeps BSA's hardware grid shape, ``(work_idx, head_kv_idx, 1)``,
|
| 9 |
+
and maps the canceled CTA coordinate back to the same logical task space.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
import cutlass
|
| 16 |
+
import cutlass.cute as cute
|
| 17 |
+
from cutlass import Int32, const_expr
|
| 18 |
+
from cutlass.cute import FastDivmodDivisor
|
| 19 |
+
|
| 20 |
+
from ....quack.cute_dsl_utils import ParamsBase
|
| 21 |
+
|
| 22 |
+
from ....src.common.tile_scheduler import SchedulingMode, WorkTileInfo
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class DecodeTileSchedulerArguments(ParamsBase):
|
| 27 |
+
work_capacity: Int32
|
| 28 |
+
num_heads_kv: Int32
|
| 29 |
+
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DecodeTileScheduler:
|
| 33 |
+
"""Persistent scheduler over decode ``(work_idx, head_kv_idx)`` tasks."""
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Params(ParamsBase):
|
| 37 |
+
work_capacity: Int32
|
| 38 |
+
num_heads_kv: Int32
|
| 39 |
+
num_heads_kv_divmod: FastDivmodDivisor
|
| 40 |
+
total_tasks: Int32
|
| 41 |
+
cluster_shape_m: cutlass.Constexpr[int] = 1
|
| 42 |
+
scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
params: Params,
|
| 47 |
+
task_idx: Int32,
|
| 48 |
+
clc_scheduler=None,
|
| 49 |
+
clc_pipeline=None,
|
| 50 |
+
clc_consumer_state=None,
|
| 51 |
+
clc_response_ptr=None,
|
| 52 |
+
*,
|
| 53 |
+
loc=None,
|
| 54 |
+
ip=None,
|
| 55 |
+
):
|
| 56 |
+
self.params = params
|
| 57 |
+
self._task_idx = task_idx
|
| 58 |
+
self._clc_scheduler = clc_scheduler
|
| 59 |
+
self._clc_pipeline = clc_pipeline
|
| 60 |
+
self._clc_consumer_state = clc_consumer_state
|
| 61 |
+
self._clc_response_ptr = clc_response_ptr
|
| 62 |
+
self._loc = loc
|
| 63 |
+
self._ip = ip
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def to_underlying_arguments(
|
| 67 |
+
args: DecodeTileSchedulerArguments,
|
| 68 |
+
*,
|
| 69 |
+
scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
|
| 70 |
+
loc=None,
|
| 71 |
+
ip=None,
|
| 72 |
+
) -> Params:
|
| 73 |
+
assert args.cluster_shape_mn[1] == 1, "Decode scheduler requires cluster N == 1"
|
| 74 |
+
total_tasks = args.work_capacity * args.num_heads_kv
|
| 75 |
+
return DecodeTileScheduler.Params(
|
| 76 |
+
args.work_capacity,
|
| 77 |
+
args.num_heads_kv,
|
| 78 |
+
FastDivmodDivisor(args.num_heads_kv),
|
| 79 |
+
total_tasks,
|
| 80 |
+
cluster_shape_m=args.cluster_shape_mn[0],
|
| 81 |
+
scheduling_mode=scheduling_mode,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _clc_grid_shape(params: Params):
|
| 86 |
+
return (
|
| 87 |
+
cute.round_up(params.work_capacity, params.cluster_shape_m),
|
| 88 |
+
params.num_heads_kv,
|
| 89 |
+
Int32(1),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
@cute.jit
|
| 94 |
+
def create(
|
| 95 |
+
params: Params,
|
| 96 |
+
clc_response_ptr=None,
|
| 97 |
+
*,
|
| 98 |
+
loc=None,
|
| 99 |
+
ip=None,
|
| 100 |
+
) -> "DecodeTileScheduler":
|
| 101 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 102 |
+
from cutlass.utils import (
|
| 103 |
+
ClcDynamicPersistentTileScheduler,
|
| 104 |
+
ClcDynamicPersistentTileSchedulerParams,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
cutlass_params = ClcDynamicPersistentTileSchedulerParams(
|
| 108 |
+
problem_shape_ntile_mnl=DecodeTileScheduler._clc_grid_shape(params),
|
| 109 |
+
cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
|
| 110 |
+
)
|
| 111 |
+
block_idx = cute.arch.block_idx()
|
| 112 |
+
grid_dim = cute.arch.grid_dim()
|
| 113 |
+
clc_scheduler = ClcDynamicPersistentTileScheduler.create(
|
| 114 |
+
cutlass_params,
|
| 115 |
+
block_idx,
|
| 116 |
+
grid_dim,
|
| 117 |
+
clc_response_ptr,
|
| 118 |
+
)
|
| 119 |
+
return DecodeTileScheduler(
|
| 120 |
+
params,
|
| 121 |
+
block_idx[0],
|
| 122 |
+
clc_scheduler,
|
| 123 |
+
clc_response_ptr=clc_response_ptr,
|
| 124 |
+
loc=loc,
|
| 125 |
+
ip=ip,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if const_expr(params.cluster_shape_m == 1):
|
| 129 |
+
task_idx = cute.arch.block_idx()[0]
|
| 130 |
+
else:
|
| 131 |
+
task_idx = cute.arch.cluster_idx()[0]
|
| 132 |
+
return DecodeTileScheduler(params, task_idx, loc=loc, ip=ip)
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def get_grid_shape(
|
| 136 |
+
params: Params,
|
| 137 |
+
*,
|
| 138 |
+
loc=None,
|
| 139 |
+
ip=None,
|
| 140 |
+
) -> Tuple[Int32, Int32, Int32]:
|
| 141 |
+
if const_expr(params.scheduling_mode == SchedulingMode.CLC):
|
| 142 |
+
return DecodeTileScheduler._clc_grid_shape(params)
|
| 143 |
+
hardware_info = cutlass.utils.HardwareInfo()
|
| 144 |
+
sm_count = hardware_info.get_device_multiprocessor_count()
|
| 145 |
+
max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
|
| 146 |
+
grid_x = cutlass.min(max_ctas, params.total_tasks * params.cluster_shape_m)
|
| 147 |
+
return (grid_x, Int32(1), Int32(1))
|
| 148 |
+
|
| 149 |
+
@cute.jit
|
| 150 |
+
def _task_to_work(self, task_idx: Int32, is_valid) -> WorkTileInfo:
|
| 151 |
+
work_idx, head_kv_idx = divmod(task_idx, self.params.num_heads_kv_divmod)
|
| 152 |
+
return WorkTileInfo(
|
| 153 |
+
(Int32(work_idx), Int32(head_kv_idx), Int32(0), Int32(0)),
|
| 154 |
+
is_valid,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
@cute.jit
|
| 158 |
+
def _clc_work_to_coords(self, work) -> WorkTileInfo:
|
| 159 |
+
work_idx = work.tile_idx[0]
|
| 160 |
+
if const_expr(self.params.cluster_shape_m > 1):
|
| 161 |
+
work_idx = work_idx // self.params.cluster_shape_m
|
| 162 |
+
return WorkTileInfo(
|
| 163 |
+
(
|
| 164 |
+
Int32(work_idx),
|
| 165 |
+
Int32(work.tile_idx[1]),
|
| 166 |
+
Int32(0),
|
| 167 |
+
Int32(0),
|
| 168 |
+
),
|
| 169 |
+
work.is_valid_tile,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
@cute.jit
|
| 173 |
+
def _clc_response_to_work(
|
| 174 |
+
self,
|
| 175 |
+
response_stage: Int32,
|
| 176 |
+
*,
|
| 177 |
+
loc=None,
|
| 178 |
+
ip=None,
|
| 179 |
+
) -> WorkTileInfo:
|
| 180 |
+
# CLC responses are 16B opaque records. The scheduler warp can query
|
| 181 |
+
# the next stage before all consumer warps have read the current one,
|
| 182 |
+
# so each pipeline stage needs its own response slot.
|
| 183 |
+
response_ptr = self._clc_response_ptr + response_stage * Int32(4)
|
| 184 |
+
m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response(
|
| 185 |
+
response_ptr, loc=loc, ip=ip)
|
| 186 |
+
cute.arch.fence_proxy("async.shared", space="cta")
|
| 187 |
+
cta_idx_in_cluster = cute.arch.block_idx()[0] % Int32(
|
| 188 |
+
self.params.cluster_shape_m)
|
| 189 |
+
return WorkTileInfo(
|
| 190 |
+
(
|
| 191 |
+
Int32(m_idx) + cta_idx_in_cluster,
|
| 192 |
+
Int32(n_idx),
|
| 193 |
+
Int32(l_idx),
|
| 194 |
+
Int32(0),
|
| 195 |
+
),
|
| 196 |
+
is_valid,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@cute.jit
|
| 200 |
+
def get_current_work(
|
| 201 |
+
self,
|
| 202 |
+
response_stage: Int32 = Int32(0),
|
| 203 |
+
*,
|
| 204 |
+
loc=None,
|
| 205 |
+
ip=None,
|
| 206 |
+
) -> WorkTileInfo:
|
| 207 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 208 |
+
work = self._clc_response_to_work(
|
| 209 |
+
response_stage, loc=loc, ip=ip)
|
| 210 |
+
self._task_idx = (
|
| 211 |
+
work.tile_idx[0] * self.params.num_heads_kv
|
| 212 |
+
+ work.tile_idx[1]
|
| 213 |
+
)
|
| 214 |
+
return self._clc_work_to_coords(work)
|
| 215 |
+
is_valid = self._task_idx < self.params.total_tasks
|
| 216 |
+
return self._task_to_work(self._task_idx, is_valid)
|
| 217 |
+
|
| 218 |
+
@cute.jit
|
| 219 |
+
def initial_work_tile_info(self, *, loc=None, ip=None):
|
| 220 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 221 |
+
work = self._clc_scheduler.initial_work_tile_info()
|
| 222 |
+
self._task_idx = (
|
| 223 |
+
work.tile_idx[0] * self.params.num_heads_kv
|
| 224 |
+
+ work.tile_idx[1]
|
| 225 |
+
)
|
| 226 |
+
return self._clc_work_to_coords(work)
|
| 227 |
+
return self.get_current_work(loc=loc, ip=ip)
|
| 228 |
+
|
| 229 |
+
def prefetch_next_work(self, *, loc=None, ip=None):
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
def advance_to_next_work(
|
| 233 |
+
self,
|
| 234 |
+
*,
|
| 235 |
+
loc=None,
|
| 236 |
+
ip=None,
|
| 237 |
+
mbarrier_addr=None,
|
| 238 |
+
response_stage: Int32 = Int32(0),
|
| 239 |
+
):
|
| 240 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 241 |
+
assert mbarrier_addr is not None
|
| 242 |
+
response_ptr = self._clc_response_ptr + response_stage * Int32(4)
|
| 243 |
+
with cute.arch.elect_one():
|
| 244 |
+
cute.arch.issue_clc_query(
|
| 245 |
+
mbarrier_addr, response_ptr, loc=loc, ip=ip)
|
| 246 |
+
else:
|
| 247 |
+
assert mbarrier_addr is None
|
| 248 |
+
if const_expr(self.params.cluster_shape_m == 1):
|
| 249 |
+
self._task_idx += cute.arch.grid_dim()[0]
|
| 250 |
+
else:
|
| 251 |
+
self._task_idx += cute.arch.cluster_dim()[0]
|
| 252 |
+
|
| 253 |
+
def consumer_advance(self, *, loc=None, ip=None):
|
| 254 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 255 |
+
response_stage = self._clc_consumer_state.index
|
| 256 |
+
self._clc_pipeline.consumer_wait(self._clc_consumer_state)
|
| 257 |
+
work_tile = self.get_current_work(response_stage=response_stage)
|
| 258 |
+
self._clc_pipeline.consumer_release(self._clc_consumer_state)
|
| 259 |
+
self._clc_consumer_state.advance()
|
| 260 |
+
return work_tile
|
| 261 |
+
self.advance_to_next_work()
|
| 262 |
+
return self.get_current_work()
|
| 263 |
+
|
| 264 |
+
def set_clc_pipeline(self, clc_pipeline, clc_consumer_state):
|
| 265 |
+
self._clc_pipeline = clc_pipeline
|
| 266 |
+
self._clc_consumer_state = clc_consumer_state
|
| 267 |
+
|
| 268 |
+
def producer_tail(self, *, loc=None, ip=None):
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
def __extract_mlir_values__(self):
|
| 272 |
+
values, self._values_pos = [], []
|
| 273 |
+
objs = [self.params, self._task_idx]
|
| 274 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 275 |
+
objs += [
|
| 276 |
+
self._clc_scheduler,
|
| 277 |
+
self._clc_pipeline,
|
| 278 |
+
self._clc_consumer_state,
|
| 279 |
+
self._clc_response_ptr,
|
| 280 |
+
]
|
| 281 |
+
for obj in objs:
|
| 282 |
+
obj_values = cutlass.extract_mlir_values(obj)
|
| 283 |
+
values += obj_values
|
| 284 |
+
self._values_pos.append(len(obj_values))
|
| 285 |
+
return values
|
| 286 |
+
|
| 287 |
+
def __new_from_mlir_values__(self, values):
|
| 288 |
+
obj_list = []
|
| 289 |
+
objs = [self.params, self._task_idx]
|
| 290 |
+
if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
|
| 291 |
+
objs += [
|
| 292 |
+
self._clc_scheduler,
|
| 293 |
+
self._clc_pipeline,
|
| 294 |
+
self._clc_consumer_state,
|
| 295 |
+
self._clc_response_ptr,
|
| 296 |
+
]
|
| 297 |
+
for obj, n_items in zip(objs, self._values_pos):
|
| 298 |
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
| 299 |
+
values = values[n_items:]
|
| 300 |
+
return DecodeTileScheduler(*obj_list, loc=self._loc)
|
build/torch211-cxx11-cu128-x86_64-linux/src/sm100/prepare_k2q_csr.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
|
| 4 |
+
"""Sparse k2q CSR builder for SM100.
|
| 5 |
+
|
| 6 |
+
Thin dispatcher that calls the CUDA C++ kernel pipeline in
|
| 7 |
+
``src.sm100.build_k2q_csr``. Supports ``topK in {4, 8, 16, 32}`` and
|
| 8 |
+
``blk_kv == 128`` only — other shapes raise ``ValueError`` rather than
|
| 9 |
+
silently falling back to a torch-reference path.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from ...src.sm100.prepare_scheduler import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_SUPPORTED_TOPK = (4, 8, 16, 32)
|
| 22 |
+
_SUPPORTED_BLK_KV = 128
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _ceil_div(x: int, y: int) -> int:
|
| 26 |
+
return (x + y - 1) // y
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SparseK2qCsrBuilderSm100:
|
| 30 |
+
"""Build the k2q CSR reverse index for sparse attention on SM100.
|
| 31 |
+
|
| 32 |
+
The public API matches the historical CUTE DSL builder so callers
|
| 33 |
+
(``sparse_index_utils.build_k2q_csr``, attention kernels) need no
|
| 34 |
+
changes. Internally the kernel pipeline runs five CUDA C++ kernels:
|
| 35 |
+
``build_row_map`` -> ``hist`` -> ``row_prefix`` -> ``tile_prefix_smem``
|
| 36 |
+
-> ``scatter`` (5 kernels + 2 ``cudaMemsetAsync``).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
# No persistent state — the JIT-compiled extension is loaded
|
| 41 |
+
# lazily by ``src.sm100.build_k2q_csr`` on first call.
|
| 42 |
+
self._run = None
|
| 43 |
+
self._run_with_schedule = None
|
| 44 |
+
|
| 45 |
+
def _ensure_loaded(self) -> None:
|
| 46 |
+
if self._run is None:
|
| 47 |
+
from ...src.sm100.build_k2q_csr import (
|
| 48 |
+
run_build_k2q_csr,
|
| 49 |
+
run_build_k2q_csr_with_schedule,
|
| 50 |
+
)
|
| 51 |
+
self._run = run_build_k2q_csr
|
| 52 |
+
self._run_with_schedule = run_build_k2q_csr_with_schedule
|
| 53 |
+
|
| 54 |
+
def __call__(
|
| 55 |
+
self,
|
| 56 |
+
q2k_indices: torch.Tensor,
|
| 57 |
+
cu_seqlens_q: torch.Tensor,
|
| 58 |
+
cu_seqlens_k: torch.Tensor,
|
| 59 |
+
*,
|
| 60 |
+
total_k: int,
|
| 61 |
+
blk_kv: int = 128,
|
| 62 |
+
max_seqlen_k: Optional[int] = None,
|
| 63 |
+
max_seqlen_q: Optional[int] = None,
|
| 64 |
+
total_rows: Optional[int] = None,
|
| 65 |
+
qhead_per_kv: int = 1,
|
| 66 |
+
return_schedule: bool = False,
|
| 67 |
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]:
|
| 68 |
+
# ---- Validation ----------------------------------------------------
|
| 69 |
+
if blk_kv != _SUPPORTED_BLK_KV:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"SparseK2qCsrBuilderSm100 only supports blk_kv == "
|
| 72 |
+
f"{_SUPPORTED_BLK_KV}, got {blk_kv}"
|
| 73 |
+
)
|
| 74 |
+
if q2k_indices.dtype != torch.int32:
|
| 75 |
+
raise TypeError(
|
| 76 |
+
f"q2k_indices must be torch.int32, got {q2k_indices.dtype}"
|
| 77 |
+
)
|
| 78 |
+
if q2k_indices.ndim != 3:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"q2k_indices must be rank-3 [head_kv, total_q, topK], "
|
| 81 |
+
f"got shape {tuple(q2k_indices.shape)}"
|
| 82 |
+
)
|
| 83 |
+
if not q2k_indices.is_contiguous():
|
| 84 |
+
raise ValueError("q2k_indices must be contiguous")
|
| 85 |
+
if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32:
|
| 86 |
+
raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32")
|
| 87 |
+
if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1:
|
| 88 |
+
raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1")
|
| 89 |
+
if cu_seqlens_q.shape != cu_seqlens_k.shape:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"cu_seqlens_q and cu_seqlens_k must share shape [B + 1]"
|
| 92 |
+
)
|
| 93 |
+
if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda):
|
| 94 |
+
raise ValueError("all inputs must be CUDA tensors")
|
| 95 |
+
if (
|
| 96 |
+
q2k_indices.device != cu_seqlens_q.device
|
| 97 |
+
or q2k_indices.device != cu_seqlens_k.device
|
| 98 |
+
):
|
| 99 |
+
raise ValueError("all inputs must share a device")
|
| 100 |
+
if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous():
|
| 101 |
+
raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous")
|
| 102 |
+
|
| 103 |
+
total_k = int(total_k)
|
| 104 |
+
if total_k < 0:
|
| 105 |
+
raise ValueError(f"total_k must be non-negative, got {total_k}")
|
| 106 |
+
|
| 107 |
+
head_kv, total_q, topk = (int(v) for v in q2k_indices.shape)
|
| 108 |
+
if topk not in _SUPPORTED_TOPK:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"SparseK2qCsrBuilderSm100 only supports topK in "
|
| 111 |
+
f"{_SUPPORTED_TOPK}, got {topk}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
batch = int(cu_seqlens_q.shape[0] - 1)
|
| 115 |
+
if batch < 0:
|
| 116 |
+
raise ValueError("cu_seqlens tensors must have shape [B + 1]")
|
| 117 |
+
if return_schedule and max_seqlen_k is None:
|
| 118 |
+
raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True")
|
| 119 |
+
max_k_tokens = int(max_seqlen_k) if max_seqlen_k is not None else total_k
|
| 120 |
+
max_kv_blocks = _ceil_div(max(max_k_tokens, blk_kv), blk_kv)
|
| 121 |
+
if total_rows is not None:
|
| 122 |
+
total_rows = int(total_rows)
|
| 123 |
+
elif total_k % blk_kv == 0:
|
| 124 |
+
total_rows = total_k // blk_kv
|
| 125 |
+
else:
|
| 126 |
+
total_rows = _ceil_div(total_k + batch * (blk_kv - 1), blk_kv)
|
| 127 |
+
if total_rows < 0:
|
| 128 |
+
raise ValueError(f"total_rows must be non-negative, got {total_rows}")
|
| 129 |
+
total_rows = max(total_rows, 0)
|
| 130 |
+
nnz_upper_bound = total_q * topk
|
| 131 |
+
qhead_per_kv = int(qhead_per_kv)
|
| 132 |
+
if qhead_per_kv <= 0:
|
| 133 |
+
raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}")
|
| 134 |
+
if return_schedule:
|
| 135 |
+
if max_seqlen_q is None:
|
| 136 |
+
raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True")
|
| 137 |
+
max_seqlen_q = int(max_seqlen_q)
|
| 138 |
+
|
| 139 |
+
# ---- Output tensors ------------------------------------------------
|
| 140 |
+
device = q2k_indices.device
|
| 141 |
+
k2q_row_ptr = torch.empty(
|
| 142 |
+
(head_kv, total_rows + 1), dtype=torch.int32, device=device,
|
| 143 |
+
)
|
| 144 |
+
k2q_q_indices = torch.empty(
|
| 145 |
+
(head_kv, nnz_upper_bound), dtype=torch.int32, device=device,
|
| 146 |
+
)
|
| 147 |
+
schedule = None
|
| 148 |
+
if return_schedule:
|
| 149 |
+
target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta(
|
| 150 |
+
total_q=total_q,
|
| 151 |
+
topk=topk,
|
| 152 |
+
blk_kv=blk_kv,
|
| 153 |
+
head_kv=head_kv,
|
| 154 |
+
qhead_per_kv=qhead_per_kv,
|
| 155 |
+
device=device,
|
| 156 |
+
)
|
| 157 |
+
scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity(
|
| 158 |
+
total_rows=total_rows,
|
| 159 |
+
total_q=total_q,
|
| 160 |
+
topk=topk,
|
| 161 |
+
head_kv=head_kv,
|
| 162 |
+
target_q_per_cta=target_q_per_cta,
|
| 163 |
+
)
|
| 164 |
+
scheduler_metadata = torch.empty(
|
| 165 |
+
(scheduler_metadata_capacity, 6), dtype=torch.int32, device=device
|
| 166 |
+
)
|
| 167 |
+
work_count = torch.empty((1,), dtype=torch.int32, device=device)
|
| 168 |
+
qsplit_indices = torch.empty_like(k2q_q_indices)
|
| 169 |
+
split_counts = torch.empty(
|
| 170 |
+
(total_q, head_kv), dtype=torch.int32, device=device
|
| 171 |
+
)
|
| 172 |
+
schedule = SparseAttentionSchedule(
|
| 173 |
+
enabled=True,
|
| 174 |
+
scheduler_metadata=scheduler_metadata,
|
| 175 |
+
work_count=work_count,
|
| 176 |
+
qsplit_indices=qsplit_indices,
|
| 177 |
+
split_counts=split_counts,
|
| 178 |
+
target_q_per_cta=target_q_per_cta,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Empty workload short-circuit (the CUDA path also handles this,
|
| 182 |
+
# but doing it here saves a JIT load for trivial calls).
|
| 183 |
+
if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0:
|
| 184 |
+
k2q_row_ptr.zero_()
|
| 185 |
+
k2q_q_indices.fill_(-1)
|
| 186 |
+
if schedule is not None:
|
| 187 |
+
schedule.work_count.zero_()
|
| 188 |
+
schedule.split_counts.zero_()
|
| 189 |
+
return k2q_row_ptr, k2q_q_indices, schedule
|
| 190 |
+
return k2q_row_ptr, k2q_q_indices
|
| 191 |
+
|
| 192 |
+
self._ensure_loaded()
|
| 193 |
+
with torch.cuda.nvtx.range("SparseK2qCsr_Pipeline"):
|
| 194 |
+
if schedule is None:
|
| 195 |
+
self._run(
|
| 196 |
+
q2k_indices,
|
| 197 |
+
cu_seqlens_q,
|
| 198 |
+
cu_seqlens_k,
|
| 199 |
+
k2q_row_ptr,
|
| 200 |
+
k2q_q_indices,
|
| 201 |
+
topk,
|
| 202 |
+
blk_kv,
|
| 203 |
+
total_rows,
|
| 204 |
+
max_kv_blocks,
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
self._run_with_schedule(
|
| 208 |
+
q2k_indices,
|
| 209 |
+
cu_seqlens_q,
|
| 210 |
+
cu_seqlens_k,
|
| 211 |
+
k2q_row_ptr,
|
| 212 |
+
k2q_q_indices,
|
| 213 |
+
schedule.scheduler_metadata,
|
| 214 |
+
schedule.work_count,
|
| 215 |
+
schedule.qsplit_indices,
|
| 216 |
+
schedule.split_counts,
|
| 217 |
+
topk,
|
| 218 |
+
blk_kv,
|
| 219 |
+
total_rows,
|
| 220 |
+
max_kv_blocks,
|
| 221 |
+
schedule.target_q_per_cta,
|
| 222 |
+
schedule.work_capacity,
|
| 223 |
+
max_seqlen_q,
|
| 224 |
+
)
|
| 225 |
+
if schedule is not None:
|
| 226 |
+
return k2q_row_ptr, k2q_q_indices, schedule
|
| 227 |
+
return k2q_row_ptr, k2q_q_indices
|