| import argparse |
| import time |
|
|
| import torch |
|
|
| from minimaxai_msa_blackwell import ( |
| flash_decode_with_gqa_share_sparse, |
| has_native_ops, |
| native_topk_from_scores, |
| ) |
|
|
|
|
| def make_case(ctx: int, seed: int = 0): |
| g = torch.Generator(device="cuda").manual_seed(seed) |
| batch, hq, hkv, d = 1, 64, 4, 128 |
| block, topk = 128, 16 |
| q = torch.randn(batch, hq, d, device="cuda", dtype=torch.bfloat16, generator=g) |
| k_cache = torch.randn(ctx, hkv, d, device="cuda", dtype=torch.bfloat16, generator=g) |
| v_cache = torch.randn(ctx, hkv, d, device="cuda", dtype=torch.bfloat16, generator=g) |
| req_to_token = torch.arange(ctx, device="cuda", dtype=torch.int32).view(1, -1) |
| seq_lens = torch.tensor([ctx], device="cuda", dtype=torch.int32) |
| slot_ids = torch.zeros(batch, device="cuda", dtype=torch.int64) |
| nb = (ctx + block - 1) // block |
| n = min(topk, nb) |
| topk_idx = torch.full((hkv, batch, topk), -1, device="cuda", dtype=torch.int32) |
| topk_idx[:, :, :n] = torch.arange(n, device="cuda", dtype=torch.int32).view(1, 1, n) |
| return q, None, k_cache, v_cache, req_to_token, seq_lens, slot_ids, block, topk_idx |
|
|
|
|
| def bench(ctx: int, warmup: int, iters: int) -> float: |
| args = make_case(ctx) |
| for _ in range(warmup): |
| flash_decode_with_gqa_share_sparse(*args) |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| for _ in range(iters): |
| flash_decode_with_gqa_share_sparse(*args) |
| torch.cuda.synchronize() |
| return (time.perf_counter() - start) * 1e6 / iters |
|
|
|
|
| def bench_native_topk(ctx: int, warmup: int, iters: int) -> float | None: |
| if not has_native_ops(): |
| return None |
| heads, batch, block, topk = 64, 1, 128, 16 |
| blocks = (ctx + block - 1) // block |
| score = torch.randn(heads, batch, blocks, device="cuda", dtype=torch.float32) |
| seq_lens = torch.tensor([ctx], device="cuda", dtype=torch.int32) |
| for _ in range(warmup): |
| native_topk_from_scores(score, seq_lens, block, topk) |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| for _ in range(iters): |
| native_topk_from_scores(score, seq_lens, block, topk) |
| torch.cuda.synchronize() |
| return (time.perf_counter() - start) * 1e6 / iters |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ctx", type=int, nargs="+", default=[2048, 4096, 32768, 65536, 131072]) |
| ap.add_argument("--warmup", type=int, default=10) |
| ap.add_argument("--iters", type=int, default=50) |
| args = ap.parse_args() |
|
|
| print("gpu:", torch.cuda.get_device_name()) |
| print("ctx,attention_mean_us,native_topk_mean_us") |
| for ctx in args.ctx: |
| attn_us = bench(ctx, args.warmup, args.iters) |
| topk_us = bench_native_topk(ctx, args.warmup, args.iters) |
| topk_text = "NA" if topk_us is None else f"{topk_us:.3f}" |
| print(f"{ctx},{attn_us:.3f},{topk_text}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|