| |
| """Benchmark flashrt-fp8-ffn against PyTorch eager/compile references.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import ctypes |
| import ctypes.util |
| import importlib |
| import json |
| import math |
| import os |
| import sys |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| PACKAGE = ROOT / "flashrt-fp8-ffn" |
| REGISTRATION_INCLUDE = ( |
| ROOT.parent |
| / "kernels" |
| / "kernel-builder" |
| / "src" |
| / "pyproject" |
| / "templates" |
| / "torch" |
| ) |
|
|
|
|
| SHAPES = { |
| |
| "pi05_decoder_ffn_m1": (1, 1024, 4096, 1024, 18), |
| "pi05_decoder_ffn_m8": (8, 1024, 4096, 1024, 18), |
| "pi05_decoder_ffn_m10": (10, 1024, 4096, 1024, 18), |
| "pi05_decoder_ffn_m16": (16, 1024, 4096, 1024, 18), |
| |
| "pi05_decoder_ffn": (10, 1024, 4096, 1024, 18), |
| |
| "pi05_vision_ffn_1view": (256, 1152, 4304, 1152, 27), |
| "pi05_vision_ffn_2view": (512, 1152, 4304, 1152, 27), |
| "pi05_vision_ffn_3view": (768, 1152, 4304, 1152, 27), |
| |
| "groot_vit_ffn_1view": (256, 1024, 4096, 1024, 24), |
| "groot_vit_ffn_2view": (512, 1024, 4096, 1024, 24), |
| "groot_vit_ffn_4view": (1024, 1024, 4096, 1024, 24), |
| |
| "groot_deepstack_merge_2view": (128, 4096, 4096, 2048, 3), |
| |
| "groot_vl_self_attn_ffn_seq512": (512, 2048, 8192, 2048, 4), |
| "groot_vl_self_attn_ffn_seq1024": (1024, 2048, 8192, 2048, 4), |
| "groot_vl_self_attn_ffn_seq2520": (2520, 2048, 8192, 2048, 4), |
| |
| "groot_vl_self_attn_ffn": (1024, 2048, 8192, 2048, 4), |
| |
| |
| |
| "groot_action_dit_ffn": (41, 1536, 6144, 1536, 32), |
| } |
|
|
| SHAPE_GROUPS = { |
| "headline": [ |
| "pi05_decoder_ffn_m10", |
| "pi05_vision_ffn_2view", |
| "groot_vit_ffn_2view", |
| "groot_vl_self_attn_ffn_seq1024", |
| ], |
| "pi05": [ |
| "pi05_decoder_ffn_m1", |
| "pi05_decoder_ffn_m8", |
| "pi05_decoder_ffn_m10", |
| "pi05_decoder_ffn_m16", |
| "pi05_vision_ffn_1view", |
| "pi05_vision_ffn_2view", |
| "pi05_vision_ffn_3view", |
| ], |
| "groot": [ |
| "groot_vit_ffn_1view", |
| "groot_vit_ffn_2view", |
| "groot_vit_ffn_4view", |
| "groot_deepstack_merge_2view", |
| "groot_vl_self_attn_ffn_seq512", |
| "groot_vl_self_attn_ffn_seq1024", |
| "groot_vl_self_attn_ffn_seq2520", |
| "groot_action_dit_ffn", |
| ], |
| } |
| SHAPE_GROUPS["all"] = SHAPE_GROUPS["pi05"] + SHAPE_GROUPS["groot"] |
|
|
|
|
| @dataclass |
| class Result: |
| shape: str |
| M: int |
| K: int |
| H: int |
| N: int |
| layers: int |
| flashrt_us: float |
| torch_eager_us: float |
| torch_compile_us: float | None |
| speedup_vs_eager: float |
| speedup_vs_compile: float | None |
| compile_status: str |
| max_abs: float |
| p99_abs: float |
| p99_rel_floor1: float |
| max_rel_floor1: float |
| status: str |
|
|
|
|
| class SourceOps: |
| def __init__(self, namespace: str) -> None: |
| self._ops = getattr(torch.ops, namespace) |
|
|
| def fp8_gelu_mlp_bf16( |
| self, |
| x, |
| up_w, |
| up_b, |
| down_w, |
| down_b, |
| x_scale, |
| up_w_scale, |
| hidden_scale, |
| down_w_scale, |
| hidden=None, |
| hidden_fp8=None, |
| out=None, |
| ): |
| if hidden is None: |
| hidden = torch.empty((x.shape[0], up_w.shape[0]), device=x.device, dtype=torch.bfloat16) |
| if hidden_fp8 is None: |
| hidden_fp8 = torch.empty_like(hidden, dtype=fp8_dtype()) |
| if out is None: |
| out = torch.empty((x.shape[0], down_w.shape[0]), device=x.device, dtype=torch.bfloat16) |
| self._ops.fp8_gelu_mlp_bf16( |
| x, |
| up_w, |
| up_b, |
| down_w, |
| down_b, |
| x_scale, |
| up_w_scale, |
| hidden_scale, |
| down_w_scale, |
| hidden, |
| hidden_fp8, |
| out, |
| ) |
| return out |
|
|
|
|
| def _preload_cublaslt() -> None: |
| for parent in Path(torch.__file__).resolve().parents: |
| candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12" |
| if candidate.exists(): |
| ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) |
| return |
| library = ctypes.util.find_library("cublasLt") |
| if library: |
| ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL) |
|
|
|
|
| def _current_arch_list() -> str: |
| major, minor = torch.cuda.get_device_capability(0) |
| return f"{major}.{minor}" |
|
|
|
|
| def load_source_ops() -> SourceOps: |
| from torch.utils.cpp_extension import load |
|
|
| if not REGISTRATION_INCLUDE.is_dir(): |
| raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}") |
| _preload_cublaslt() |
| os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list()) |
| namespace = "flashrt_fp8_ffn_benchmark" |
| load( |
| name=namespace, |
| sources=[ |
| str(PACKAGE / "torch-ext" / "torch_binding.cpp"), |
| str(PACKAGE / "csrc" / "fp8_ffn.cu"), |
| ], |
| extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)], |
| extra_cflags=["-O3", "-DCUDA_KERNEL"], |
| extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"], |
| verbose=False, |
| ) |
| return SourceOps(namespace) |
|
|
|
|
| def load_installed_ops(artifact: str | None): |
| if artifact: |
| sys.path.insert(0, artifact) |
| try: |
| return importlib.import_module("flashrt_fp8_ffn") |
| finally: |
| if artifact: |
| sys.path.remove(artifact) |
|
|
|
|
| def load_hub_ops(repo_id: str, version: int): |
| from kernels import get_kernel |
|
|
| return get_kernel(repo_id, version=version) |
|
|
|
|
| def fp8_dtype() -> torch.dtype: |
| if torch.version.hip is not None and hasattr(torch, "float8_e4m3fnuz"): |
| return torch.float8_e4m3fnuz |
| return torch.float8_e4m3fn |
|
|
|
|
| def fp8_max() -> float: |
| return 240.0 if torch.version.hip is not None else 448.0 |
|
|
|
|
| def quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| limit = fp8_max() |
| return torch.clamp(x.float() / scale.float(), -limit, limit).to(fp8_dtype()) |
|
|
|
|
| def dequant_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| return x.float() * scale.float() |
|
|
|
|
| def compiler_disable(fn): |
| compiler = getattr(torch, "compiler", None) |
| if compiler is not None and hasattr(compiler, "disable"): |
| return compiler.disable(fn) |
| return torch._dynamo.disable(fn) |
|
|
|
|
| def gelu_quantize_fp8_boundary( |
| hidden: torch.Tensor, bias: torch.Tensor, scale: torch.Tensor |
| ) -> torch.Tensor: |
| hidden = torch.nn.functional.gelu( |
| hidden.float() + bias.float(), approximate="tanh" |
| ) |
| return quantize_fp8(hidden, scale) |
|
|
|
|
| def bf16_bias_add_boundary(out: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: |
| return (out.float() + bias.float()).to(torch.bfloat16) |
|
|
|
|
| stable_gelu_quantize_fp8 = compiler_disable(gelu_quantize_fp8_boundary) |
| stable_bf16_bias_add = compiler_disable(bf16_bias_add_boundary) |
|
|
|
|
| def torch_mlp(x, up_w, up_b, down_w, down_b, x_s, up_s, hid_s, dn_s): |
| hidden = (dequant_fp8(x, x_s) @ dequant_fp8(up_w, up_s).T).to(torch.bfloat16) |
| hidden = torch.nn.functional.gelu(hidden + up_b.float(), approximate="tanh") |
| limit = fp8_max() |
| hidden_fp8 = torch.clamp(hidden / hid_s.float(), -limit, limit).to(fp8_dtype()) |
| out = (dequant_fp8(hidden_fp8, hid_s) @ dequant_fp8(down_w, dn_s).T).to(torch.bfloat16) |
| return (out + down_b.float()).to(torch.bfloat16) |
|
|
|
|
| def torch_mlp_compile_stable(x, up_w, up_b, down_w, down_b, x_s, up_s, hid_s, dn_s): |
| hidden = (dequant_fp8(x, x_s) @ dequant_fp8(up_w, up_s).T).to(torch.bfloat16) |
| hidden_fp8 = stable_gelu_quantize_fp8(hidden, up_b, hid_s) |
| out = (dequant_fp8(hidden_fp8, hid_s) @ dequant_fp8(down_w, dn_s).T).to(torch.bfloat16) |
| return stable_bf16_bias_add(out, down_b) |
|
|
|
|
| def make_inputs(M: int, K: int, H: int, N: int, layers: int): |
| x_scale = torch.tensor([0.05], device="cuda", dtype=torch.float32) |
| up_scale = torch.tensor([0.04], device="cuda", dtype=torch.float32) |
| hidden_scale = torch.tensor([0.25], device="cuda", dtype=torch.float32) |
| down_scale = torch.tensor([0.04], device="cuda", dtype=torch.float32) |
| xs = [ |
| quantize_fp8(torch.randn((M, K), device="cuda", dtype=torch.bfloat16), x_scale) |
| for _ in range(layers) |
| ] |
| up_ws = [ |
| quantize_fp8(torch.randn((H, K), device="cuda", dtype=torch.bfloat16), up_scale) |
| for _ in range(layers) |
| ] |
| down_ws = [ |
| quantize_fp8(torch.randn((N, H), device="cuda", dtype=torch.bfloat16), down_scale) |
| for _ in range(layers) |
| ] |
| up_bs = [torch.randn((H,), device="cuda", dtype=torch.bfloat16) for _ in range(layers)] |
| down_bs = [torch.randn((N,), device="cuda", dtype=torch.bfloat16) for _ in range(layers)] |
| hidden = [torch.empty((M, H), device="cuda", dtype=torch.bfloat16) for _ in range(layers)] |
| hidden_fp8 = [torch.empty((M, H), device="cuda", dtype=fp8_dtype()) for _ in range(layers)] |
| outs = [torch.empty((M, N), device="cuda", dtype=torch.bfloat16) for _ in range(layers)] |
| return xs, up_ws, up_bs, down_ws, down_bs, x_scale, up_scale, hidden_scale, down_scale, hidden, hidden_fp8, outs |
|
|
|
|
| def time_us(fn, *, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) * 1000.0 / iters |
|
|
|
|
| def percentile(x: torch.Tensor, q: float) -> torch.Tensor: |
| flat = x.flatten() |
| k = max(1, min(flat.numel(), math.ceil(q * flat.numel()))) |
| return flat.kthvalue(k).values |
|
|
|
|
| def _outputs_close(got, expected) -> bool: |
| if isinstance(got, (tuple, list)) and isinstance(expected, (tuple, list)): |
| return len(got) == len(expected) and all( |
| _outputs_close(g, e) for g, e in zip(got, expected) |
| ) |
| return bool(torch.allclose(got, expected, rtol=3e-2, atol=1.25e-1)) |
|
|
|
|
| def compile_time_us(fn, expected, *, warmup: int, iters: int) -> tuple[float | None, str]: |
| try: |
| compiled = torch.compile(fn, fullgraph=False, mode="reduce-overhead") |
| compiled_out = compiled() |
| torch.cuda.synchronize() |
| if not _outputs_close(compiled_out, expected): |
| return None, "unsupported: compiled reference output mismatch" |
| return time_us(compiled, warmup=warmup, iters=iters), "ok" |
| except Exception as exc: |
| return None, f"unsupported: {type(exc).__name__}: {exc}" |
|
|
|
|
| def run_shape(ops, name: str, shape, args) -> Result: |
| M, K, H, N, layers = shape |
| xs, up_ws, up_bs, down_ws, down_bs, x_s, up_s, hid_s, dn_s, hidden, hidden_fp8, outs = make_inputs( |
| M, K, H, N, layers |
| ) |
|
|
| def flashrt_stack(): |
| result = [] |
| for i in range(layers): |
| result.append( |
| ops.fp8_gelu_mlp_bf16( |
| xs[i], |
| up_ws[i], |
| up_bs[i], |
| down_ws[i], |
| down_bs[i], |
| x_s, |
| up_s, |
| hid_s, |
| dn_s, |
| hidden[i], |
| hidden_fp8[i], |
| outs[i], |
| ) |
| ) |
| return tuple(result) |
|
|
| def torch_stack(): |
| return tuple( |
| torch_mlp(xs[i], up_ws[i], up_bs[i], down_ws[i], down_bs[i], x_s, up_s, hid_s, dn_s) |
| for i in range(layers) |
| ) |
|
|
| def torch_stack_compile_stable(): |
| return tuple( |
| torch_mlp_compile_stable( |
| xs[i], up_ws[i], up_bs[i], down_ws[i], down_bs[i], x_s, up_s, hid_s, dn_s |
| ) |
| for i in range(layers) |
| ) |
|
|
| flashrt_stack() |
| expected0 = torch_mlp(xs[0], up_ws[0], up_bs[0], down_ws[0], down_bs[0], x_s, up_s, hid_s, dn_s) |
| diff = (outs[0].float() - expected0.float()).abs().flatten() |
| rel = diff / expected0.float().abs().flatten().clamp_min(1.0) |
| max_abs = float(diff.max().item()) |
| p99_abs = float(percentile(diff, 0.99).item()) |
| p99_rel = float(percentile(rel, 0.99).item()) |
| max_rel = float(rel.max().item()) |
| status = ( |
| "PASS" |
| if p99_abs <= args.p99_abs_limit and p99_rel <= args.p99_rel_floor1_limit |
| else "FAIL" |
| ) |
|
|
| flashrt_us = time_us(flashrt_stack, warmup=args.warmup, iters=args.iters) |
| eager_us = time_us(torch_stack, warmup=args.warmup, iters=args.iters) |
| compile_us = None |
| compile_status = "not_requested" |
| if args.compile_baseline: |
| eager_expected = torch_stack() |
| stable_expected = torch_stack_compile_stable() |
| torch.cuda.synchronize() |
| if not _outputs_close(stable_expected, eager_expected): |
| compile_status = "unsupported: stable compile reference differs from eager" |
| else: |
| compile_us, compile_status = compile_time_us( |
| torch_stack_compile_stable, |
| eager_expected, |
| warmup=args.warmup, |
| iters=args.iters, |
| ) |
|
|
| return Result( |
| shape=name, |
| M=M, |
| K=K, |
| H=H, |
| N=N, |
| layers=layers, |
| flashrt_us=flashrt_us, |
| torch_eager_us=eager_us, |
| torch_compile_us=compile_us, |
| speedup_vs_eager=eager_us / flashrt_us, |
| speedup_vs_compile=compile_us / flashrt_us if compile_us is not None else None, |
| compile_status=compile_status, |
| max_abs=max_abs, |
| p99_abs=p99_abs, |
| p99_rel_floor1=p99_rel, |
| max_rel_floor1=max_rel, |
| status=status, |
| ) |
|
|
|
|
| def write_markdown(path: Path, results: list[Result], args) -> None: |
| lines = [ |
| "# Benchmark Results: flashrt-fp8-ffn", |
| "", |
| f"- Backend: `{args.backend}`", |
| f"- Device: `{torch.cuda.get_device_name(0)}`", |
| f"- Torch: `{torch.__version__}`", |
| f"- Warmup/iters: `{args.warmup}/{args.iters}`", |
| f"- Precision gate: p99_abs <= `{args.p99_abs_limit}` and " |
| f"p99_rel_floor1 <= `{args.p99_rel_floor1_limit}`", |
| "- Compile baseline: reported only when compiled reference output " |
| "matches eager reference output.", |
| "", |
| "| Shape | M,K,H,N | Layers | FlashRT us | Eager us | vs eager | Compile us | vs compile | Compile status | P99 abs | P99 rel | Max abs | Status |", |
| "|---|---:|---:|---:|---:|---:|---:|---:|---|---:|---:|---:|---:|", |
| ] |
| for r in results: |
| compile_us = f"{r.torch_compile_us:.3f}" if r.torch_compile_us is not None else "n/a" |
| compile_speedup = f"{r.speedup_vs_compile:.2f}x" if r.speedup_vs_compile is not None else "n/a" |
| lines.append( |
| f"| {r.shape} | {r.M},{r.K},{r.H},{r.N} | {r.layers} | " |
| f"{r.flashrt_us:.3f} | {r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | " |
| f"{compile_us} | {compile_speedup} | {r.compile_status} | {r.p99_abs:.4f} | " |
| f"{r.p99_rel_floor1:.6f} | {r.max_abs:.4f} | {r.status} |" |
| ) |
| lines.append("") |
| path.write_text("\n".join(lines), encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed", "hub"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--repo-id", default="flashrt/flashrt-fp8-ffn") |
| parser.add_argument("--version", type=int, default=1) |
| parser.add_argument("--shapes", default="all") |
| parser.add_argument("--warmup", type=int, default=5) |
| parser.add_argument("--iters", type=int, default=20) |
| parser.add_argument("--compile-baseline", action="store_true") |
| parser.add_argument("--p99-abs-limit", type=float, default=1.0) |
| parser.add_argument("--p99-rel-floor1-limit", type=float, default=0.05) |
| parser.add_argument("--output", type=Path, default=None) |
| parser.add_argument("--markdown", type=Path, default=None) |
| parser.add_argument("--list-shapes", action="store_true") |
| args = parser.parse_args() |
|
|
| if args.list_shapes: |
| print("Shape groups:") |
| for group, names in SHAPE_GROUPS.items(): |
| print(f" {group}: {','.join(names)}") |
| print("\nShapes:") |
| for name, shape in SHAPES.items(): |
| print(f" {name}: M,K,H,N,layers={shape}") |
| return |
|
|
| if not torch.cuda.is_available(): |
| raise SystemExit("CUDA is required") |
| torch.manual_seed(17) |
| if args.backend == "source": |
| ops = load_source_ops() |
| elif args.backend == "installed": |
| ops = load_installed_ops(args.artifact) |
| else: |
| ops = load_hub_ops(args.repo_id, args.version) |
| requested = [s.strip() for s in args.shapes.split(",")] |
| names: list[str] = [] |
| for item in requested: |
| if item in SHAPE_GROUPS: |
| names.extend(SHAPE_GROUPS[item]) |
| else: |
| names.append(item) |
| unknown = [name for name in names if name not in SHAPES] |
| if unknown: |
| raise SystemExit(f"unknown shapes/groups: {unknown}") |
|
|
| results = [] |
| for name in names: |
| results.append(run_shape(ops, name, SHAPES[name], args)) |
| torch.cuda.empty_cache() |
|
|
| for r in results: |
| compile_part = ( |
| f", compile={r.torch_compile_us:.3f}us, vs_compile={r.speedup_vs_compile:.2f}x" |
| if r.torch_compile_us is not None |
| else f", compile={r.compile_status}" |
| ) |
| print( |
| f"{r.shape}: flashrt={r.flashrt_us:.3f}us, eager={r.torch_eager_us:.3f}us, " |
| f"vs_eager={r.speedup_vs_eager:.2f}x{compile_part}, " |
| f"p99_abs={r.p99_abs:.4f}, max_abs={r.max_abs:.4f}, {r.status}" |
| ) |
|
|
| payload = { |
| "backend": args.backend, |
| "torch": torch.__version__, |
| "device": torch.cuda.get_device_name(0), |
| "results": [asdict(r) for r in results], |
| } |
| if args.output: |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| args.output.write_text(json.dumps(payload, indent=2), encoding="utf-8") |
| if args.markdown: |
| args.markdown.parent.mkdir(parents=True, exist_ok=True) |
| write_markdown(args.markdown, results, args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|