liangsu9988's picture
Uploaded using `kernel-builder`.
a6a49dc verified
Raw
History Blame
10.3 kB
#!/usr/bin/env python3
"""Benchmark flashrt-residual-norm-quant against PyTorch eager references."""
from __future__ import annotations
import argparse
import ctypes
import ctypes.util
import importlib
import json
import math
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
PACKAGE = ROOT / "flashrt-residual-norm-quant"
REGISTRATION_INCLUDE = (
ROOT.parent
/ "kernels"
/ "kernel-builder"
/ "src"
/ "pyproject"
/ "templates"
/ "torch"
)
SHAPES = {
"pi05_decoder": (10, 1024),
"pi05_vision": (512, 1152),
"groot_vl": (1024, 2048),
"video_prefill": (2520, 2048),
}
SHAPE_GROUPS = {
"smoke": ["pi05_decoder"],
"headline": ["pi05_decoder", "pi05_vision", "groot_vl"],
"all": list(SHAPES.keys()),
}
@dataclass
class Result:
shape: str
rows: int
dim: int
kernel: str
flashrt_us: float
torch_eager_us: float
speedup_vs_eager: float
max_abs: float
mean_abs: float
p99_abs: float
cosine: float
status: str
class SourceOps:
def __init__(self, namespace: str) -> None:
self._ops = getattr(torch.ops, namespace)
def rms_norm_quant_fp8_static_bf16(self, x, weight, scale, eps=1e-6, out=None):
if out is None:
out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
self._ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
return out
def residual_add_rms_norm_quant_fp8_static_bf16(
self, residual, x, weight, scale, eps=1e-6, out=None
):
if out is None:
out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
self._ops.residual_add_rms_norm_quant_fp8_static_bf16(
residual, x, weight, scale, float(eps), out
)
return out
def _preload_cublaslt() -> None:
for parent in Path(torch.__file__).resolve().parents:
candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12"
if candidate.exists():
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
return
library = ctypes.util.find_library("cublasLt")
if library:
ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL)
def _current_arch_list() -> str:
major, minor = torch.cuda.get_device_capability(0)
return f"{major}.{minor}"
def load_source_ops() -> SourceOps:
from torch.utils.cpp_extension import load
if not REGISTRATION_INCLUDE.is_dir():
raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}")
_preload_cublaslt()
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
namespace = "flashrt_residual_norm_quant_benchmark"
load(
name=namespace,
sources=[
str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
str(PACKAGE / "csrc" / "residual_norm_quant.cu"),
],
extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
extra_cflags=["-O3", "-DCUDA_KERNEL"],
extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
verbose=False,
)
return SourceOps(namespace)
def load_installed_ops(artifact: str | None):
if artifact:
sys.path.insert(0, artifact)
try:
return importlib.import_module("flashrt_residual_norm_quant")
finally:
if artifact:
sys.path.remove(artifact)
def quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return torch.clamp(x.float() / scale.float(), -448.0, 448.0).to(torch.float8_e4m3fn)
def torch_rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=1, keepdim=True) + eps)
return x.float() * rms * weight.float()
def torch_rms_norm_quant(x, weight, scale, eps) -> torch.Tensor:
return quantize_fp8(torch_rms_norm(x, weight, eps), scale)
def torch_residual_add_rms_norm_quant(residual, x, weight, scale, eps) -> torch.Tensor:
added = residual.float() + x.float()
residual.copy_(added.to(torch.bfloat16))
rms = torch.rsqrt(torch.mean(added * added, dim=1, keepdim=True) + eps)
return quantize_fp8(residual.float() * rms * weight.float(), scale)
def make_case(rows: int, dim: int):
x = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16)
residual = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16)
weight = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
scale = torch.tensor([0.04], device="cuda", dtype=torch.float32)
out = torch.empty((rows, dim), device="cuda", dtype=torch.float8_e4m3fn)
return x, residual, weight, scale, out
def time_us(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) * 1000.0 / iters
def percentile(x: torch.Tensor, q: float) -> torch.Tensor:
flat = x.flatten()
k = max(1, min(flat.numel(), math.ceil(q * flat.numel())))
return flat.kthvalue(k).values
def metrics(got: torch.Tensor, expected: torch.Tensor):
diff = (got.float() - expected.float()).abs().flatten()
cosine = torch.nn.functional.cosine_similarity(
got.float().flatten(), expected.float().flatten(), dim=0
)
return {
"max_abs": float(diff.max().item()),
"mean_abs": float(diff.mean().item()),
"p99_abs": float(percentile(diff, 0.99).item()),
"cosine": float(cosine.item()),
}
def run_one(ops, name: str, rows: int, dim: int, args) -> list[Result]:
x, residual, weight, scale, out = make_case(rows, dim)
eps = args.eps
results = []
got = ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out)
expected = torch_rms_norm_quant(x, weight, scale, eps)
m = metrics(got, expected)
kernel_us = time_us(
lambda: ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out),
args.warmup,
args.iters,
)
torch_us = time_us(lambda: torch_rms_norm_quant(x, weight, scale, eps), args.warmup, args.iters)
results.append(
Result(
shape=name,
rows=rows,
dim=dim,
kernel="rms_norm_quant_fp8_static_bf16",
flashrt_us=kernel_us,
torch_eager_us=torch_us,
speedup_vs_eager=torch_us / kernel_us,
status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL",
**m,
)
)
residual0 = residual.clone()
residual_kernel = residual0.clone()
got = ops.residual_add_rms_norm_quant_fp8_static_bf16(
residual_kernel, x, weight, scale, eps, out
)
residual_ref = residual0.clone()
expected = torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps)
m = metrics(got, expected)
residual_kernel = residual0.clone()
residual_ref = residual0.clone()
kernel_us = time_us(
lambda: ops.residual_add_rms_norm_quant_fp8_static_bf16(
residual_kernel, x, weight, scale, eps, out
),
args.warmup,
args.iters,
)
torch_us = time_us(
lambda: torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps),
args.warmup,
args.iters,
)
results.append(
Result(
shape=name,
rows=rows,
dim=dim,
kernel="residual_add_rms_norm_quant_fp8_static_bf16",
flashrt_us=kernel_us,
torch_eager_us=torch_us,
speedup_vs_eager=torch_us / kernel_us,
status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL",
**m,
)
)
return results
def write_markdown(path: Path, results: list[Result]) -> None:
lines = [
"| Shape | Rows,Dim | Kernel | FlashRT us | Eager us | vs eager | Max abs | Mean abs | P99 abs | Cosine | Status |",
"|---|---:|---|---:|---:|---:|---:|---:|---:|---:|---|",
]
for r in results:
lines.append(
f"| {r.shape} | {r.rows},{r.dim} | {r.kernel} | {r.flashrt_us:.3f} | "
f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | "
f"{r.max_abs:.6f} | {r.mean_abs:.6f} | {r.p99_abs:.6f} | "
f"{r.cosine:.8f} | {r.status} |"
)
path.write_text("\n".join(lines) + "\n")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke")
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--eps", type=float, default=1e-6)
parser.add_argument("--p99-abs-limit", type=float, default=0.5)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", default=None)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
torch.manual_seed(29)
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
results = []
for name in SHAPE_GROUPS[args.shapes]:
rows, dim = SHAPES[name]
results.extend(run_one(ops, name, rows, dim, args))
for r in results:
print(
f"{r.status} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us "
f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x "
f"p99_abs={r.p99_abs:.6f} cosine={r.cosine:.8f}"
)
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n")
if args.markdown:
Path(args.markdown).parent.mkdir(parents=True, exist_ok=True)
write_markdown(Path(args.markdown), results)
if any(r.status != "PASS" for r in results):
raise SystemExit(1)
if __name__ == "__main__":
main()