feather-a10-runtime / overlay /scripts /benchmark_runner.py
Jackoatmon's picture
Update Feather training runtime image
951f760 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any, Callable
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
LEDGER_TEMPLATE_PATH = REPO_ROOT / "artifacts" / "benchmark_ledger.template.json"
from scripts.hydra_generation import build_hydra_generator
from scripts.benchmark_datasets import resolve_benchmark_dataset as resolve_canonical_dataset
from scripts.benchmark_suite import build_prompt, validate_sample
def load_jsonl_samples(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for line in path.read_text(encoding="utf-8").splitlines():
if line.strip():
rows.append(json.loads(line))
return rows
def _score_mbpp(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
passed = 0
for sample in samples:
validate_sample("MBPP", sample)
code = generate_fn(build_prompt("MBPP", sample))
namespace: dict[str, Any] = {}
exec(code, namespace, namespace)
for test in sample["tests"]:
exec(test, namespace, namespace)
passed += 1
return passed / len(samples) if samples else 0.0
def _extract_last_number(text: str) -> str | None:
matches = re.findall(r"-?\d+(?:\.\d+)?", text)
return matches[-1] if matches else None
def _score_gsm8k(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
passed = 0
for sample in samples:
validate_sample("GSM8K", sample)
output = generate_fn(build_prompt("GSM8K", sample))
pred = _extract_last_number(output)
if pred is not None and pred == str(sample["answer"]):
passed += 1
return passed / len(samples) if samples else 0.0
def _score_humaneval(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
passed = 0
for sample in samples:
validate_sample("HumanEval", sample)
code = generate_fn(build_prompt("HumanEval", sample))
namespace: dict[str, Any] = {}
exec(code, namespace, namespace)
exec(sample["test"], namespace, namespace)
passed += 1
return passed / len(samples) if samples else 0.0
def _score_arc(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
passed = 0
for sample in samples:
validate_sample("ARC-Challenge", sample)
output = generate_fn(build_prompt("ARC-Challenge", sample)).strip()
if output == str(sample["answer"]):
passed += 1
return passed / len(samples) if samples else 0.0
def run_benchmark(benchmark_name: str, path: Path, generate_fn: Callable[[str], str]) -> dict[str, Any]:
samples = load_jsonl_samples(path)
if benchmark_name == "MBPP":
return {
"benchmark": "MBPP",
"primary_metric": "pass_at_1",
"score": _score_mbpp(samples, generate_fn),
"n_samples": len(samples),
}
if benchmark_name == "GSM8K":
return {
"benchmark": "GSM8K",
"primary_metric": "exact_match",
"score": _score_gsm8k(samples, generate_fn),
"n_samples": len(samples),
}
if benchmark_name == "HumanEval":
return {
"benchmark": "HumanEval",
"primary_metric": "pass_at_1",
"score": _score_humaneval(samples, generate_fn),
"n_samples": len(samples),
}
if benchmark_name == "ARC-Challenge":
return {
"benchmark": "ARC-Challenge",
"primary_metric": "accuracy",
"score": _score_arc(samples, generate_fn),
"n_samples": len(samples),
}
raise ValueError(f"Unsupported runnable benchmark: {benchmark_name}")
def write_benchmark_result(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def append_benchmark_run_record(
ledger_path: Path,
result: dict[str, Any],
*,
benchmark_name: str,
variant: str,
seed: int,
samples_path: Path,
) -> None:
if not ledger_path.exists():
ledger_path.parent.mkdir(parents=True, exist_ok=True)
ledger_path.write_text(LEDGER_TEMPLATE_PATH.read_text(encoding="utf-8"), encoding="utf-8")
payload = json.loads(ledger_path.read_text(encoding="utf-8"))
run_records = payload.setdefault("run_records", [])
if len(run_records) == 1 and run_records[0].get("run_id") == "example-run-0001":
run_records.clear()
run_records.append(
{
"run_id": result.get("run_id", f"{benchmark_name.lower()}-{seed}"),
"commit": "HEAD",
"model_family": "hydra",
"variant": variant,
"seed": seed,
"hardware": {
"hardware_class": payload.get("benchmark_cycle", {}).get("hardware_class", "unknown"),
},
"budget": {
"budget_mode": payload.get("benchmark_cycle", {}).get("budget_modes", [None])[0],
},
"capability": {
"coding_score": result["score"] if benchmark_name in {"MBPP", "HumanEval"} else None,
"reasoning_score": result["score"] if benchmark_name in {"GSM8K", "ARC-Challenge"} else None,
},
"artifacts": {
"samples_path": str(samples_path),
},
}
)
ledger_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def resolve_samples_path(benchmark_name: str, samples: Path | None, suite_path: Path) -> Path:
if samples is not None:
return samples
payload = json.loads(suite_path.read_text(encoding="utf-8"))
for section in ("coding_benchmarks", "reasoning_benchmarks"):
if section not in payload:
continue
for slot in ("fast_iteration", "milestone"):
entry = payload[section].get(slot)
if isinstance(entry, dict) and entry.get("name") == benchmark_name and "sample_path" in entry:
return Path(entry["sample_path"])
try:
return resolve_canonical_dataset(benchmark_name, None)
except ValueError:
raise ValueError(f"No sample path found for benchmark: {benchmark_name}")
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run a local benchmark against JSONL samples")
parser.add_argument("--benchmark", required=True, choices=["MBPP", "GSM8K", "HumanEval", "ARC-Challenge"])
parser.add_argument("--samples", type=Path)
parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
parser.add_argument("--out", type=Path)
parser.add_argument("--ledger", type=Path)
parser.add_argument("--variant", default="hydra_full")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--generator-mode", choices=["stub", "hydra"], default="stub")
parser.add_argument("--checkpoint", type=Path)
parser.add_argument("--device")
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-p", type=float, default=0.95)
return parser.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
sample_path = resolve_samples_path(args.benchmark, args.samples, args.suite)
try:
if args.generator_mode == "hydra":
generator = build_hydra_generator(
checkpoint_path=args.checkpoint,
device=args.device,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
else:
def generator(prompt: str) -> str:
return prompt
result = run_benchmark(args.benchmark, sample_path, generator)
exit_code = 0
except FileNotFoundError as exc:
result = {
"benchmark": args.benchmark,
"status": "failed",
"failure_type": "missing_checkpoint",
"error": str(exc),
"n_samples": 0,
}
exit_code = 1
except Exception as exc: # noqa: BLE001
result = {
"benchmark": args.benchmark,
"status": "failed",
"failure_type": type(exc).__name__,
"error": str(exc),
"n_samples": 0,
}
exit_code = 1
if args.out is not None:
write_benchmark_result(args.out, result)
if args.ledger is not None and exit_code == 0:
append_benchmark_run_record(
args.ledger,
result,
benchmark_name=args.benchmark,
variant=args.variant,
seed=args.seed,
samples_path=sample_path,
)
print(json.dumps(result, indent=2, sort_keys=True))
return exit_code
if __name__ == "__main__":
raise SystemExit(main())