File size: 4,358 Bytes
4fb53e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | #!/usr/bin/env python3
"""Benchmark fp4-gemm."""
from __future__ import annotations
import argparse
import importlib.util
import json
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
TEST_FILE = ROOT / "fp4-gemm" / "tests" / "test_fp4_gemm.py"
@dataclass
class BenchResult:
shape: str
M: int
N: int
K: int
variant: int
flashrt_us: float
torch_reference_us: float
speedup_vs_reference: float
max_abs: float
mean_abs: float
p99_abs: float
cosine: float
status: str
def load_helpers():
spec = importlib.util.spec_from_file_location("fp4_gemm_test_helpers", TEST_FILE)
if spec is None or spec.loader is None:
raise RuntimeError(f"cannot load helpers from {TEST_FILE}")
module = importlib.util.module_from_spec(spec)
sys.modules["fp4_gemm_test_helpers"] = module
spec.loader.exec_module(module)
return module
def measure(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return float(start.elapsed_time(end) * 1000.0 / iters)
def bench_case(helpers, ops, name: str, shape: tuple[int, int, int], warmup: int, iters: int) -> list[BenchResult]:
m, n, k = shape
a_packed, b_packed, sfa, sfb, expected = helpers.prepare_quantized(ops, m, n, k)
a_deq = torch.empty((m, k), device="cuda", dtype=torch.float16)
b_deq = torch.empty((n, k), device="cuda", dtype=torch.float16)
ops.dequantize_fp4_sfa_fp16(a_packed, sfa, a_deq, False)
ops.dequantize_fp4_sfa_fp16(b_packed, sfb, b_deq, True)
torch.cuda.synchronize()
def torch_ref():
return (a_deq.float() @ b_deq.float().T).to(torch.bfloat16)
torch_us = measure(torch_ref, warmup, iters)
results: list[BenchResult] = []
for variant in (0, 1, 2):
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant)
torch.cuda.synchronize()
max_abs, mean_abs, p99_abs, cosine = helpers.metrics(out, expected)
flashrt_us = measure(
lambda: ops.fp4_w4a16_linear_bf16(a_packed, b_packed, sfa, sfb, out, 1.0, variant),
warmup,
iters,
)
results.append(
BenchResult(
shape=name,
M=m,
N=n,
K=k,
variant=variant,
flashrt_us=flashrt_us,
torch_reference_us=torch_us,
speedup_vs_reference=torch_us / flashrt_us,
max_abs=max_abs,
mean_abs=mean_abs,
p99_abs=p99_abs,
cosine=cosine,
status="ok",
)
)
return results
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["smoke", "headline"], default="headline")
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iterations", type=int, default=100)
parser.add_argument("--json-out", default=None)
args = parser.parse_args()
helpers = load_helpers()
ops = helpers.load_source_ops()
shapes = {
"small_m16_n128_k128": (16, 128, 128),
"small_m32_n256_k256": (32, 256, 256),
"mlp_tile_m64_n512_k512": (64, 512, 512),
}
if args.mode == "smoke":
shapes = {"small_m16_n128_k128": shapes["small_m16_n128_k128"]}
results: list[BenchResult] = []
for name, shape in shapes.items():
results.extend(bench_case(helpers, ops, name, shape, args.warmup, args.iterations))
payload = {
"mode": args.mode,
"device": torch.cuda.get_device_name(),
"torch": torch.__version__,
"results": [asdict(item) for item in results],
}
print(json.dumps(payload, indent=2))
if args.json_out:
out = Path(args.json_out)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, indent=2) + "\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())
|