feather-a10g-large-runtime / overlay /scripts /benchmark_hyena_stack.py
icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
from __future__ import annotations
"""Hyena stack benchmark — measure TPS under the four knob combinations.
Produces the table requested in Task 4:
| Config | TPS | BPB@500 | VRAM |
|----------------------------|------|---------|------|
| B=8, no flash, no cache | ... | ... | ... | <-- baseline
| B=16, no flash, no cache | ...
| B=16, no flash, cache on | ...
| B=16, flash on, cache on | ... | ... | ... | <-- best
Run ONE config by invoking with command-line args, then collate externally.
Each invocation runs train.py for the specified wall-clock time with the
given env overrides, tails run.log, and emits a single summary line.
Invocation:
cd /home/mikeb/work/feather
# On the RTX 3060 (local validation only — these numbers will NOT hit
# the 200k tps production floor):
.venv/bin/python scripts/benchmark_hyena_stack.py --config baseline --time 300
.venv/bin/python scripts/benchmark_hyena_stack.py --config b16 --time 300
.venv/bin/python scripts/benchmark_hyena_stack.py --config cache --time 300
# "kernel" config requires flashfftconv built — see kernels/cuda/flashfftconv/README.md
.venv/bin/python scripts/benchmark_hyena_stack.py --config kernel --time 300
# On A100/A10G (production cloud hardware), use time=900 (15 min) for
# stable steady-state numbers.
After each run the script prints:
BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB>
Collate those lines into the matrix table manually, then pick the winner
for the 6-hour production run (HYDRA_TIME_BUDGET=21600).
"""
import argparse
import os
import re
import subprocess
import sys
from pathlib import Path
REPO = Path(__file__).resolve().parents[1]
CONFIGS = {
# Baseline: B=8, no flash, no train-cache. Current reference point.
"baseline": {
"HYDRA_BATCH_SIZE": "8",
"HYDRA_HYENA_LAYERS": "3,7",
"HYDRA_HYENA_FLASH_FFT": "0",
"HYDRA_HYENA_TRAIN_CACHE": "0",
"HYDRA_HYENA_FILTER_CACHE": "0",
},
"b16": {
"HYDRA_BATCH_SIZE": "16",
"HYDRA_HYENA_LAYERS": "3,7",
"HYDRA_HYENA_FLASH_FFT": "0",
"HYDRA_HYENA_TRAIN_CACHE": "0",
"HYDRA_HYENA_FILTER_CACHE": "0",
},
"cache": {
"HYDRA_BATCH_SIZE": "16",
"HYDRA_HYENA_LAYERS": "3,7",
"HYDRA_HYENA_FLASH_FFT": "0",
"HYDRA_HYENA_TRAIN_CACHE": "1",
"HYDRA_HYENA_FILTER_CACHE": "1",
},
"kernel": {
"HYDRA_BATCH_SIZE": "16",
"HYDRA_HYENA_LAYERS": "3,7",
"HYDRA_HYENA_FLASH_FFT": "1",
"HYDRA_HYENA_TRAIN_CACHE": "1",
"HYDRA_HYENA_FILTER_CACHE": "1",
# Task 4 note: also bump HYDRA_HTM_SUBSAMPLE to 128 (from 64) in the
# best config to get more aggressive reclamation.
"HYDRA_HTM_SUBSAMPLE": "128",
},
}
def build_env(cfg_overrides: dict) -> dict:
"""Compose a full env dict from the inherited env + config overrides."""
env = os.environ.copy()
# Ensure the Hyena layer selection is always present (defaults to off).
env.setdefault("HYDRA_HYENA_LAYERS", "")
for k, v in cfg_overrides.items():
env[k] = v
return env
def parse_step_line(line: str) -> dict | None:
"""Parse a single step=... line into a dict of metrics, or None."""
if not line.startswith("step="):
return None
parts = re.findall(r"(\w+)=([0-9.eE+\-]+)", line)
try:
return {k: float(v) for k, v in parts}
except ValueError:
return None
def summarize(log_path: Path, warmup_steps: int = 50) -> dict:
"""Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak.
Skips the first `warmup_steps` to discard CUDA graph capture / autotune
spikes; takes the median of the rest.
"""
tps_vals = []
bpbs = []
vram_peak = 0.0
bpb_at_500 = None
with log_path.open() as f:
for line in f:
d = parse_step_line(line.strip())
if d is None:
continue
step = int(d.get("step", -1))
if step < warmup_steps:
continue
tps = d.get("tps")
if tps is not None:
tps_vals.append(tps)
bpb = d.get("bpb")
if bpb is not None:
bpbs.append(bpb)
if step == 500 and bpb_at_500 is None:
bpb_at_500 = bpb
vram = d.get("vram")
if vram is not None and vram > vram_peak:
vram_peak = vram
if not tps_vals:
return {"tps_steady": 0.0, "bpb_at_500": 0.0, "vram_peak": 0.0, "steps": 0}
tps_sorted = sorted(tps_vals)
tps_steady = tps_sorted[len(tps_sorted) // 2] # median
return {
"tps_steady": tps_steady,
"bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0),
"vram_peak": vram_peak,
"steps": len(tps_vals) + warmup_steps,
}
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, choices=list(CONFIGS))
ap.add_argument("--time", type=int, default=300, help="training seconds")
ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)")
args = ap.parse_args()
cfg = CONFIGS[args.config]
log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log"))
env = build_env(cfg)
env["HYDRA_TIME_BUDGET"] = str(args.time)
# Make the config visible up-front so failed runs are debuggable.
print(f"BENCH start config={args.config} time={args.time}s log={log_path}", flush=True)
print(f" overrides: {cfg}", flush=True)
with log_path.open("w") as logf:
proc = subprocess.Popen(
["python", "-u", str(REPO / "train.py")],
env=env,
cwd=str(REPO),
stdout=logf,
stderr=subprocess.STDOUT,
)
proc.wait()
print(f"BENCH wait_done exit={proc.returncode}", flush=True)
if proc.returncode != 0:
print(f"BENCH FAIL config={args.config}", flush=True)
return proc.returncode
summary = summarize(log_path)
print(
f"BENCHMARK config={args.config} "
f"tps_steady={summary['tps_steady']:.0f} "
f"bpb_at_500={summary['bpb_at_500']:.4f} "
f"vram_peak={summary['vram_peak']:.0f}MiB "
f"steps={summary['steps']}",
flush=True,
)
return 0
if __name__ == "__main__":
sys.exit(main())