fp8-gemm / benchmarks /benchmark.py
liangsu9988's picture
Uploaded using `kernel-builder`.
166f09b verified
Raw
History Blame
8.79 kB
#!/usr/bin/env python3
"""Benchmark fp8-gemm."""
from __future__ import annotations
import argparse
import importlib
import json
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
PACKAGE = ROOT / "fp8-gemm"
REGISTRATION_INCLUDE = (
ROOT.parent
/ "kernels"
/ "kernel-builder"
/ "src"
/ "pyproject"
/ "templates"
/ "torch"
)
SHAPES = {
"decode_m1_k4096_n2048": (1, 4096, 2048),
"decode_m1_k4096_n8192": (1, 4096, 8192),
"small_m16_k4096_n4096": (16, 4096, 4096),
"small_m32_k4096_n8192": (32, 4096, 8192),
"small_m64_k512_n1024": (64, 512, 1024),
}
MODES = {
"smoke": ["decode_m1_k4096_n2048", "small_m16_k4096_n4096"],
"headline": list(SHAPES),
}
@dataclass
class Result:
shape: str
M: int
K: int
N: int
variant: int
tile: str
flashrt_us: float
torch_eager_us: float
torch_compile_us: float | None
speedup_vs_eager: float
speedup_vs_compile: float | None
max_abs: float
mean_abs: float
p99_abs: float
cosine: float
status: str
class SourceOps:
def __init__(self, namespace: str) -> None:
self._ops = getattr(torch.ops, namespace)
@staticmethod
def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
return select_tile(m, n, k, variant)
def fp8_linear_bf16(self, x, w, alpha=1.0, out=None, variant=0):
if out is None:
out = torch.empty((x.shape[0], w.shape[0]), device=x.device, dtype=torch.bfloat16)
self._ops.fp8_linear_bf16(x, w, float(alpha), int(variant), out)
return out
def _current_arch_list() -> str:
major, minor = torch.cuda.get_device_capability(0)
if major >= 12:
return "12.0a"
return f"{major}.{minor}"
def load_source_ops() -> SourceOps:
from torch.utils.cpp_extension import load
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
namespace = "fp8_gemm_source_bench"
load(
name=namespace,
sources=[
str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
str(PACKAGE / "csrc" / "fp8_gemv_m1_sm120.cu"),
str(PACKAGE / "csrc" / "fp8_smallM_handtuned_sm120.cu"),
str(PACKAGE / "csrc" / "fp8_smallM_handtuned_ldmatrix_sm120.cu"),
],
extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
extra_cflags=["-O3", "-DCUDA_KERNEL"],
extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
verbose=False,
)
return SourceOps(namespace)
def load_installed_ops(artifact: str | None):
if artifact:
sys.path.insert(0, artifact)
try:
return importlib.import_module("fp8_gemm")
finally:
if artifact:
sys.path.remove(artifact)
def select_tile(m: int, n: int, k: int, variant: int = 0) -> str:
if m == 1:
if variant == 4:
return "gemv_fp8_m1_w4"
if variant == 8:
return "gemv_fp8_m1_w8"
if variant == 16:
return "gemv_fp8_m1_w16"
if n <= 2048:
return "gemv_fp8_m1_w4"
if n <= 8192:
return "gemv_fp8_m1_w8"
return "gemv_fp8_m1_w16"
if m <= 16:
if k % 256 == 0:
return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
if n % 256 == 0:
return "ld_fp8_gemm_16x256x128_w8"
if n % 192 == 0:
return "ld_fp8_gemm_16x192x128_w4"
if n % 128 == 0:
return "ld_fp8_gemm_16x128x128_w4"
return "ld_fp8_gemm_16x64x128_w4"
if m <= 32:
if k % 256 == 0:
return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
if n % 192 == 0:
return "ld_fp8_gemm_32x192x128_w4"
if n % 128 == 0:
return "ld_fp8_gemm_32x128x128_w4"
return "ld_fp8_gemm_32x64x128_w4"
if m <= 64:
if k % 256 == 0:
return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
if n % 128 == 0:
return "ld_fp8_gemm_64x128x128_w4"
return "ld_fp8_gemm_64x64x128_w4"
if m <= 64:
if k % 256 == 0:
return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
if n % 128 == 0:
return "ld_fp8_gemm_64x128x128_w4"
return "ld_fp8_gemm_64x64x128_w4"
raise RuntimeError("unsupported M")
def make_inputs(m: int, k: int, n: int, seed: int):
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
x = (torch.randn((m, k), device="cuda", generator=gen) * 0.25).to(torch.bfloat16).to(torch.float8_e4m3fn)
w = (torch.randn((n, k), device="cuda", generator=gen) * 0.25).to(torch.bfloat16).to(torch.float8_e4m3fn)
return x, w
def ref_fn(x, w):
return (x.float() @ w.float().T).to(torch.bfloat16)
def measure(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return float(start.elapsed_time(end) * 1000.0 / iters)
def metrics(got, expected):
diff = (got.float() - expected.float()).abs().flatten()
return (
float(diff.max().item()),
float(diff.mean().item()),
float(torch.quantile(diff, 0.99).item()),
float(torch.nn.functional.cosine_similarity(got.float().flatten(), expected.float().flatten(), dim=0).item()),
)
def bench_case(ops, name: str, shape: tuple[int, int, int], variant: int, warmup: int, iters: int, compile_ref: bool):
m, k, n = shape
x, w = make_inputs(m, k, n, seed=3000 + m + k + n + variant)
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
expected = ref_fn(x, w)
got = ops.fp8_linear_bf16(x, w, out=out, variant=variant)
torch.cuda.synchronize()
max_abs, mean_abs, p99_abs, cos = metrics(got, expected)
tile = ops.select_fp8_linear_tile(m, n, k, variant)
flashrt_us = measure(lambda: ops.fp8_linear_bf16(x, w, out=out, variant=variant), warmup, iters)
eager_us = measure(lambda: ref_fn(x, w), warmup, iters)
compile_us = None
if compile_ref:
try:
compiled = torch.compile(ref_fn, fullgraph=True)
compiled(x, w)
torch.cuda.synchronize()
compile_us = measure(lambda: compiled(x, w), warmup, iters)
except Exception:
compile_us = None
return Result(
shape=name,
M=m,
K=k,
N=n,
variant=variant,
tile=tile,
flashrt_us=flashrt_us,
torch_eager_us=eager_us,
torch_compile_us=compile_us,
speedup_vs_eager=eager_us / flashrt_us,
speedup_vs_compile=(compile_us / flashrt_us) if compile_us else None,
max_abs=max_abs,
mean_abs=mean_abs,
p99_abs=p99_abs,
cosine=cos,
status="pass" if max_abs <= 0.5 and p99_abs <= 0.25 and cos >= 0.999 else "fail",
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--mode", choices=sorted(MODES), default="smoke")
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iterations", type=int, default=100)
parser.add_argument("--compile-ref", action="store_true")
parser.add_argument("--json-out", default=None)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
major, _minor = torch.cuda.get_device_capability(0)
if major < 12:
raise SystemExit("fp8-gemm requires Blackwell/SM120 for this package")
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
rows: list[Result] = []
for name in MODES[args.mode]:
shape = SHAPES[name]
variants = [0]
if shape[0] == 1:
variants = [0, 4, 8, 16]
for variant in variants:
rows.append(bench_case(ops, name, shape, variant, args.warmup, args.iterations, args.compile_ref))
payload = {"rows": [asdict(row) for row in rows]}
print(json.dumps(payload, indent=2, sort_keys=True))
if args.json_out:
Path(args.json_out).write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
if any(row.status != "pass" for row in rows):
raise SystemExit(1)
if __name__ == "__main__":
main()