"""Top-level benchmark orchestration for agent comparison runs.""" from __future__ import annotations import os import sys from pathlib import Path from dotenv import load_dotenv from dataforge.bench.core import ( AggregateBenchmarkResult, BenchmarkRunOutput, SeedBenchmarkResult, aggregate_seed_results, build_benchmark_metadata, build_seed_list, dataset_evidence_from_loaded, estimate_llm_calls, validate_estimated_calls, write_run_output, ) from dataforge.bench.groq_client import GroqBenchClient from dataforge.bench.methods import ( run_heuristic_episode, run_llm_react_episode, run_llm_zeroshot_episode, run_random_episode, ) from dataforge.datasets.real_world import load_real_world_dataset from dataforge.datasets.registry import DATASET_REGISTRY _SUPPORTED_METHODS = frozenset({"random", "heuristic", "llm_zeroshot", "llm_react"}) def _validate_inputs(methods: list[str], datasets: list[str]) -> None: """Validate user-selected methods and datasets.""" unknown_methods = sorted(set(methods) - _SUPPORTED_METHODS) unknown_datasets = sorted(set(datasets) - set(DATASET_REGISTRY)) if unknown_methods: raise ValueError(f"Unknown benchmark methods: {unknown_methods}") if unknown_datasets: raise ValueError(f"Unknown benchmark datasets: {unknown_datasets}") def _reproduction_command( methods: list[str], datasets: list[str], *, seed_count: int, seed_list: list[int] | None, ) -> str: """Build the canonical command for reproducing a benchmark run.""" command = ( "dataforge bench " f"--methods {','.join(methods)} " f"--datasets {','.join(datasets)} " f"--seeds {seed_count}" ) if seed_list is not None: command += f" --seed-list {','.join(str(seed) for seed in seed_list)}" return command def _llm_skip_reason() -> str | None: """Return a skip reason when LLM methods cannot run.""" provider = os.environ.get("DATAFORGE_LLM_PROVIDER", "").strip().lower() if provider != "groq": return "DATAFORGE_LLM_PROVIDER must be set to groq." if not os.environ.get("GROQ_API_KEY"): return "GROQ_API_KEY is not set." return None def _skipped_result( *, method: str, dataset: str, seed: int, reason: str, reproduction_command: str, ) -> SeedBenchmarkResult: """Build a skipped seed result with a clear reason.""" return SeedBenchmarkResult( method=method, dataset=dataset, seed=seed, status="skipped", skip_reason=reason, llm_calls=0, prompt_tokens=0, completion_tokens=0, quota_units=0.0, runtime_s=0.0, provider=None, model=None, warnings=["provider_unset"], reproduction_command=reproduction_command, ) def run_agent_comparison( *, methods: list[str], datasets: list[str], seeds: int, output_json: Path, really_run_big_bench: bool, cache_root: Path | None = None, reproduction_command: str | None = None, seed_list: list[int] | None = None, verify_dataset_hashes: bool = True, ) -> BenchmarkRunOutput: """Run the selected benchmark methods across real-world datasets.""" load_dotenv() _validate_inputs(methods, datasets) resolved_seed_list = build_seed_list(seeds=seeds, seed_list=seed_list) estimated_calls = estimate_llm_calls( methods=methods, datasets=datasets, seeds=len(resolved_seed_list), ) # Validate call budget before any client instantiation or dataset loads that could # trigger network access in tests with environment variables set. validate_estimated_calls( estimated_calls=estimated_calls, really_run_big_bench=really_run_big_bench, ) reproduction_command = reproduction_command or _reproduction_command( methods, datasets, seed_count=len(resolved_seed_list), seed_list=seed_list, ) records: list[SeedBenchmarkResult] = [] loaded_datasets = { dataset_name: load_real_world_dataset( dataset_name, cache_root=cache_root, verify_hashes=verify_dataset_hashes, ) for dataset_name in datasets } llm_methods_requested = any(method.startswith("llm_") for method in methods) skip_reason = _llm_skip_reason() if llm_methods_requested else None client = None if llm_methods_requested and skip_reason is None: # Allow env-driven tuning for tiny CI checks. model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile") try: min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0")) except ValueError: min_interval_s = 1.0 try: timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30")) except ValueError: timeout_s = 30.0 try: max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256")) except ValueError: max_tokens = 256 try: max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3")) except ValueError: max_retries = 3 client = GroqBenchClient( api_key=os.environ["GROQ_API_KEY"], model=model, min_interval_s=min_interval_s, max_tokens=max_tokens, max_retries=max_retries, timeout_s=timeout_s, ) for dataset_name in datasets: dataset = loaded_datasets[dataset_name] for method in methods: for seed in resolved_seed_list: if os.environ.get("DATAFORGE_BENCH_VERBOSE"): print( f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}", file=sys.stderr, flush=True, ) if method == "random": result = run_random_episode(dataset, seed=seed) elif method == "heuristic": result = run_heuristic_episode(dataset, seed=seed) elif method == "llm_zeroshot": if client is None or skip_reason is not None: result = _skipped_result( method=method, dataset=dataset_name, seed=seed, reason=skip_reason or "LLM client unavailable.", reproduction_command=reproduction_command, ) else: result = run_llm_zeroshot_episode(dataset, seed=seed, client=client) else: if client is None or skip_reason is not None: result = _skipped_result( method=method, dataset=dataset_name, seed=seed, reason=skip_reason or "LLM client unavailable.", reproduction_command=reproduction_command, ) else: result = run_llm_react_episode(dataset, seed=seed, client=client) if result.reproduction_command != reproduction_command: result = result.model_copy( update={"reproduction_command": reproduction_command} ) if method == "heuristic": result = result.model_copy(update={"seed": seed}) records.append(result) if os.environ.get("DATAFORGE_BENCH_VERBOSE"): print( f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}", file=sys.stderr, flush=True, ) aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results( records, seeds_requested=len(resolved_seed_list) ) dataset_evidence = [ dataset_evidence_from_loaded(loaded_datasets[dataset_name]) for dataset_name in datasets ] metadata = build_benchmark_metadata( methods=methods, datasets=datasets, seed_list=resolved_seed_list, reproduction_command=reproduction_command, dataset_evidence=dataset_evidence, ) output = BenchmarkRunOutput( metadata=metadata.model_dump(mode="json"), records=records, aggregates=aggregates, ) write_run_output(output, output_json) return output