feather-a10-runtime / overlay /scripts /benchmark_suite.py
Jackoatmon's picture
Update Feather training runtime image
951f760 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class BenchmarkSpec:
name: str
family: str
required_fields: tuple[str, ...]
REGISTRY: dict[str, BenchmarkSpec] = {
"MBPP": BenchmarkSpec("MBPP", "coding", ("task_id", "prompt", "tests")),
"HumanEval": BenchmarkSpec("HumanEval", "coding", ("task_id", "prompt", "test")),
"GSM8K": BenchmarkSpec("GSM8K", "reasoning", ("question", "answer")),
"ARC-Challenge": BenchmarkSpec("ARC-Challenge", "reasoning", ("question", "choices", "answer")),
}
def validate_sample(benchmark_name: str, sample: dict[str, Any]) -> None:
spec = REGISTRY[benchmark_name]
for field in spec.required_fields:
if field not in sample:
raise ValueError(f"{benchmark_name} sample missing required field: {field}")
def build_prompt(benchmark_name: str, sample: dict[str, Any]) -> str:
validate_sample(benchmark_name, sample)
if benchmark_name == "MBPP":
tests = sample["tests"]
rendered_tests = "\n".join(str(t) for t in tests)
return (
"Write a Python function that solves the task below.\n\n"
f"Task:\n{sample['prompt']}\n\n"
f"Tests:\n{rendered_tests}\n"
)
if benchmark_name == "HumanEval":
return (
"Complete the following Python function exactly as specified.\n\n"
f"Prompt:\n{sample['prompt']}\n\n"
f"Reference test:\n{sample['test']}\n"
)
if benchmark_name == "GSM8K":
return f"Solve the following math word problem. Return only the final answer.\n\nQuestion: {sample['question']}\n"
if benchmark_name == "ARC-Challenge":
choices = sample["choices"]
rendered_choices = "\n".join(f"- {choice}" for choice in choices)
return (
"Answer the following multiple-choice science question. Return only the correct option text or label.\n\n"
f"Question: {sample['question']}\nChoices:\n{rendered_choices}\n"
)
raise ValueError(f"Unknown benchmark: {benchmark_name}")
def load_cycle_benchmark_suite(path: Path) -> dict[str, dict[str, BenchmarkSpec]]:
payload = json.loads(path.read_text(encoding="utf-8"))
out: dict[str, dict[str, BenchmarkSpec]] = {"coding_benchmarks": {}, "reasoning_benchmarks": {}}
for section in ("coding_benchmarks", "reasoning_benchmarks"):
if section not in payload:
raise ValueError(f"missing benchmark section: {section}")
for slot in ("fast_iteration", "milestone"):
if slot not in payload[section]:
raise ValueError(f"missing benchmark slot: {section}.{slot}")
name = payload[section][slot]["name"]
if name not in REGISTRY:
raise ValueError(f"unsupported benchmark: {name}")
out[section][slot] = REGISTRY[name]
return out
def main() -> int:
path = Path("artifacts/benchmark_suite.cycle1.json")
suite = load_cycle_benchmark_suite(path)
print(json.dumps({k: {slot: spec.name for slot, spec in section.items()} for k, section in suite.items()}, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())