| |
| """Benchmark causal-conv1d-state.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import sys |
| import torch |
|
|
| TESTS = Path(__file__).resolve().parents[1] / "tests" |
| sys.path.insert(0, str(TESTS)) |
| from test_causal_conv1d_state import MODES, SHAPES, load_installed_ops, load_source_ops, make_inputs, ref_chunk |
|
|
|
|
| def time_cuda(fn, warmup: int, iters: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return float(start.elapsed_time(end) * 1000.0 / iters) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--backend", choices=["source", "installed"], default="source") |
| parser.add_argument("--artifact", default=None) |
| parser.add_argument("--mode", choices=sorted(MODES), default="headline") |
| parser.add_argument("--warmup", type=int, default=20) |
| parser.add_argument("--iters", type=int, default=200) |
| parser.add_argument("--json-out", default=None) |
| args = parser.parse_args() |
| ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact) |
| rows = [] |
| for name in MODES[args.mode]: |
| kind, B, S, C, K = SHAPES[name] |
| x, w, bias, state = make_inputs(B, S, C, K, seed=9900 + C + S) |
| if kind in {"chunk", "parallel", "gqa"}: |
| if kind == "gqa": |
| def kernel_call(): |
| state_work = state.clone() |
| return ops.gqa(x, w, state_work, bias) |
| elif kind == "parallel": |
| def kernel_call(): |
| state_work = state.clone() |
| return ops.parallel(x, w, state_work, bias) |
| else: |
| def kernel_call(): |
| state_work = state.clone() |
| return ops.chunk(x, w, state_work, bias) |
|
|
| def ref_call(): |
| return ref_chunk(x, w, bias, state) |
| else: |
| continue |
| kernel_us = time_cuda(kernel_call, args.warmup, args.iters) |
| ref_us = time_cuda(ref_call, max(2, args.warmup // 5), max(5, args.iters // 20)) |
| rows.append({"shape": name, "kind": kind, "B": B, "S": S, "C": C, "K": K, "kernel_us": kernel_us, "torch_reference_us": ref_us, "speedup": ref_us / kernel_us}) |
| print(f"{name}: kernel={kernel_us:.3f}us ref={ref_us:.3f}us speedup={ref_us / kernel_us:.2f}x") |
| if args.json_out: |
| Path(args.json_out).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.json_out).write_text(json.dumps(rows, indent=2) + "\n") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|