File size: 7,280 Bytes
3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 5a7fea3 bdfe80b f3eeba1 5a7fea3 bdfe80b 5a7fea3 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 f3eeba1 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 3cb6c08 bdfe80b 69e42f8 3cb6c08 bdfe80b 5a7fea3 f3eeba1 5a7fea3 3cb6c08 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | ---
library_name: kernels
license: apache-2.0
tags:
- cuda
- triton
- native-cuda
- minimax
- sparse-attention
- blackwell
---
# MiniMaxAI MSA Blackwell
Blackwell-family package for MiniMax MSA decode sparse attention, maintained by
FlashRT. The upstream package is
[`MiniMaxAI/msa`](https://huggingface.co/kernels/MiniMaxAI/msa), which targets
SM100. This package extends the decode-sparse path to NVIDIA Blackwell
compute capability 12.x and has been validated in the FlashRT MiniMax-Spark
runtime on DGX Spark / GB10 / SM121.
## Load
```python
from kernels import get_kernel
msa = get_kernel(
"flashrt/MiniMaxAI-msa-blackwell",
version=1,
trust_remote_code=True,
)
```
## What You Can Call
### Official MiniMaxAI/msa names
| Function/class | Status in this Blackwell package |
|---|---|
| `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
| `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
| `build_k2q_csr` | Available. CSR construction helper for the official prefill API. |
| `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
| `Nvfp4QuantizedTensor` | Available metadata dataclass. |
| `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
| `quantize_kv_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
| `dequantize_nvfp4_128x4_to_bf16` | Available reference dequantizer. |
| `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
| `nvfp4_global_scale_from_amax` | Available scale helper. |
| `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
| `sparse_atten_nvfp4_kv_func` | Available. Built artifacts use native CUDA swizzled NVFP4 -> BF16 dequantization, then call Blackwell sparse prefill. |
| `fp4_indexer_block_scores` | Available. Built artifacts use the native CUDA Blackwell block-score kernel and return the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |
### FlashRT Blackwell helper names
These are the direct low-level APIs used by the FlashRT MiniMax-Spark decode
path:
- `flash_decode_with_topk_idx`
- `flash_decode_with_gqa_share_sparse`
- `native_topk_from_scores`
- `native_nvfp4_dequant_swizzled_to_bf16`
- `has_native_ops`
- `naive_flash_decode_with_topk_idx`
- `naive_flash_decode_with_gqa_share_sparse`
- `get_cu_seqblocks`
- `robust_allocator`
## Decode Example
This example uses the official MiniMax decode-facing name
`sparse_decode_atten_func`.
```python
import torch
from kernels import get_kernel
msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)
B, Hq, Hkv, D = 1, 64, 4, 128
page_size = 128
num_pages = 32
topk = 16
device = "cuda"
dtype = torch.bfloat16
q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k = torch.randn(num_pages, Hkv, page_size, D, device=device, dtype=dtype)
v = torch.randn_like(k)
page_table = torch.arange(num_pages, device=device, dtype=torch.int32).view(B, -1)
seqused_k = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
q2k_indices = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
q2k_indices = q2k_indices.expand(Hkv, B, topk).contiguous()
out = msa.sparse_decode_atten_func(
q,
k,
v,
q2k_indices,
page_table=page_table,
seqused_k=seqused_k,
seqlen_q=1,
max_seqlen_k=num_pages * page_size,
blk_kv=page_size,
)
```
## Wrapper Example
```python
wrapper = msa.SparseDecodePagedAttentionWrapper(blk_kv=128)
wrapper.plan(
page_table=page_table,
seqused_k=seqused_k,
seqlen_q=1,
max_seqlen_k=num_pages * page_size,
q2k_indices=q2k_indices,
num_qo_heads=Hq,
num_kv_heads=Hkv,
head_dim=D,
)
out = wrapper.run(q, k, v)
```
## Prefill Example
This example uses the official MiniMax CSR prefill-facing name
`sparse_atten_func`.
```python
import torch
from kernels import get_kernel
msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)
T, Hq, Hkv, D = 512, 64, 4, 128
page_size = 128
topk = 16
device = "cuda"
dtype = torch.bfloat16
q = torch.randn(T, Hq, D, device=device, dtype=dtype)
k = torch.randn(T, Hkv, D, device=device, dtype=dtype)
v = torch.randn_like(k)
cu = torch.tensor([0, T], device=device, dtype=torch.int32)
q2k = torch.full((Hkv, T, topk), -1, device=device, dtype=torch.int32)
for qi in range(T):
blocks = torch.arange(qi // page_size + 1, device=device, dtype=torch.int32)
q2k[:, qi, : min(topk, blocks.numel())] = blocks[:topk]
k2q_row_ptr, k2q_q_indices = msa.build_k2q_csr(
q2k, cu, cu, page_size, total_k=T
)
out = msa.sparse_atten_func(
q,
k,
v,
k2q_row_ptr,
k2q_q_indices,
topk,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=T,
max_seqlen_k=T,
blk_kv=page_size,
)
```
## Direct FlashRT Decode Path
Use this lower-level path if you already have `topk_idx` in FlashRT's paged KV
layout.
```python
q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k_cache = torch.randn(num_pages * page_size, Hkv, D, device=device, dtype=dtype)
v_cache = torch.randn_like(k_cache)
req_to_token = torch.arange(num_pages * page_size, device=device, dtype=torch.int32).view(B, -1)
seq_lens = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
slot_ids = torch.zeros(B, device=device, dtype=torch.int64)
topk_idx = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
topk_idx = topk_idx.expand(Hkv, B, topk).contiguous()
out = msa.flash_decode_with_gqa_share_sparse(
q, None, k_cache, v_cache, req_to_token, seq_lens, slot_ids, page_size, topk_idx
)
```
## Scope
- Target family: NVIDIA Blackwell CUDA compute capability 12.x.
- Builder target: CUDA 12.8+ with `cuda-capabilities = ["12.0", "12.1"]`.
- Validated hardware: DGX Spark / GB10 / SM121.
- Validated MiniMax shape: query heads `64`, KV heads `4`, head dim `128`,
sparse block/page size `128`, top-k blocks `16`.
- End-to-end model validation: MiniMax-Spark runtime on GB10 through `32768`
context length.
- Standalone kernel long-context validation: `128`, `2048`, `4096`, `32768`,
`65536`, `131072`.
- Correctness gate: cosine similarity `>= 0.999` against paged FP32 PyTorch
references; official decode wrapper output matches the direct Blackwell
decode kernel.
## Implementation Notes
This package contains:
- native CUDA score-to-top-k helper;
- native CUDA tensor-core sparse decode route for the MiniMax-M3 Blackwell shape;
- native CUDA FP4 block-score indexer;
- native CUDA swizzled NVFP4 -> BF16 dequantization for the W4A16 quality path;
- Blackwell-validated sparse prefill attention wrapper;
- MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
and FP4 block-score helpers.
When loaded from Hub built artifacts, the decode, FP4 indexer, and NVFP4
dequant hot paths use compiled CUDA ops. The source-tree mode keeps reference
paths so the API and correctness tests remain runnable before a wheel/shared
object has been built.
Source provenance and validation details are documented in `SYNC.md` and
`VALIDATION.md`.
|