from __future__ import annotations """Hyena stack benchmark — measure TPS under the four knob combinations. Produces the table requested in Task 4: | Config | TPS | BPB@500 | VRAM | |----------------------------|------|---------|------| | B=8, no flash, no cache | ... | ... | ... | <-- baseline | B=16, no flash, no cache | ... | B=16, no flash, cache on | ... | B=16, flash on, cache on | ... | ... | ... | <-- best Run ONE config by invoking with command-line args, then collate externally. Each invocation runs train.py for the specified wall-clock time with the given env overrides, tails run.log, and emits a single summary line. Invocation: cd /home/mikeb/work/feather # On the RTX 3060 (local validation only — these numbers will NOT hit # the 200k tps production floor): .venv/bin/python scripts/benchmark_hyena_stack.py --config baseline --time 300 .venv/bin/python scripts/benchmark_hyena_stack.py --config b16 --time 300 .venv/bin/python scripts/benchmark_hyena_stack.py --config cache --time 300 # "kernel" config requires flashfftconv built — see kernels/cuda/flashfftconv/README.md .venv/bin/python scripts/benchmark_hyena_stack.py --config kernel --time 300 # On A100/A10G (production cloud hardware), use time=900 (15 min) for # stable steady-state numbers. After each run the script prints: BENCHMARK config= tps_steady= bpb_at_500= vram_peak= Collate those lines into the matrix table manually, then pick the winner for the 6-hour production run (HYDRA_TIME_BUDGET=21600). """ import argparse import os import re import subprocess import sys from pathlib import Path REPO = Path(__file__).resolve().parents[1] CONFIGS = { # Baseline: B=8, no flash, no train-cache. Current reference point. "baseline": { "HYDRA_BATCH_SIZE": "8", "HYDRA_HYENA_LAYERS": "3,7", "HYDRA_HYENA_FLASH_FFT": "0", "HYDRA_HYENA_TRAIN_CACHE": "0", "HYDRA_HYENA_FILTER_CACHE": "0", }, "b16": { "HYDRA_BATCH_SIZE": "16", "HYDRA_HYENA_LAYERS": "3,7", "HYDRA_HYENA_FLASH_FFT": "0", "HYDRA_HYENA_TRAIN_CACHE": "0", "HYDRA_HYENA_FILTER_CACHE": "0", }, "cache": { "HYDRA_BATCH_SIZE": "16", "HYDRA_HYENA_LAYERS": "3,7", "HYDRA_HYENA_FLASH_FFT": "0", "HYDRA_HYENA_TRAIN_CACHE": "1", "HYDRA_HYENA_FILTER_CACHE": "1", }, "kernel": { "HYDRA_BATCH_SIZE": "16", "HYDRA_HYENA_LAYERS": "3,7", "HYDRA_HYENA_FLASH_FFT": "1", "HYDRA_HYENA_TRAIN_CACHE": "1", "HYDRA_HYENA_FILTER_CACHE": "1", # Task 4 note: also bump HYDRA_HTM_SUBSAMPLE to 128 (from 64) in the # best config to get more aggressive reclamation. "HYDRA_HTM_SUBSAMPLE": "128", }, } def build_env(cfg_overrides: dict) -> dict: """Compose a full env dict from the inherited env + config overrides.""" env = os.environ.copy() # Ensure the Hyena layer selection is always present (defaults to off). env.setdefault("HYDRA_HYENA_LAYERS", "") for k, v in cfg_overrides.items(): env[k] = v return env def parse_step_line(line: str) -> dict | None: """Parse a single step=... line into a dict of metrics, or None.""" if not line.startswith("step="): return None parts = re.findall(r"(\w+)=([0-9.eE+\-]+)", line) try: return {k: float(v) for k, v in parts} except ValueError: return None def summarize(log_path: Path, warmup_steps: int = 50) -> dict: """Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak. Skips the first `warmup_steps` to discard CUDA graph capture / autotune spikes; takes the median of the rest. """ tps_vals = [] bpbs = [] vram_peak = 0.0 bpb_at_500 = None with log_path.open() as f: for line in f: d = parse_step_line(line.strip()) if d is None: continue step = int(d.get("step", -1)) if step < warmup_steps: continue tps = d.get("tps") if tps is not None: tps_vals.append(tps) bpb = d.get("bpb") if bpb is not None: bpbs.append(bpb) if step == 500 and bpb_at_500 is None: bpb_at_500 = bpb vram = d.get("vram") if vram is not None and vram > vram_peak: vram_peak = vram if not tps_vals: return {"tps_steady": 0.0, "bpb_at_500": 0.0, "vram_peak": 0.0, "steps": 0} tps_sorted = sorted(tps_vals) tps_steady = tps_sorted[len(tps_sorted) // 2] # median return { "tps_steady": tps_steady, "bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0), "vram_peak": vram_peak, "steps": len(tps_vals) + warmup_steps, } def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--config", required=True, choices=list(CONFIGS)) ap.add_argument("--time", type=int, default=300, help="training seconds") ap.add_argument("--log", default=None, help="output log path (default: run_bench_.log)") args = ap.parse_args() cfg = CONFIGS[args.config] log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log")) env = build_env(cfg) env["HYDRA_TIME_BUDGET"] = str(args.time) # Make the config visible up-front so failed runs are debuggable. print(f"BENCH start config={args.config} time={args.time}s log={log_path}", flush=True) print(f" overrides: {cfg}", flush=True) with log_path.open("w") as logf: proc = subprocess.Popen( ["python", "-u", str(REPO / "train.py")], env=env, cwd=str(REPO), stdout=logf, stderr=subprocess.STDOUT, ) proc.wait() print(f"BENCH wait_done exit={proc.returncode}", flush=True) if proc.returncode != 0: print(f"BENCH FAIL config={args.config}", flush=True) return proc.returncode summary = summarize(log_path) print( f"BENCHMARK config={args.config} " f"tps_steady={summary['tps_steady']:.0f} " f"bpb_at_500={summary['bpb_at_500']:.4f} " f"vram_peak={summary['vram_peak']:.0f}MiB " f"steps={summary['steps']}", flush=True, ) return 0 if __name__ == "__main__": sys.exit(main())