liangsu9988's picture
Uploaded using `kernel-builder`.
2adb8f7 verified
Raw
History Blame
3.23 kB
#!/usr/bin/env python3
"""Benchmark vl-transformer-primitives against PyTorch eager references."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
PACKAGE = Path(__file__).resolve().parents[1]
TESTS = PACKAGE / "tests"
sys.path.insert(0, str(TESTS))
from test_vl_transformer_primitives import ( # noqa: E402
load_installed_ops,
load_source_ops,
make_decode_case,
ref_avg_pool,
ref_norm_rope,
)
def bench(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) * 1000.0 / iters
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--warmup", type=int, default=50)
parser.add_argument("--iters", type=int, default=500)
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required")
torch.manual_seed(1234)
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
print("| Workload | Shape | FlashRT us | PyTorch eager us | Speedup |")
print("|---|---:|---:|---:|---:|")
for heads in [1, 4, 8, 16, 32, 40]:
q, k, v, q_w, k_w, cos, sin = make_decode_case(heads)
fused = bench(
lambda: ops.qwen3_q_norm_rope_qstage_bf16(q, q_w, cos, sin),
args.warmup,
args.iters,
)
eager = bench(
lambda: ref_norm_rope(q, q_w, cos, sin, 1e-6),
args.warmup,
args.iters,
)
print(f"| q_norm_rope | heads={heads}, d=128 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |")
fused = bench(
lambda: ops.qwen3_k_norm_rope_kvwrite_bf16(k, v, k_w, cos, sin),
args.warmup,
args.iters,
)
eager = bench(
lambda: (ref_norm_rope(k, k_w, cos, sin, 1e-6), v.clone()),
args.warmup,
args.iters,
)
print(f"| k_norm_rope_vwrite | heads={heads}, d=128 | {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |")
for nv, h, w, dim, pool in [
(1, 16, 16, 1024, 2),
(2, 16, 16, 1152, 2),
(4, 16, 16, 2048, 4),
(2, 32, 32, 1024, 4),
]:
x = torch.randn((nv * h * w, dim), device="cuda", dtype=torch.bfloat16)
fused = bench(
lambda: ops.avg_pool_vision_tokens_bf16(x, nv, h, w, pool),
args.warmup,
args.iters,
)
eager = bench(
lambda: ref_avg_pool(x, nv, h, w, pool),
args.warmup,
args.iters,
)
print(
f"| avg_pool_vision | nv={nv}, h={h}, w={w}, dim={dim}, pool={pool} "
f"| {fused:.3f} | {eager:.3f} | {eager / fused:.2f}x |"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())