liangsu9988's picture
Update Blackwell MSA native API card
5a7fea3
|
Raw
History Blame
7.28 kB
metadata
library_name: kernels
license: apache-2.0
tags:
  - cuda
  - triton
  - native-cuda
  - minimax
  - sparse-attention
  - blackwell

MiniMaxAI MSA Blackwell

Blackwell-family package for MiniMax MSA decode sparse attention, maintained by FlashRT. The upstream package is MiniMaxAI/msa, which targets SM100. This package extends the decode-sparse path to NVIDIA Blackwell compute capability 12.x and has been validated in the FlashRT MiniMax-Spark runtime on DGX Spark / GB10 / SM121.

Load

from kernels import get_kernel

msa = get_kernel(
    "flashrt/MiniMaxAI-msa-blackwell",
    version=1,
    trust_remote_code=True,
)

What You Can Call

Official MiniMaxAI/msa names

Function/class Status in this Blackwell package
sparse_decode_atten_func Available. Blackwell paged BF16/FP16 single-token decode wrapper.
SparseDecodePagedAttentionWrapper Available. plan(...).run(...) wrapper for the same decode path.
build_k2q_csr Available. CSR construction helper for the official prefill API.
SparseK2qCsrBuilderSm100 Available compatibility class; build() delegates to build_k2q_csr.
Nvfp4QuantizedTensor Available metadata dataclass.
quantize_bf16_to_nvfp4_128x4 Available when Transformer Engine NVFP4 support is installed.
quantize_kv_bf16_to_nvfp4_128x4 Available when Transformer Engine NVFP4 support is installed.
dequantize_nvfp4_128x4_to_bf16 Available reference dequantizer.
swizzle_nvfp4_scale_to_128x4 Available scale-layout helper.
nvfp4_global_scale_from_amax Available scale helper.
sparse_atten_func Available. Official CSR sparse prefill API backed by the Blackwell Triton BF16/FP16 prefill kernel.
sparse_atten_nvfp4_kv_func Available. Built artifacts use native CUDA swizzled NVFP4 -> BF16 dequantization, then call Blackwell sparse prefill.
fp4_indexer_block_scores Available. Built artifacts use the native CUDA Blackwell block-score kernel and return the official [Hq, ceil(max_seqlen_k/128), total_q] score layout.

FlashRT Blackwell helper names

These are the direct low-level APIs used by the FlashRT MiniMax-Spark decode path:

  • flash_decode_with_topk_idx
  • flash_decode_with_gqa_share_sparse
  • native_topk_from_scores
  • native_nvfp4_dequant_swizzled_to_bf16
  • has_native_ops
  • naive_flash_decode_with_topk_idx
  • naive_flash_decode_with_gqa_share_sparse
  • get_cu_seqblocks
  • robust_allocator

Decode Example

This example uses the official MiniMax decode-facing name sparse_decode_atten_func.

import torch
from kernels import get_kernel

msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)

B, Hq, Hkv, D = 1, 64, 4, 128
page_size = 128
num_pages = 32
topk = 16
device = "cuda"
dtype = torch.bfloat16

q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k = torch.randn(num_pages, Hkv, page_size, D, device=device, dtype=dtype)
v = torch.randn_like(k)
page_table = torch.arange(num_pages, device=device, dtype=torch.int32).view(B, -1)
seqused_k = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
q2k_indices = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
q2k_indices = q2k_indices.expand(Hkv, B, topk).contiguous()

out = msa.sparse_decode_atten_func(
    q,
    k,
    v,
    q2k_indices,
    page_table=page_table,
    seqused_k=seqused_k,
    seqlen_q=1,
    max_seqlen_k=num_pages * page_size,
    blk_kv=page_size,
)

Wrapper Example

wrapper = msa.SparseDecodePagedAttentionWrapper(blk_kv=128)
wrapper.plan(
    page_table=page_table,
    seqused_k=seqused_k,
    seqlen_q=1,
    max_seqlen_k=num_pages * page_size,
    q2k_indices=q2k_indices,
    num_qo_heads=Hq,
    num_kv_heads=Hkv,
    head_dim=D,
)
out = wrapper.run(q, k, v)

Prefill Example

This example uses the official MiniMax CSR prefill-facing name sparse_atten_func.

import torch
from kernels import get_kernel

msa = get_kernel("flashrt/MiniMaxAI-msa-blackwell", version=1, trust_remote_code=True)

T, Hq, Hkv, D = 512, 64, 4, 128
page_size = 128
topk = 16
device = "cuda"
dtype = torch.bfloat16

q = torch.randn(T, Hq, D, device=device, dtype=dtype)
k = torch.randn(T, Hkv, D, device=device, dtype=dtype)
v = torch.randn_like(k)
cu = torch.tensor([0, T], device=device, dtype=torch.int32)

q2k = torch.full((Hkv, T, topk), -1, device=device, dtype=torch.int32)
for qi in range(T):
    blocks = torch.arange(qi // page_size + 1, device=device, dtype=torch.int32)
    q2k[:, qi, : min(topk, blocks.numel())] = blocks[:topk]

k2q_row_ptr, k2q_q_indices = msa.build_k2q_csr(
    q2k, cu, cu, page_size, total_k=T
)
out = msa.sparse_atten_func(
    q,
    k,
    v,
    k2q_row_ptr,
    k2q_q_indices,
    topk,
    cu_seqlens_q=cu,
    cu_seqlens_k=cu,
    max_seqlen_q=T,
    max_seqlen_k=T,
    blk_kv=page_size,
)

Direct FlashRT Decode Path

Use this lower-level path if you already have topk_idx in FlashRT's paged KV layout.

q = torch.randn(B, Hq, D, device=device, dtype=dtype)
k_cache = torch.randn(num_pages * page_size, Hkv, D, device=device, dtype=dtype)
v_cache = torch.randn_like(k_cache)
req_to_token = torch.arange(num_pages * page_size, device=device, dtype=torch.int32).view(B, -1)
seq_lens = torch.tensor([num_pages * page_size], device=device, dtype=torch.int32)
slot_ids = torch.zeros(B, device=device, dtype=torch.int64)
topk_idx = torch.arange(topk, device=device, dtype=torch.int32).view(1, 1, topk)
topk_idx = topk_idx.expand(Hkv, B, topk).contiguous()

out = msa.flash_decode_with_gqa_share_sparse(
    q, None, k_cache, v_cache, req_to_token, seq_lens, slot_ids, page_size, topk_idx
)

Scope

  • Target family: NVIDIA Blackwell CUDA compute capability 12.x.
  • Builder target: CUDA 12.8+ with cuda-capabilities = ["12.0", "12.1"].
  • Validated hardware: DGX Spark / GB10 / SM121.
  • Validated MiniMax shape: query heads 64, KV heads 4, head dim 128, sparse block/page size 128, top-k blocks 16.
  • End-to-end model validation: MiniMax-Spark runtime on GB10 through 32768 context length.
  • Standalone kernel long-context validation: 128, 2048, 4096, 32768, 65536, 131072.
  • Correctness gate: cosine similarity >= 0.999 against paged FP32 PyTorch references; official decode wrapper output matches the direct Blackwell decode kernel.

Implementation Notes

This package contains:

  • native CUDA score-to-top-k helper;
  • native CUDA tensor-core sparse decode route for the MiniMax-M3 Blackwell shape;
  • native CUDA FP4 block-score indexer;
  • native CUDA swizzled NVFP4 -> BF16 dequantization for the W4A16 quality path;
  • Blackwell-validated sparse prefill attention wrapper;
  • MiniMaxAI/msa-compatible Python API layer for decode, prefill, CSR, NVFP4, and FP4 block-score helpers.

When loaded from Hub built artifacts, the decode, FP4 indexer, and NVFP4 dequant hot paths use compiled CUDA ops. The source-tree mode keeps reference paths so the API and correctness tests remain runnable before a wheel/shared object has been built.

Source provenance and validation details are documented in SYNC.md and VALIDATION.md.