fp4-gemm / benchmarks /benchmark.py
liangsu9988's picture
Uploaded using `kernel-builder`.
4fb53e5 verified
Raw
History Blame
4.36 kB
#!/usr/bin/env python3
"""Benchmark fp4-gemm."""
from __future__ import annotations
import argparse
import importlib.util
import json
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
TEST_FILE = ROOT / "fp4-gemm" / "tests" / "test_fp4_gemm.py"
@dataclass
class BenchResult:
shape: str
M: int
N: int
K: int
variant: int
flashrt_us: float
torch_reference_us: float
speedup_vs_reference: float
max_abs: float
mean_abs: float
p99_abs: float
cosine: float
status: str
def load_helpers():
spec = importlib.util.spec_from_file_location("fp4_gemm_test_helpers", TEST_FILE)
if spec is None or spec.loader is None:
raise RuntimeError(f"cannot load helpers from {TEST_FILE}")
module = importlib.util.module_from_spec(spec)
sys.modules["fp4_gemm_test_helpers"] = module
spec.loader.exec_module(module)
return module
def measure(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return float(start.elapsed_time(end) * 1000.0 / iters)
def bench_case(helpers, ops, name: str, shape: tuple[int, int, int], warmup: int, iters: int) -> list[BenchResult]:
m, n, k = shape
a_packed, b_packed, sfa, sfb, expected = helpers.prepare_quantized(ops, m, n, k)
a_deq = torch.empty((m, k), device="cuda", dtype=torch.float16)
b_deq = torch.empty((n, k), device="cuda", dtype=torch.float16)
ops.dequantize_fp4_sfa_fp16(a_packed, sfa, a_deq, False)
ops.dequantize_fp4_sfa_fp16(b_packed, sfb, b_deq, True)
torch.cuda.synchronize()
def torch_ref():
return (a_deq.float() @ b_deq.float().T).to(torch.bfloat16)
torch_us = measure(torch_ref, warmup, iters)
results: list[BenchResult] = []
for variant in (0, 1, 2):
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant)
torch.cuda.synchronize()
max_abs, mean_abs, p99_abs, cosine = helpers.metrics(out, expected)
flashrt_us = measure(
lambda: ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant),
warmup,
iters,
)
results.append(
BenchResult(
shape=name,
M=m,
N=n,
K=k,
variant=variant,
flashrt_us=flashrt_us,
torch_reference_us=torch_us,
speedup_vs_reference=torch_us / flashrt_us,
max_abs=max_abs,
mean_abs=mean_abs,
p99_abs=p99_abs,
cosine=cosine,
status="ok",
)
)
return results
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["smoke", "headline"], default="headline")
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iterations", type=int, default=100)
parser.add_argument("--json-out", default=None)
args = parser.parse_args()
helpers = load_helpers()
ops = helpers.load_source_ops()
shapes = {
"small_m16_n128_k128": (16, 128, 128),
"small_m32_n256_k256": (32, 256, 256),
"mlp_tile_m64_n512_k512": (64, 512, 512),
}
if args.mode == "smoke":
shapes = {"small_m16_n128_k128": shapes["small_m16_n128_k128"]}
results: list[BenchResult] = []
for name, shape in shapes.items():
results.extend(bench_case(helpers, ops, name, shape, args.warmup, args.iterations))
payload = {
"mode": args.mode,
"device": torch.cuda.get_device_name(),
"torch": torch.__version__,
"results": [asdict(item) for item in results],
}
print(json.dumps(payload, indent=2))
if args.json_out:
out = Path(args.json_out)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, indent=2) + "\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())