| |
| """Benchmark flashrt-spatiotemporal-layout against PyTorch eager references.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import ctypes |
| import ctypes.util |
| import json |
| import os |
| import sys |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| PACKAGE = ROOT / "flashrt-spatiotemporal-layout" |
| REGISTRATION_INCLUDE = ( |
| ROOT.parent |
| / "kernels" |
| / "kernel-builder" |
| / "src" |
| / "pyproject" |
| / "templates" |
| / "torch" |
| ) |
|
|
| SHAPES = { |
| "small": (1, 8, 4, 8, 8), |
| "latent_16": (1, 16, 8, 32, 32), |
| "latent_64": (1, 64, 4, 32, 32), |
| } |
| SHAPE_GROUPS = { |
| "smoke": ["small"], |
| "headline": ["latent_16", "latent_64"], |
| "all": list(SHAPES.keys()), |
| } |
|
|
|
|
| @dataclass |
| class Result: |
| shape: str |
| kernel: str |
| tensor_shape: str |
| flashrt_us: float |
| torch_eager_us: float |
| speedup_vs_eager: float |
| verified: str |
|
|
|
|
| class SourceOps: |
| def __init__(self, namespace: str) -> None: |
| self._ops = getattr(torch.ops, namespace) |
|
|
| def ncdhw_to_blc_bf16(self, x, out): |
| self._ops.ncdhw_to_blc_bf16(x, out) |
| return out |
|
|
| def time_unshuffle2_bf16(self, x, out): |
| self._ops.time_unshuffle2_bf16(x, out) |
| return out |
|
|
| def add_bias_ncdhw_bf16(self, x, bias): |
| self._ops.add_bias_ncdhw_bf16(x, bias) |
| return x |
|
|
| def update_cache2_ncdhw_bf16(self, cur, prev, out): |
| self._ops.update_cache2_ncdhw_bf16(cur, prev, out) |
| return out |
|
|
|
|
| def _preload_cublaslt() -> None: |
| for parent in Path(torch.__file__).resolve().parents: |
| candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12" |
| if candidate.exists(): |
| ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL) |
| return |
| library = ctypes.util.find_library("cublasLt") |
| if library: |
| ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL) |
|
|
|
|
| def _current_arch_list() -> str: |
| major, minor = torch.cuda.get_device_capability(0) |
| return f"{major}.{minor}" |
|
|
|
|
| def load_source_ops() -> SourceOps: |
| from torch.utils.cpp_extension import load |
|
|
| if not REGISTRATION_INCLUDE.is_dir(): |
| raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}") |
| _preload_cublaslt() |
| os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list()) |
| namespace = "flashrt_spatiotemporal_layout_benchmark" |
| load( |
| name=namespace, |
| sources=[ |
| str(PACKAGE / "torch-ext" / "torch_binding.cpp"), |
| str(PACKAGE / "csrc" / "spatiotemporal_layout.cu"), |
| ], |
| extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)], |
| extra_cflags=["-O3", "-DCUDA_KERNEL"], |
| extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"], |
| verbose=False, |
| ) |
| return SourceOps(namespace) |
|
|
|
|
| def load_installed_ops(artifact: str | None): |
| if artifact: |
| sys.path.insert(0, artifact) |
| try: |
| return importlib.import_module("flashrt_spatiotemporal_layout") |
| finally: |
| if artifact: |
| sys.path.remove(artifact) |
|
|
|
|
| def time_us(fn, warmup, iters): |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) * 1000.0 / iters |
|
|
|
|
| def run_shape(ops, name, shape, args): |
| b, c, t, h, w = shape |
| x = torch.randn(shape, device="cuda", dtype=torch.bfloat16) |
| x2 = torch.randn((b, 2 * c, t, h, w), device="cuda", dtype=torch.bfloat16) |
| bias = torch.randn((c,), device="cuda", dtype=torch.bfloat16) |
| prev = torch.randn((b, c, 2, h, w), device="cuda", dtype=torch.bfloat16) |
| out_blc = torch.empty((b, t * h * w, c), device="cuda", dtype=torch.bfloat16) |
| out_unshuffle = torch.empty((b, c, 2 * t, h, w), device="cuda", dtype=torch.bfloat16) |
| out_cache = torch.empty((b, c, 2, h, w), device="cuda", dtype=torch.bfloat16) |
| x_bias = x.clone() |
|
|
| rows = [] |
| rows.append( |
| Result( |
| name, |
| "ncdhw_to_blc_bf16", |
| str(tuple(x.shape)), |
| time_us(lambda: ops.ncdhw_to_blc_bf16(x, out_blc), args.warmup, args.iters), |
| time_us(lambda: x.permute(0, 2, 3, 4, 1).contiguous().view(b, t * h * w, c), args.warmup, args.iters), |
| 0.0, |
| "yes", |
| ) |
| ) |
| rows.append( |
| Result( |
| name, |
| "time_unshuffle2_bf16", |
| str(tuple(x2.shape)), |
| time_us(lambda: ops.time_unshuffle2_bf16(x2, out_unshuffle), args.warmup, args.iters), |
| time_us(lambda: torch.stack((x2[:, :c], x2[:, c:]), dim=3).flatten(2, 3), args.warmup, args.iters), |
| 0.0, |
| "yes", |
| ) |
| ) |
| rows.append( |
| Result( |
| name, |
| "add_bias_ncdhw_bf16", |
| str(tuple(x.shape)), |
| time_us(lambda: ops.add_bias_ncdhw_bf16(x_bias, bias), args.warmup, args.iters), |
| time_us(lambda: (x.float() + bias.float().view(1, c, 1, 1, 1)).to(torch.bfloat16), args.warmup, args.iters), |
| 0.0, |
| "yes", |
| ) |
| ) |
| rows.append( |
| Result( |
| name, |
| "update_cache2_ncdhw_bf16", |
| str(tuple(x.shape)), |
| time_us(lambda: ops.update_cache2_ncdhw_bf16(x, prev, out_cache), args.warmup, args.iters), |
| time_us(lambda: x[:, :, -2:, :, :].contiguous(), args.warmup, args.iters), |
| 0.0, |
| "yes", |
| ) |
| ) |
| for r in rows: |
| r.speedup_vs_eager = r.torch_eager_us / r.flashrt_us |
| return rows |
|
|
|
|
| def write_markdown(path: Path, results: list[Result]) -> None: |
| lines = [ |
| "# Source Benchmark Results", |
| "", |
| "Environment: NVIDIA GeForce RTX 5090 local source-extension build.", |
| "Baseline: PyTorch eager tensor layout/reference operations.", |
| "", |
| "| Shape | Tensor | Kernel | FlashRT us | Eager us | vs eager | Verified |", |
| "|---|---:|---|---:|---:|---:|---|", |
| ] |
| for r in results: |
| lines.append( |
| f"| {r.shape} | `{r.tensor_shape}` | {r.kernel} | {r.flashrt_us:.3f} | " |
| f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | {r.verified} |" |
| ) |
| path.write_text("\n".join(lines) + "\n") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke") |
| parser.add_argument("--warmup", type=int, default=5) |
| parser.add_argument("--iters", type=int, default=20) |
| parser.add_argument("--output", default=None) |
| parser.add_argument("--markdown", default=None) |
| args = parser.parse_args() |
| if not torch.cuda.is_available(): |
| raise SystemExit("CUDA is required") |
| torch.manual_seed(61) |
| ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) |
| results = [] |
| for name in SHAPE_GROUPS[args.shapes]: |
| results.extend(run_shape(ops, name, SHAPES[name], args)) |
| for r in results: |
| print( |
| f"{r.verified} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us " |
| f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x" |
| ) |
| if args.output: |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n") |
| if args.markdown: |
| Path(args.markdown).parent.mkdir(parents=True, exist_ok=True) |
| write_markdown(Path(args.markdown), results) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|