liangsu9988's picture
Uploaded using `kernel-builder`.
ed53df0 verified
Raw
History Blame
9.33 kB
#!/usr/bin/env python3
"""Benchmark flashrt-adaptive-norms against PyTorch eager references."""
from __future__ import annotations
import argparse
import ctypes
import ctypes.util
import json
import math
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
PACKAGE = ROOT / "flashrt-adaptive-norms"
REGISTRATION_INCLUDE = (
ROOT.parent
/ "kernels"
/ "kernel-builder"
/ "src"
/ "pyproject"
/ "templates"
/ "torch"
)
SHAPES = {
"small": (64, 1024),
"vla_2k": (2520, 3072),
"vla_4k": (4096, 3072),
}
SHAPE_GROUPS = {
"smoke": ["small"],
"headline": ["vla_2k"],
"all": list(SHAPES.keys()),
}
@dataclass
class Result:
shape: str
rows: int
dim: int
kernel: str
flashrt_us: float
torch_eager_us: float
speedup_vs_eager: float
p99_abs: float
cosine: float
status: str
class SourceOps:
def __init__(self, namespace: str) -> None:
self._ops = getattr(torch.ops, namespace)
def ada_rms_norm_style_bf16(self, x, weight, style, eps, out, gate_out):
self._ops.ada_rms_norm_style_bf16(x, weight, style, float(eps), out, gate_out)
return out, gate_out
def gate_residual_ada_norm_fp8_static_bf16(self, residual, x, gate, weight, style, scale, eps, out, gate_out):
self._ops.gate_residual_ada_norm_fp8_static_bf16(
residual, x, gate, weight, style, scale, float(eps), out, gate_out
)
return residual, out, gate_out
def _preload_cublaslt() -> None:
for parent in Path(torch.__file__).resolve().parents:
candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12"
if candidate.exists():
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
return
library = ctypes.util.find_library("cublasLt")
if library:
ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL)
def _current_arch_list() -> str:
major, minor = torch.cuda.get_device_capability(0)
return f"{major}.{minor}"
def load_source_ops() -> SourceOps:
from torch.utils.cpp_extension import load
if not REGISTRATION_INCLUDE.is_dir():
raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}")
_preload_cublaslt()
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
namespace = "flashrt_adaptive_norms_benchmark"
load(
name=namespace,
sources=[
str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
str(PACKAGE / "csrc" / "adaptive_norms.cu"),
],
extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
extra_cflags=["-O3", "-DCUDA_KERNEL"],
extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
verbose=False,
)
return SourceOps(namespace)
def load_installed_ops(artifact: str | None):
if artifact:
sys.path.insert(0, artifact)
try:
return importlib.import_module("flashrt_adaptive_norms")
finally:
if artifact:
sys.path.remove(artifact)
def make_case(rows: int, dim: int):
x = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16)
residual = torch.randn_like(x)
gate = torch.randn_like(x)
weight = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
style = (0.05 * torch.randn((rows, 3 * dim), device="cuda", dtype=torch.bfloat16)).contiguous()
scale = torch.tensor([0.04], device="cuda", dtype=torch.float32)
return x, residual, gate, weight, style, scale
def rms_norm(x, weight, eps):
rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=-1, keepdim=True) + eps)
return x.float() * rms * weight.float()
def ref_ada(x, weight, style, eps):
dim = x.shape[1]
normed = rms_norm(x, weight, eps)
y = normed * (1.0 + style[:, :dim].float()) + style[:, dim : 2 * dim].float()
return y.to(torch.bfloat16), style[:, 2 * dim :].contiguous().to(torch.bfloat16)
def ref_fused(residual, x, gate, weight, style, scale, eps):
updated = (residual.float() + x.float() * gate.float()).to(torch.bfloat16)
dim = updated.shape[1]
normed = rms_norm(updated, weight, eps)
y = (normed * (1.0 + style[:, :dim].float()) + style[:, dim : 2 * dim].float()) / scale.float().reshape(())
return updated, y.to(torch.float8_e4m3fn), style[:, 2 * dim :].contiguous().to(torch.bfloat16)
def time_us(fn, warmup, iters):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) * 1000.0 / iters
def percentile(x, q):
flat = x.flatten()
k = max(1, min(flat.numel(), math.ceil(q * flat.numel())))
return flat.kthvalue(k).values
def metrics(got, expected):
diff = (got.float() - expected.float()).abs().flatten()
return float(percentile(diff, 0.99).item()), float(
torch.nn.functional.cosine_similarity(got.float().flatten(), expected.float().flatten(), dim=0).item()
)
def run_one(ops, name, shape, args):
rows, dim = shape
x, residual, gate, weight, style, scale = make_case(rows, dim)
out = torch.empty_like(x)
gate_out = torch.empty_like(x)
fp8_out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
fused_gate_out = torch.empty_like(x)
ada_got, _ = ops.ada_rms_norm_style_bf16(x, weight, style, args.eps, out, gate_out)
ada_exp, _ = ref_ada(x, weight, style, args.eps)
ada_p99, ada_cos = metrics(ada_got, ada_exp)
ada_flashrt_us = time_us(lambda: ops.ada_rms_norm_style_bf16(x, weight, style, args.eps, out, gate_out), args.warmup, args.iters)
ada_eager_us = time_us(lambda: ref_ada(x, weight, style, args.eps), args.warmup, args.iters)
residual_work = residual.clone()
fused_got = ops.gate_residual_ada_norm_fp8_static_bf16(
residual_work, x, gate, weight, style, scale, args.eps, fp8_out, fused_gate_out
)[1]
_, fused_exp, _ = ref_fused(residual, x, gate, weight, style, scale, args.eps)
fused_p99, fused_cos = metrics(fused_got.float(), fused_exp.float())
fused_flashrt_us = time_us(
lambda: ops.gate_residual_ada_norm_fp8_static_bf16(
residual_work, x, gate, weight, style, scale, args.eps, fp8_out, fused_gate_out
),
args.warmup,
args.iters,
)
fused_eager_us = time_us(lambda: ref_fused(residual, x, gate, weight, style, scale, args.eps), args.warmup, args.iters)
return [
Result(name, rows, dim, "ada_rms_norm_style_bf16", ada_flashrt_us, ada_eager_us, ada_eager_us / ada_flashrt_us, ada_p99, ada_cos, "PASS"),
Result(name, rows, dim, "gate_residual_ada_norm_fp8_static_bf16", fused_flashrt_us, fused_eager_us, fused_eager_us / fused_flashrt_us, fused_p99, fused_cos, "PASS"),
]
def write_markdown(path: Path, results: list[Result]) -> None:
lines = [
"# Source Benchmark Results",
"",
"Environment: NVIDIA GeForce RTX 5090 local source-extension build.",
"Baseline: PyTorch eager tensor reference with matching BF16/FP8 math contract.",
"",
"| Shape | Rows,Dim | Kernel | FlashRT us | Eager us | vs eager | p99 abs | Cosine | Status |",
"|---|---:|---|---:|---:|---:|---:|---:|---|",
]
for r in results:
lines.append(
f"| {r.shape} | {r.rows},{r.dim} | {r.kernel} | {r.flashrt_us:.3f} | "
f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | {r.p99_abs:.6f} | "
f"{r.cosine:.8f} | {r.status} |"
)
path.write_text("\n".join(lines) + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke")
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--eps", type=float, default=1e-6)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", default=None)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
torch.manual_seed(53)
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
results = []
for name in SHAPE_GROUPS[args.shapes]:
results.extend(run_one(ops, name, SHAPES[name], args))
for r in results:
print(
f"{r.status} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us "
f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x p99={r.p99_abs:.6f}"
)
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n")
if args.markdown:
Path(args.markdown).parent.mkdir(parents=True, exist_ok=True)
write_markdown(Path(args.markdown), results)
if __name__ == "__main__":
main()