"""Top-level benchmark orchestration for agent comparison runs.""" from __future__ import annotations import os from pathlib import Path from dotenv import load_dotenv from dataforge.bench.core import ( AggregateBenchmarkResult, BenchmarkRunOutput, SeedBenchmarkResult, aggregate_seed_results, estimate_llm_calls, validate_estimated_calls, write_run_output, ) from dataforge.bench.groq_client import GroqBenchClient from dataforge.bench.methods import ( run_heuristic_episode, run_llm_react_episode, run_llm_zeroshot_episode, run_random_episode, ) from dataforge.datasets.real_world import load_real_world_dataset from dataforge.datasets.registry import DATASET_REGISTRY _SUPPORTED_METHODS = frozenset({"random", "heuristic", "llm_zeroshot", "llm_react"}) def _validate_inputs(methods: list[str], datasets: list[str], seeds: int) -> None: """Validate user-selected methods and datasets.""" unknown_methods = sorted(set(methods) - _SUPPORTED_METHODS) unknown_datasets = sorted(set(datasets) - set(DATASET_REGISTRY)) if unknown_methods: raise ValueError(f"Unknown benchmark methods: {unknown_methods}") if unknown_datasets: raise ValueError(f"Unknown benchmark datasets: {unknown_datasets}") if seeds <= 0: raise ValueError("Benchmark seeds must be >= 1.") def _reproduction_command(methods: list[str], datasets: list[str], seeds: int) -> str: """Build the canonical command for reproducing a benchmark run.""" return ( "dataforge bench " f"--methods {','.join(methods)} " f"--datasets {','.join(datasets)} " f"--seeds {seeds}" ) def _llm_skip_reason() -> str | None: """Return a skip reason when LLM methods cannot run.""" provider = os.environ.get("DATAFORGE_LLM_PROVIDER", "").strip().lower() if provider != "groq": return "DATAFORGE_LLM_PROVIDER must be set to groq." if not os.environ.get("GROQ_API_KEY"): return "GROQ_API_KEY is not set." return None def _skipped_result( *, method: str, dataset: str, seed: int, reason: str, reproduction_command: str, ) -> SeedBenchmarkResult: """Build a skipped seed result with a clear reason.""" return SeedBenchmarkResult( method=method, dataset=dataset, seed=seed, status="skipped", skip_reason=reason, llm_calls=0, prompt_tokens=0, completion_tokens=0, quota_units=0.0, runtime_s=0.0, provider=None, model=None, warnings=["provider_unset"], reproduction_command=reproduction_command, ) def run_agent_comparison( *, methods: list[str], datasets: list[str], seeds: int, output_json: Path, really_run_big_bench: bool, cache_root: Path | None = None, ) -> BenchmarkRunOutput: """Run the selected benchmark methods across real-world datasets.""" load_dotenv() _validate_inputs(methods, datasets, seeds) estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds) validate_estimated_calls( estimated_calls=estimated_calls, really_run_big_bench=really_run_big_bench, ) reproduction_command = _reproduction_command(methods, datasets, seeds) records: list[SeedBenchmarkResult] = [] loaded_datasets = { dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root) for dataset_name in datasets } llm_methods_requested = any(method.startswith("llm_") for method in methods) skip_reason = _llm_skip_reason() if llm_methods_requested else None client = ( GroqBenchClient(api_key=os.environ["GROQ_API_KEY"]) if llm_methods_requested and skip_reason is None else None ) for dataset_name in datasets: dataset = loaded_datasets[dataset_name] for method in methods: for seed in range(seeds): if method == "random": result = run_random_episode(dataset, seed=seed) elif method == "heuristic": result = run_heuristic_episode(dataset, seed=seed) elif method == "llm_zeroshot": if client is None or skip_reason is not None: result = _skipped_result( method=method, dataset=dataset_name, seed=seed, reason=skip_reason or "LLM client unavailable.", reproduction_command=reproduction_command, ) else: result = run_llm_zeroshot_episode(dataset, seed=seed, client=client) else: if client is None or skip_reason is not None: result = _skipped_result( method=method, dataset=dataset_name, seed=seed, reason=skip_reason or "LLM client unavailable.", reproduction_command=reproduction_command, ) else: result = run_llm_react_episode(dataset, seed=seed, client=client) if result.reproduction_command != reproduction_command: result = result.model_copy( update={"reproduction_command": reproduction_command} ) if method == "heuristic": result = result.model_copy(update={"seed": seed}) records.append(result) aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results( records, seeds_requested=seeds ) output = BenchmarkRunOutput( metadata={ "methods": methods, "datasets": datasets, "seeds": seeds, "reproduction_command": reproduction_command, }, records=records, aggregates=aggregates, ) write_run_output(output, output_json) return output