Pranesh
deploy: sync staged DataForge Space
66b1c50
"""Top-level benchmark orchestration for agent comparison runs."""
from __future__ import annotations
import os
from pathlib import Path
from dotenv import load_dotenv
from dataforge.bench.core import (
AggregateBenchmarkResult,
BenchmarkRunOutput,
SeedBenchmarkResult,
aggregate_seed_results,
estimate_llm_calls,
validate_estimated_calls,
write_run_output,
)
from dataforge.bench.groq_client import GroqBenchClient
from dataforge.bench.methods import (
run_heuristic_episode,
run_llm_react_episode,
run_llm_zeroshot_episode,
run_random_episode,
)
from dataforge.datasets.real_world import load_real_world_dataset
from dataforge.datasets.registry import DATASET_REGISTRY
_SUPPORTED_METHODS = frozenset({"random", "heuristic", "llm_zeroshot", "llm_react"})
def _validate_inputs(methods: list[str], datasets: list[str], seeds: int) -> None:
"""Validate user-selected methods and datasets."""
unknown_methods = sorted(set(methods) - _SUPPORTED_METHODS)
unknown_datasets = sorted(set(datasets) - set(DATASET_REGISTRY))
if unknown_methods:
raise ValueError(f"Unknown benchmark methods: {unknown_methods}")
if unknown_datasets:
raise ValueError(f"Unknown benchmark datasets: {unknown_datasets}")
if seeds <= 0:
raise ValueError("Benchmark seeds must be >= 1.")
def _reproduction_command(methods: list[str], datasets: list[str], seeds: int) -> str:
"""Build the canonical command for reproducing a benchmark run."""
return (
"dataforge bench "
f"--methods {','.join(methods)} "
f"--datasets {','.join(datasets)} "
f"--seeds {seeds}"
)
def _llm_skip_reason() -> str | None:
"""Return a skip reason when LLM methods cannot run."""
provider = os.environ.get("DATAFORGE_LLM_PROVIDER", "").strip().lower()
if provider != "groq":
return "DATAFORGE_LLM_PROVIDER must be set to groq."
if not os.environ.get("GROQ_API_KEY"):
return "GROQ_API_KEY is not set."
return None
def _skipped_result(
*,
method: str,
dataset: str,
seed: int,
reason: str,
reproduction_command: str,
) -> SeedBenchmarkResult:
"""Build a skipped seed result with a clear reason."""
return SeedBenchmarkResult(
method=method,
dataset=dataset,
seed=seed,
status="skipped",
skip_reason=reason,
llm_calls=0,
prompt_tokens=0,
completion_tokens=0,
quota_units=0.0,
runtime_s=0.0,
provider=None,
model=None,
warnings=["provider_unset"],
reproduction_command=reproduction_command,
)
def run_agent_comparison(
*,
methods: list[str],
datasets: list[str],
seeds: int,
output_json: Path,
really_run_big_bench: bool,
cache_root: Path | None = None,
) -> BenchmarkRunOutput:
"""Run the selected benchmark methods across real-world datasets."""
load_dotenv()
_validate_inputs(methods, datasets, seeds)
estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds)
validate_estimated_calls(
estimated_calls=estimated_calls,
really_run_big_bench=really_run_big_bench,
)
reproduction_command = _reproduction_command(methods, datasets, seeds)
records: list[SeedBenchmarkResult] = []
loaded_datasets = {
dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root)
for dataset_name in datasets
}
llm_methods_requested = any(method.startswith("llm_") for method in methods)
skip_reason = _llm_skip_reason() if llm_methods_requested else None
client = (
GroqBenchClient(api_key=os.environ["GROQ_API_KEY"])
if llm_methods_requested and skip_reason is None
else None
)
for dataset_name in datasets:
dataset = loaded_datasets[dataset_name]
for method in methods:
for seed in range(seeds):
if method == "random":
result = run_random_episode(dataset, seed=seed)
elif method == "heuristic":
result = run_heuristic_episode(dataset, seed=seed)
elif method == "llm_zeroshot":
if client is None or skip_reason is not None:
result = _skipped_result(
method=method,
dataset=dataset_name,
seed=seed,
reason=skip_reason or "LLM client unavailable.",
reproduction_command=reproduction_command,
)
else:
result = run_llm_zeroshot_episode(dataset, seed=seed, client=client)
else:
if client is None or skip_reason is not None:
result = _skipped_result(
method=method,
dataset=dataset_name,
seed=seed,
reason=skip_reason or "LLM client unavailable.",
reproduction_command=reproduction_command,
)
else:
result = run_llm_react_episode(dataset, seed=seed, client=client)
if result.reproduction_command != reproduction_command:
result = result.model_copy(
update={"reproduction_command": reproduction_command}
)
if method == "heuristic":
result = result.model_copy(update={"seed": seed})
records.append(result)
aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
records, seeds_requested=seeds
)
output = BenchmarkRunOutput(
metadata={
"methods": methods,
"datasets": datasets,
"seeds": seeds,
"reproduction_command": reproduction_command,
},
records=records,
aggregates=aggregates,
)
write_run_output(output, output_json)
return output