hydra / benchmarks /benchmark_hydra_decode.py
Frosty40's picture
Publish Hydra kernel source packet
7298fd0 verified
from __future__ import annotations
import argparse
import importlib
import sys
import time
from pathlib import Path
import torch
from kernels import get_local_kernel
def load_local_kernel(repo: Path):
for variant in (repo / "build", repo):
if (variant / "metadata.json").exists():
return get_local_kernel(variant)
sys.path.insert(0, str(repo / "torch-ext"))
return importlib.import_module("hydra")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Hydra decode microbenchmark")
parser.add_argument("--repo", default=".", help="Path to the hydra kernel directory")
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--heads", type=int, default=32)
parser.add_argument("--kv-heads", type=int, default=8)
parser.add_argument("--tokens", type=int, default=8192)
parser.add_argument("--head-dim", type=int, default=128)
parser.add_argument("--iters", type=int, default=100)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--window", type=int, default=0)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")
kernel = load_local_kernel(Path(args.repo))
q = torch.randn(
args.batch,
args.heads,
1,
args.head_dim,
device="cuda",
dtype=torch.bfloat16,
)
k = torch.randn(
args.batch,
args.kv_heads,
args.tokens,
args.head_dim,
device="cuda",
dtype=torch.bfloat16,
)
v = torch.randn_like(k)
window = None if args.window <= 0 else args.window
for _ in range(args.warmup):
kernel.hydra(q, k, v, sliding_window=window)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(args.iters):
kernel.hydra(q, k, v, sliding_window=window)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
ms = elapsed * 1000.0 / args.iters
print(
"hydra_decode "
f"B={args.batch} H={args.heads} Hkv={args.kv_heads} "
f"Tkv={args.tokens} D={args.head_dim} window={window or 0} "
f"iters={args.iters} ms_per_iter={ms:.4f}"
)
if __name__ == "__main__":
main()