| |
| """Benchmark flashrt-residual-norm-quant against PyTorch eager references.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import ctypes |
| import ctypes.util |
| import importlib |
| import json |
| import math |
| import os |
| import sys |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| PACKAGE = ROOT / "flashrt-residual-norm-quant" |
| REGISTRATION_INCLUDE = ( |
| ROOT.parent |
| / "kernels" |
| / "kernel-builder" |
| / "src" |
| / "pyproject" |
| / "templates" |
| / "torch" |
| ) |
|
|
| SHAPES = { |
| "pi05_decoder": (10, 1024), |
| "pi05_vision": (512, 1152), |
| "groot_vl": (1024, 2048), |
| "video_prefill": (2520, 2048), |
| } |
| SHAPE_GROUPS = { |
| "smoke": ["pi05_decoder"], |
| "headline": ["pi05_decoder", "pi05_vision", "groot_vl"], |
| "all": list(SHAPES.keys()), |
| } |
|
|
|
|
| @dataclass |
| class Result: |
| shape: str |
| rows: int |
| dim: int |
| kernel: str |
| flashrt_us: float |
| torch_eager_us: float |
| speedup_vs_eager: float |
| max_abs: float |
| mean_abs: float |
| p99_abs: float |
| cosine: float |
| status: str |
|
|
|
|
| class SourceOps: |
| def __init__(self, namespace: str) -> None: |
| self._ops = getattr(torch.ops, namespace) |
|
|
| def rms_norm_quant_fp8_static_bf16(self, x, weight, scale, eps=1e-6, out=None): |
| if out is None: |
| out = torch.empty_like(x, dtype=torch.float8_e4m3fn) |
| self._ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out) |
| return out |
|
|
| def residual_add_rms_norm_quant_fp8_static_bf16( |
| self, residual, x, weight, scale, eps=1e-6, out=None |
| ): |
| if out is None: |
| out = torch.empty_like(x, dtype=torch.float8_e4m3fn) |
| self._ops.residual_add_rms_norm_quant_fp8_static_bf16( |
| residual, x, weight, scale, float(eps), out |
| ) |
| return out |
|
|
|
|
| def _preload_cublaslt() -> None: |
| for parent in Path(torch.__file__).resolve().parents: |
| candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12" |
| if candidate.exists(): |
| ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) |
| return |
| library = ctypes.util.find_library("cublasLt") |
| if library: |
| ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL) |
|
|
|
|
| def _current_arch_list() -> str: |
| major, minor = torch.cuda.get_device_capability(0) |
| return f"{major}.{minor}" |
|
|
|
|
| def load_source_ops() -> SourceOps: |
| from torch.utils.cpp_extension import load |
|
|
| if not REGISTRATION_INCLUDE.is_dir(): |
| raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}") |
| _preload_cublaslt() |
| os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list()) |
| namespace = "flashrt_residual_norm_quant_benchmark" |
| load( |
| name=namespace, |
| sources=[ |
| str(PACKAGE / "torch-ext" / "torch_binding.cpp"), |
| str(PACKAGE / "csrc" / "residual_norm_quant.cu"), |
| ], |
| extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)], |
| extra_cflags=["-O3", "-DCUDA_KERNEL"], |
| extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"], |
| verbose=False, |
| ) |
| return SourceOps(namespace) |
|
|
|
|
| def load_installed_ops(artifact: str | None): |
| if artifact: |
| sys.path.insert(0, artifact) |
| try: |
| return importlib.import_module("flashrt_residual_norm_quant") |
| finally: |
| if artifact: |
| sys.path.remove(artifact) |
|
|
|
|
| def quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| return torch.clamp(x.float() / scale.float(), -448.0, 448.0).to(torch.float8_e4m3fn) |
|
|
|
|
| def torch_rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: |
| rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=1, keepdim=True) + eps) |
| return x.float() * rms * weight.float() |
|
|
|
|
| def torch_rms_norm_quant(x, weight, scale, eps) -> torch.Tensor: |
| return quantize_fp8(torch_rms_norm(x, weight, eps), scale) |
|
|
|
|
| def torch_residual_add_rms_norm_quant(residual, x, weight, scale, eps) -> torch.Tensor: |
| added = residual.float() + x.float() |
| residual.copy_(added.to(torch.bfloat16)) |
| rms = torch.rsqrt(torch.mean(added * added, dim=1, keepdim=True) + eps) |
| return quantize_fp8(residual.float() * rms * weight.float(), scale) |
|
|
|
|
| def make_case(rows: int, dim: int): |
| x = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16) |
| residual = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16) |
| weight = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous() |
| scale = torch.tensor([0.04], device="cuda", dtype=torch.float32) |
| out = torch.empty((rows, dim), device="cuda", dtype=torch.float8_e4m3fn) |
| return x, residual, weight, scale, out |
|
|
|
|
| def time_us(fn, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) * 1000.0 / iters |
|
|
|
|
| def percentile(x: torch.Tensor, q: float) -> torch.Tensor: |
| flat = x.flatten() |
| k = max(1, min(flat.numel(), math.ceil(q * flat.numel()))) |
| return flat.kthvalue(k).values |
|
|
|
|
| def metrics(got: torch.Tensor, expected: torch.Tensor): |
| diff = (got.float() - expected.float()).abs().flatten() |
| cosine = torch.nn.functional.cosine_similarity( |
| got.float().flatten(), expected.float().flatten(), dim=0 |
| ) |
| return { |
| "max_abs": float(diff.max().item()), |
| "mean_abs": float(diff.mean().item()), |
| "p99_abs": float(percentile(diff, 0.99).item()), |
| "cosine": float(cosine.item()), |
| } |
|
|
|
|
| def run_one(ops, name: str, rows: int, dim: int, args) -> list[Result]: |
| x, residual, weight, scale, out = make_case(rows, dim) |
| eps = args.eps |
| results = [] |
|
|
| got = ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out) |
| expected = torch_rms_norm_quant(x, weight, scale, eps) |
| m = metrics(got, expected) |
| kernel_us = time_us( |
| lambda: ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out), |
| args.warmup, |
| args.iters, |
| ) |
| torch_us = time_us(lambda: torch_rms_norm_quant(x, weight, scale, eps), args.warmup, args.iters) |
| results.append( |
| Result( |
| shape=name, |
| rows=rows, |
| dim=dim, |
| kernel="rms_norm_quant_fp8_static_bf16", |
| flashrt_us=kernel_us, |
| torch_eager_us=torch_us, |
| speedup_vs_eager=torch_us / kernel_us, |
| status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL", |
| **m, |
| ) |
| ) |
|
|
| residual0 = residual.clone() |
| residual_kernel = residual0.clone() |
| got = ops.residual_add_rms_norm_quant_fp8_static_bf16( |
| residual_kernel, x, weight, scale, eps, out |
| ) |
| residual_ref = residual0.clone() |
| expected = torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps) |
| m = metrics(got, expected) |
| residual_kernel = residual0.clone() |
| residual_ref = residual0.clone() |
| kernel_us = time_us( |
| lambda: ops.residual_add_rms_norm_quant_fp8_static_bf16( |
| residual_kernel, x, weight, scale, eps, out |
| ), |
| args.warmup, |
| args.iters, |
| ) |
| torch_us = time_us( |
| lambda: torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps), |
| args.warmup, |
| args.iters, |
| ) |
| results.append( |
| Result( |
| shape=name, |
| rows=rows, |
| dim=dim, |
| kernel="residual_add_rms_norm_quant_fp8_static_bf16", |
| flashrt_us=kernel_us, |
| torch_eager_us=torch_us, |
| speedup_vs_eager=torch_us / kernel_us, |
| status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL", |
| **m, |
| ) |
| ) |
| return results |
|
|
|
|
| def write_markdown(path: Path, results: list[Result]) -> None: |
| lines = [ |
| "| Shape | Rows,Dim | Kernel | FlashRT us | Eager us | vs eager | Max abs | Mean abs | P99 abs | Cosine | Status |", |
| "|---|---:|---|---:|---:|---:|---:|---:|---:|---:|---|", |
| ] |
| for r in results: |
| lines.append( |
| f"| {r.shape} | {r.rows},{r.dim} | {r.kernel} | {r.flashrt_us:.3f} | " |
| f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | " |
| f"{r.max_abs:.6f} | {r.mean_abs:.6f} | {r.p99_abs:.6f} | " |
| f"{r.cosine:.8f} | {r.status} |" |
| ) |
| path.write_text("\n".join(lines) + "\n") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke") |
| parser.add_argument("--warmup", type=int, default=5) |
| parser.add_argument("--iters", type=int, default=20) |
| parser.add_argument("--eps", type=float, default=1e-6) |
| parser.add_argument("--p99-abs-limit", type=float, default=0.5) |
| parser.add_argument("--output", default=None) |
| parser.add_argument("--markdown", default=None) |
| args = parser.parse_args() |
|
|
| if not torch.cuda.is_available(): |
| raise SystemExit("CUDA is required") |
| torch.manual_seed(29) |
| ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) |
|
|
| results = [] |
| for name in SHAPE_GROUPS[args.shapes]: |
| rows, dim = SHAPES[name] |
| results.extend(run_one(ops, name, rows, dim, args)) |
|
|
| for r in results: |
| print( |
| f"{r.status} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us " |
| f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x " |
| f"p99_abs={r.p99_abs:.6f} cosine={r.cosine:.8f}" |
| ) |
|
|
| if args.output: |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n") |
| if args.markdown: |
| Path(args.markdown).parent.mkdir(parents=True, exist_ok=True) |
| write_markdown(Path(args.markdown), results) |
|
|
| if any(r.status != "PASS" for r in results): |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|