| |
| """Benchmark fp4-gemm.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| import json |
| import sys |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| TEST_FILE = ROOT / "fp4-gemm" / "tests" / "test_fp4_gemm.py" |
|
|
|
|
| @dataclass |
| class BenchResult: |
| shape: str |
| M: int |
| N: int |
| K: int |
| variant: int |
| flashrt_us: float |
| torch_reference_us: float |
| speedup_vs_reference: float |
| max_abs: float |
| mean_abs: float |
| p99_abs: float |
| cosine: float |
| status: str |
|
|
|
|
| def load_helpers(): |
| spec = importlib.util.spec_from_file_location("fp4_gemm_test_helpers", TEST_FILE) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"cannot load helpers from {TEST_FILE}") |
| module = importlib.util.module_from_spec(spec) |
| sys.modules["fp4_gemm_test_helpers"] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def measure(fn, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return float(start.elapsed_time(end) * 1000.0 / iters) |
|
|
|
|
| def bench_case(helpers, ops, name: str, shape: tuple[int, int, int], warmup: int, iters: int) -> list[BenchResult]: |
| m, n, k = shape |
| a_packed, b_packed, sfa, sfb, expected = helpers.prepare_quantized(ops, m, n, k) |
| a_deq = torch.empty((m, k), device="cuda", dtype=torch.float16) |
| b_deq = torch.empty((n, k), device="cuda", dtype=torch.float16) |
| ops.dequantize_fp4_sfa_fp16(a_packed, sfa, a_deq, False) |
| ops.dequantize_fp4_sfa_fp16(b_packed, sfb, b_deq, True) |
| torch.cuda.synchronize() |
|
|
| def torch_ref(): |
| return (a_deq.float() @ b_deq.float().T).to(torch.bfloat16) |
|
|
| torch_us = measure(torch_ref, warmup, iters) |
| results: list[BenchResult] = [] |
| for variant in (0, 1, 2): |
| out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) |
| ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant) |
| torch.cuda.synchronize() |
| max_abs, mean_abs, p99_abs, cosine = helpers.metrics(out, expected) |
| flashrt_us = measure( |
| lambda: ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant), |
| warmup, |
| iters, |
| ) |
| results.append( |
| BenchResult( |
| shape=name, |
| M=m, |
| N=n, |
| K=k, |
| variant=variant, |
| flashrt_us=flashrt_us, |
| torch_reference_us=torch_us, |
| speedup_vs_reference=torch_us / flashrt_us, |
| max_abs=max_abs, |
| mean_abs=mean_abs, |
| p99_abs=p99_abs, |
| cosine=cosine, |
| status="ok", |
| ) |
| ) |
| return results |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--mode", choices=["smoke", "headline"], default="headline") |
| parser.add_argument("--warmup", type=int, default=20) |
| parser.add_argument("--iterations", type=int, default=100) |
| parser.add_argument("--json-out", default=None) |
| args = parser.parse_args() |
|
|
| helpers = load_helpers() |
| ops = helpers.load_source_ops() |
| shapes = { |
| "small_m16_n128_k128": (16, 128, 128), |
| "small_m32_n256_k256": (32, 256, 256), |
| "mlp_tile_m64_n512_k512": (64, 512, 512), |
| } |
| if args.mode == "smoke": |
| shapes = {"small_m16_n128_k128": shapes["small_m16_n128_k128"]} |
| results: list[BenchResult] = [] |
| for name, shape in shapes.items(): |
| results.extend(bench_case(helpers, ops, name, shape, args.warmup, args.iterations)) |
| payload = { |
| "mode": args.mode, |
| "device": torch.cuda.get_device_name(), |
| "torch": torch.__version__, |
| "results": [asdict(item) for item in results], |
| } |
| print(json.dumps(payload, indent=2)) |
| if args.json_out: |
| out = Path(args.json_out) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| out.write_text(json.dumps(payload, indent=2) + "\n") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|