liangsu9988's picture
Uploaded using `kernel-builder`.
9d82137 verified
Raw
History Blame
7.88 kB
#!/usr/bin/env python3
"""Benchmark flashrt-spatiotemporal-layout against PyTorch eager references."""
from __future__ import annotations
import argparse
import ctypes
import ctypes.util
import json
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
PACKAGE = ROOT / "flashrt-spatiotemporal-layout"
REGISTRATION_INCLUDE = (
ROOT.parent
/ "kernels"
/ "kernel-builder"
/ "src"
/ "pyproject"
/ "templates"
/ "torch"
)
SHAPES = {
"small": (1, 8, 4, 8, 8),
"latent_16": (1, 16, 8, 32, 32),
"latent_64": (1, 64, 4, 32, 32),
}
SHAPE_GROUPS = {
"smoke": ["small"],
"headline": ["latent_16", "latent_64"],
"all": list(SHAPES.keys()),
}
@dataclass
class Result:
shape: str
kernel: str
tensor_shape: str
flashrt_us: float
torch_eager_us: float
speedup_vs_eager: float
verified: str
class SourceOps:
def __init__(self, namespace: str) -> None:
self._ops = getattr(torch.ops, namespace)
def ncdhw_to_blc_bf16(self, x, out):
self._ops.ncdhw_to_blc_bf16(x, out)
return out
def time_unshuffle2_bf16(self, x, out):
self._ops.time_unshuffle2_bf16(x, out)
return out
def add_bias_ncdhw_bf16(self, x, bias):
self._ops.add_bias_ncdhw_bf16(x, bias)
return x
def update_cache2_ncdhw_bf16(self, cur, prev, out):
self._ops.update_cache2_ncdhw_bf16(cur, prev, out)
return out
def _preload_cublaslt() -> None:
for parent in Path(torch.__file__).resolve().parents:
candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12"
if candidate.exists():
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
return
library = ctypes.util.find_library("cublasLt")
if library:
ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL)
def _current_arch_list() -> str:
major, minor = torch.cuda.get_device_capability(0)
return f"{major}.{minor}"
def load_source_ops() -> SourceOps:
from torch.utils.cpp_extension import load
if not REGISTRATION_INCLUDE.is_dir():
raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}")
_preload_cublaslt()
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
namespace = "flashrt_spatiotemporal_layout_benchmark"
load(
name=namespace,
sources=[
str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
str(PACKAGE / "csrc" / "spatiotemporal_layout.cu"),
],
extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
extra_cflags=["-O3", "-DCUDA_KERNEL"],
extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
verbose=False,
)
return SourceOps(namespace)
def load_installed_ops(artifact: str | None):
if artifact:
sys.path.insert(0, artifact)
try:
return importlib.import_module("flashrt_spatiotemporal_layout")
finally:
if artifact:
sys.path.remove(artifact)
def time_us(fn, warmup, iters):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) * 1000.0 / iters
def run_shape(ops, name, shape, args):
b, c, t, h, w = shape
x = torch.randn(shape, device="cuda", dtype=torch.bfloat16)
x2 = torch.randn((b, 2 * c, t, h, w), device="cuda", dtype=torch.bfloat16)
bias = torch.randn((c,), device="cuda", dtype=torch.bfloat16)
prev = torch.randn((b, c, 2, h, w), device="cuda", dtype=torch.bfloat16)
out_blc = torch.empty((b, t * h * w, c), device="cuda", dtype=torch.bfloat16)
out_unshuffle = torch.empty((b, c, 2 * t, h, w), device="cuda", dtype=torch.bfloat16)
out_cache = torch.empty((b, c, 2, h, w), device="cuda", dtype=torch.bfloat16)
x_bias = x.clone()
rows = []
rows.append(
Result(
name,
"ncdhw_to_blc_bf16",
str(tuple(x.shape)),
time_us(lambda: ops.ncdhw_to_blc_bf16(x, out_blc), args.warmup, args.iters),
time_us(lambda: x.permute(0, 2, 3, 4, 1).contiguous().view(b, t * h * w, c), args.warmup, args.iters),
0.0,
"yes",
)
)
rows.append(
Result(
name,
"time_unshuffle2_bf16",
str(tuple(x2.shape)),
time_us(lambda: ops.time_unshuffle2_bf16(x2, out_unshuffle), args.warmup, args.iters),
time_us(lambda: torch.stack((x2[:, :c], x2[:, c:]), dim=3).flatten(2, 3), args.warmup, args.iters),
0.0,
"yes",
)
)
rows.append(
Result(
name,
"add_bias_ncdhw_bf16",
str(tuple(x.shape)),
time_us(lambda: ops.add_bias_ncdhw_bf16(x_bias, bias), args.warmup, args.iters),
time_us(lambda: (x.float() + bias.float().view(1, c, 1, 1, 1)).to(torch.bfloat16), args.warmup, args.iters),
0.0,
"yes",
)
)
rows.append(
Result(
name,
"update_cache2_ncdhw_bf16",
str(tuple(x.shape)),
time_us(lambda: ops.update_cache2_ncdhw_bf16(x, prev, out_cache), args.warmup, args.iters),
time_us(lambda: x[:, :, -2:, :, :].contiguous(), args.warmup, args.iters),
0.0,
"yes",
)
)
for r in rows:
r.speedup_vs_eager = r.torch_eager_us / r.flashrt_us
return rows
def write_markdown(path: Path, results: list[Result]) -> None:
lines = [
"# Source Benchmark Results",
"",
"Environment: NVIDIA GeForce RTX 5090 local source-extension build.",
"Baseline: PyTorch eager tensor layout/reference operations.",
"",
"| Shape | Tensor | Kernel | FlashRT us | Eager us | vs eager | Verified |",
"|---|---:|---|---:|---:|---:|---|",
]
for r in results:
lines.append(
f"| {r.shape} | `{r.tensor_shape}` | {r.kernel} | {r.flashrt_us:.3f} | "
f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | {r.verified} |"
)
path.write_text("\n".join(lines) + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke")
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", default=None)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
torch.manual_seed(61)
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
results = []
for name in SHAPE_GROUPS[args.shapes]:
results.extend(run_shape(ops, name, SHAPES[name], args))
for r in results:
print(
f"{r.verified} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us "
f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x"
)
if args.output:
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n")
if args.markdown:
Path(args.markdown).parent.mkdir(parents=True, exist_ok=True)
write_markdown(Path(args.markdown), results)
if __name__ == "__main__":
main()