| |
| """Benchmark vl-transformer-primitives against PyTorch eager references.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| PACKAGE = Path(__file__).resolve().parents[1] |
| TESTS = PACKAGE / "tests" |
| sys.path.insert(0, str(TESTS)) |
| from test_vl_transformer_primitives import ( |
| load_installed_ops, |
| load_source_ops, |
| make_decode_case, |
| ref_avg_pool, |
| ref_norm_rope, |
| ) |
|
|
|
|
| def bench(fn, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) * 1000.0 / iters |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--warmup", type=int, default=50) |
| parser.add_argument("--iters", type=int, default=500) |
| args = parser.parse_args() |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required") |
| torch.manual_seed(1234) |
| ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) |
|
|
| print("| Workload | Shape | FlashRT us | PyTorch eager us | Speedup |") |
| print("|---|---:|---:|---:|---:|") |
|
|
| for heads in [1, 4, 8, 16, 32, 40]: |
| q, k, v, q_w, k_w, cos, sin = make_decode_case(heads) |
| fused = bench( |
| lambda: ops.qwen3_q_norm_rope_qstage_bf16(q, q_w, cos, sin), |
| args.warmup, |
| args.iters, |
| ) |
| eager = bench( |
| lambda: ref_norm_rope(q, q_w, cos, sin, 1e-6), |
| args.warmup, |
| args.iters, |
| ) |
| print(f"| q_norm_rope | heads={heads}, d=128 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |") |
|
|
| fused = bench( |
| lambda: ops.qwen3_k_norm_rope_kvwrite_bf16(k, v, k_w, cos, sin), |
| args.warmup, |
| args.iters, |
| ) |
| eager = bench( |
| lambda: (ref_norm_rope(k, k_w, cos, sin, 1e-6), v.clone()), |
| args.warmup, |
| args.iters, |
| ) |
| print(f"| k_norm_rope_vwrite | heads={heads}, d=128 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |") |
|
|
| for nv, h, w, dim, pool in [ |
| (1, 16, 16, 1024, 2), |
| (2, 16, 16, 1152, 2), |
| (4, 16, 16, 2048, 4), |
| (2, 32, 32, 1024, 4), |
| ]: |
| x = torch.randn((nv * h * w, dim), device="cuda", dtype=torch.bfloat16) |
| fused = bench( |
| lambda: ops.avg_pool_vision_tokens_bf16(x, nv, h, w, pool), |
| args.warmup, |
| args.iters, |
| ) |
| eager = bench( |
| lambda: ref_avg_pool(x, nv, h, w, pool), |
| args.warmup, |
| args.iters, |
| ) |
| print( |
| f"| avg_pool_vision | nv={nv}, h={h}, w={w}, dim={dim}, pool={pool} " |
| f"| {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |" |
| ) |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|