liangsu9988's picture
Uploaded using `kernel-builder`.
2e95613 verified
Raw
History Blame
2.85 kB
#!/usr/bin/env python3
"""Benchmark causal-conv1d-state."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
import torch
TESTS = Path(__file__).resolve().parents[1] / "tests"
sys.path.insert(0, str(TESTS))
from test_causal_conv1d_state import MODES, SHAPES, load_installed_ops, load_source_ops, make_inputs, ref_chunk # noqa: E402
def time_cuda(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return float(start.elapsed_time(end) * 1000.0 / iters)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["source", "installed"], default="source")
parser.add_argument("--artifact", default=None)
parser.add_argument("--mode", choices=sorted(MODES), default="headline")
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument("--json-out", default=None)
args = parser.parse_args()
ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
rows = []
for name in MODES[args.mode]:
kind, B, S, C, K = SHAPES[name]
x, w, bias, state = make_inputs(B, S, C, K, seed=9900 + C + S)
if kind in {"chunk", "parallel", "gqa"}:
if kind == "gqa":
def kernel_call():
state_work = state.clone()
return ops.gqa(x, w, state_work, bias)
elif kind == "parallel":
def kernel_call():
state_work = state.clone()
return ops.parallel(x, w, state_work, bias)
else:
def kernel_call():
state_work = state.clone()
return ops.chunk(x, w, state_work, bias)
def ref_call():
return ref_chunk(x, w, bias, state)
else:
continue
kernel_us = time_cuda(kernel_call, args.warmup, args.iters)
ref_us = time_cuda(ref_call, max(2, args.warmup // 5), max(5, args.iters // 20))
rows.append({"shape": name, "kind": kind, "B": B, "S": S, "C": C, "K": K, "kernel_us": kernel_us, "torch_reference_us": ref_us, "speedup": ref_us / kernel_us})
print(f"{name}: kernel={kernel_us:.3f}us ref={ref_us:.3f}us speedup={ref_us / kernel_us:.2f}x")
if args.json_out:
Path(args.json_out).parent.mkdir(parents=True, exist_ok=True)
Path(args.json_out).write_text(json.dumps(rows, indent=2) + "\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())