File size: 7,280 Bytes
3cb6c08
 
 
 
 
 
 
 
 
 
 
 
 
 
bdfe80b
 
3cb6c08
bdfe80b
 
 
3cb6c08
bdfe80b
3cb6c08
 
 
 
 
 
 
 
 
 
 
bdfe80b
 
 
 
 
 
 
 
5a7fea3
bdfe80b
 
 
 
 
 
 
f3eeba1
5a7fea3
 
bdfe80b
 
 
 
 
 
 
 
 
5a7fea3
bdfe80b
 
 
 
 
 
 
 
 
 
3cb6c08
 
bdfe80b
 
3cb6c08
bdfe80b
3cb6c08
bdfe80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cb6c08
bdfe80b
3cb6c08
 
bdfe80b
 
 
 
 
 
 
 
 
 
 
 
3cb6c08
 
f3eeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdfe80b
3cb6c08
bdfe80b
 
3cb6c08
 
bdfe80b
 
 
 
 
 
 
 
3cb6c08
 
bdfe80b
3cb6c08
 
 
 
 
 
 
 
 
bdfe80b
69e42f8
 
 
 
3cb6c08
bdfe80b
 
 
 
 
 
 
 
5a7fea3
 
 
 
f3eeba1
 
 
5a7fea3
 
 
 
3cb6c08
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
---
library_name: kernels
license: apache-2.0
tags:
  - cuda
  - triton
  - native-cuda
  - minimax
  - sparse-attention
  - blackwell
---

# MiniMaxAI MSA Blackwell

Blackwell-family package for MiniMax MSA decode sparse attention, maintained by
FlashRT. The upstream package is
[`MiniMaxAI/msa`](https://huggingface.co/kernels/MiniMaxAI/msa), which targets
SM100. This package extends the decode-sparse path to NVIDIA Blackwell
compute capability 12.x and has been validated in the FlashRT MiniMax-Spark
runtime on DGX Spark / GB10 / SM121.

## Load

```python
from kernels import get_kernel

msa = get_kernel(
    "flashrt/MiniMaxAI-msa-blackwell",
    version=1,
    trust_remote_code=True,
)
```

## What You Can Call

### Official MiniMaxAI/msa names

| Function/class | Status in this Blackwell package |
|---|---|
| `sparse_decode_atten_func` | Available. Blackwell paged BF16/FP16 single-token decode wrapper. |
| `SparseDecodePagedAttentionWrapper` | Available. `plan(...).run(...)` wrapper for the same decode path. |
| `build_k2q_csr` | Available. CSR construction helper for the official prefill API. |
| `SparseK2qCsrBuilderSm100` | Available compatibility class; `build()` delegates to `build_k2q_csr`. |
| `Nvfp4QuantizedTensor` | Available metadata dataclass. |
| `quantize_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
| `quantize_kv_bf16_to_nvfp4_128x4` | Available when Transformer Engine NVFP4 support is installed. |
| `dequantize_nvfp4_128x4_to_bf16` | Available reference dequantizer. |
| `swizzle_nvfp4_scale_to_128x4` | Available scale-layout helper. |
| `nvfp4_global_scale_from_amax` | Available scale helper. |
| `sparse_atten_func` | Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel. |
| `sparse_atten_nvfp4_kv_func` | Available. Built artifacts use native CUDA swizzled NVFP4 -> BF16 dequantization, then call Blackwell sparse prefill. |
| `fp4_indexer_block_scores` | Available. Built artifacts use the native CUDA Blackwell block-score kernel and return the official `[Hq, ceil(max_seqlen_k/128), total_q]` score layout. |

### FlashRT Blackwell helper names

These are the direct low-level APIs used by the FlashRT MiniMax-Spark decode
path:

- `flash_decode_with_topk_idx`
- `flash_decode_with_gqa_share_sparse`
- `native_topk_from_scores`
- `native_nvfp4_dequant_swizzled_to_bf16`
- `has_native_ops`
- `naive_flash_decode_with_topk_idx`
- `naive_flash_decode_with_gqa_share_sparse`
- `get_cu_seqblocks`
- `robust_allocator`

## Decode Example

This example uses the official MiniMax decode-facing name
`sparse_decode_atten_func`.

```python
import torch
from kernels import get_kernel

msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)

B, Hq, Hkv, D = 1, 64, 4, 128
page_size = 128
num_pages = 32
topk = 16
device = "cuda"
dtype = torch.bfloat16

q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k = torch.randn(num_pages, Hkv, page_size, D, device=device, dtype=dtype)
v = torch.randn_like(k)
page_table = torch.arange(num_pages, device=device, dtype=torch.int32).view(B, -1)
seqused_k = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
q2k_indices = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
q2k_indices = q2k_indices.expand(Hkv, B, topk).contiguous()

out = msa.sparse_decode_atten_func(
    q,
    k,
    v,
    q2k_indices,
    page_table=page_table,
    seqused_k=seqused_k,
    seqlen_q=1,
    max_seqlen_k=num_pages * page_size,
    blk_kv=page_size,
)
```

## Wrapper Example

```python
wrapper = msa.SparseDecodePagedAttentionWrapper(blk_kv=128)
wrapper.plan(
    page_table=page_table,
    seqused_k=seqused_k,
    seqlen_q=1,
    max_seqlen_k=num_pages * page_size,
    q2k_indices=q2k_indices,
    num_qo_heads=Hq,
    num_kv_heads=Hkv,
    head_dim=D,
)
out = wrapper.run(q, k, v)
```

## Prefill Example

This example uses the official MiniMax CSR prefill-facing name
`sparse_atten_func`.

```python
import torch
from kernels import get_kernel

msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)

T, Hq, Hkv, D = 512, 64, 4, 128
page_size = 128
topk = 16
device = "cuda"
dtype = torch.bfloat16

q = torch.randn(T, Hq, D, device=device, dtype=dtype)
k = torch.randn(T, Hkv, D, device=device, dtype=dtype)
v = torch.randn_like(k)
cu = torch.tensor([0, T], device=device, dtype=torch.int32)

q2k = torch.full((Hkv, T, topk), -1, device=device, dtype=torch.int32)
for qi in range(T):
    blocks = torch.arange(qi // page_size + 1, device=device, dtype=torch.int32)
    q2k[:, qi, : min(topk, blocks.numel())] = blocks[:topk]

k2q_row_ptr, k2q_q_indices = msa.build_k2q_csr(
    q2k, cu, cu, page_size, total_k=T
)
out = msa.sparse_atten_func(
    q,
    k,
    v,
    k2q_row_ptr,
    k2q_q_indices,
    topk,
    cu_seqlens_q=cu,
    cu_seqlens_k=cu,
    max_seqlen_q=T,
    max_seqlen_k=T,
    blk_kv=page_size,
)
```

## Direct FlashRT Decode Path

Use this lower-level path if you already have `topk_idx` in FlashRT's paged KV
layout.

```python
q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k_cache = torch.randn(num_pages * page_size, Hkv, D, device=device, dtype=dtype)
v_cache = torch.randn_like(k_cache)
req_to_token = torch.arange(num_pages * page_size, device=device, dtype=torch.int32).view(B, -1)
seq_lens = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
slot_ids = torch.zeros(B, device=device, dtype=torch.int64)
topk_idx = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
topk_idx = topk_idx.expand(Hkv, B, topk).contiguous()

out = msa.flash_decode_with_gqa_share_sparse(
    q, None, k_cache, v_cache, req_to_token, seq_lens, slot_ids, page_size, topk_idx
)
```

## Scope

- Target family: NVIDIA Blackwell CUDA compute capability 12.x.
- Builder target: CUDA 12.8+ with `cuda-capabilities = ["12.0", "12.1"]`.
- Validated hardware: DGX Spark / GB10 / SM121.
- Validated MiniMax shape: query heads `64`, KV heads `4`, head dim `128`,
  sparse block/page size `128`, top-k blocks `16`.
- End-to-end model validation: MiniMax-Spark runtime on GB10 through `32768`
  context length.
- Standalone kernel long-context validation: `128`, `2048`, `4096`, `32768`,
  `65536`, `131072`.
- Correctness gate: cosine similarity `>= 0.999` against paged FP32 PyTorch
  references; official decode wrapper output matches the direct Blackwell
  decode kernel.

## Implementation Notes

This package contains:

- native CUDA score-to-top-k helper;
- native CUDA tensor-core sparse decode route for the MiniMax-M3 Blackwell shape;
- native CUDA FP4 block-score indexer;
- native CUDA swizzled NVFP4 -> BF16 dequantization for the W4A16 quality path;
- Blackwell-validated sparse prefill attention wrapper;
- MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4,
  and FP4 block-score helpers.

When loaded from Hub built artifacts, the decode, FP4 indexer, and NVFP4
dequant hot paths use compiled CUDA ops. The source-tree mode keeps reference
paths so the API and correctness tests remain runnable before a wheel/shared
object has been built.

Source provenance and validation details are documented in `SYNC.md` and
`VALIDATION.md`.