#!/usr/bin/env python3 """Benchmark turboquant-kv source or installed artifacts.""" from __future__ import annotations import argparse import sys from pathlib import Path import torch PACKAGE = Path(__file__).resolve().parents[1] sys.path.insert(0, str(PACKAGE / "tests")) from test_turboquant_kv import load_installed_ops, load_source_ops, ref_unpack # noqa: E402 def bench(fn, warmup: int, iters: int) -> float: for _ in range(warmup): fn() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(iters): fn() end.record() torch.cuda.synchronize() return start.elapsed_time(end) * 1000.0 / iters def make_packed(m: int): k_idx = torch.randint(0, 256, (m, 128), device="cuda", dtype=torch.uint8) k_qjl = torch.randint(0, 256, (m, 32), device="cuda", dtype=torch.uint8) v_idx = torch.randint(0, 256, (m, 128), device="cuda", dtype=torch.uint8) cb_k = torch.randn((16,), device="cuda", dtype=torch.float32) cb_v = torch.randn((16,), device="cuda", dtype=torch.float32) return k_idx, k_qjl, v_idx, cb_k, cb_v def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--backend", choices=["source", "installed"], default="source") parser.add_argument("--artifact", default=None) parser.add_argument("--warmup", type=int, default=100) parser.add_argument("--iters", type=int, default=1000) args = parser.parse_args() if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") torch.manual_seed(1234) ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) print("| Workload | Shape | FlashRT us | PyTorch eager us | Speedup |") print("|---|---:|---:|---:|---:|") for m in [1, 4, 128, 1024, 4096, 32768]: k_idx, k_qjl, v_idx, cb_k, cb_v = make_packed(m) b_k, b_v = 3, 4 fused = bench( lambda: ops.unpack_packed_bf16(k_idx, k_qjl, v_idx, cb_k, cb_v, b_k, b_v), args.warmup, args.iters, ) eager = bench( lambda: ref_unpack(k_idx, k_qjl, v_idx, cb_k, cb_v, b_k, b_v, torch.bfloat16), args.warmup, args.iters, ) print(f"| unpack_packed_bf16 | M={m}, D=256, bits=3/4 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |") fused = bench( lambda: ops.unpack_packed_mixed(k_idx, k_qjl, v_idx, cb_k, cb_v, b_k, b_v), args.warmup, args.iters, ) eager = bench( lambda: ref_unpack(k_idx, k_qjl, v_idx, cb_k, cb_v, b_k, b_v, torch.float32), args.warmup, args.iters, ) print(f"| unpack_packed_mixed | M={m}, D=256, bits=3/4 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |") for m in [1, 4, 128, 1024, 4096, 32768]: k_mse = torch.randn((m, 256), device="cuda", dtype=torch.bfloat16) k_qjl = torch.randn((m, 256), device="cuda", dtype=torch.bfloat16) v_unit = torch.randn((m, 256), device="cuda", dtype=torch.bfloat16) k_norm = torch.rand((m,), device="cuda", dtype=torch.float16) + 0.5 k_rnorm = torch.rand((m,), device="cuda", dtype=torch.float16) + 0.5 v_norm = torch.rand((m,), device="cuda", dtype=torch.float16) + 0.5 coef = 0.125 fused = bench( lambda: ops.combine_kv_bf16(k_mse, k_qjl, v_unit, k_norm, k_rnorm, v_norm, coef), args.warmup, args.iters, ) eager = bench( lambda: ( (k_norm.float().unsqueeze(1) * (k_mse.float() + coef * k_rnorm.float().unsqueeze(1) * k_qjl.float())).to(torch.bfloat16), (v_norm.float().unsqueeze(1) * v_unit.float()).to(torch.bfloat16), ), args.warmup, args.iters, ) print(f"| combine_kv_bf16 | M={m}, D=256 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |") return 0 if __name__ == "__main__": raise SystemExit(main())