Praneshrajan15's picture
Deploy DataForge playground API
791c076 verified
"""Top-level benchmark orchestration for agent comparison runs."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
from dataforge.bench.core import (
AggregateBenchmarkResult,
BenchmarkRunOutput,
SeedBenchmarkResult,
aggregate_seed_results,
build_benchmark_metadata,
build_seed_list,
dataset_evidence_from_loaded,
estimate_llm_calls,
validate_estimated_calls,
write_run_output,
)
from dataforge.bench.groq_client import GroqBenchClient
from dataforge.bench.methods import (
run_heuristic_episode,
run_llm_react_episode,
run_llm_zeroshot_episode,
run_random_episode,
)
from dataforge.datasets.real_world import load_real_world_dataset
from dataforge.datasets.registry import DATASET_REGISTRY
_SUPPORTED_METHODS = frozenset({"random", "heuristic", "llm_zeroshot", "llm_react"})
def _validate_inputs(methods: list[str], datasets: list[str]) -> None:
"""Validate user-selected methods and datasets."""
unknown_methods = sorted(set(methods) - _SUPPORTED_METHODS)
unknown_datasets = sorted(set(datasets) - set(DATASET_REGISTRY))
if unknown_methods:
raise ValueError(f"Unknown benchmark methods: {unknown_methods}")
if unknown_datasets:
raise ValueError(f"Unknown benchmark datasets: {unknown_datasets}")
def _reproduction_command(
methods: list[str],
datasets: list[str],
*,
seed_count: int,
seed_list: list[int] | None,
) -> str:
"""Build the canonical command for reproducing a benchmark run."""
command = (
"dataforge bench "
f"--methods {','.join(methods)} "
f"--datasets {','.join(datasets)} "
f"--seeds {seed_count}"
)
if seed_list is not None:
command += f" --seed-list {','.join(str(seed) for seed in seed_list)}"
return command
def _llm_skip_reason() -> str | None:
"""Return a skip reason when LLM methods cannot run."""
provider = os.environ.get("DATAFORGE_LLM_PROVIDER", "").strip().lower()
if provider != "groq":
return "DATAFORGE_LLM_PROVIDER must be set to groq."
if not os.environ.get("GROQ_API_KEY"):
return "GROQ_API_KEY is not set."
return None
def _skipped_result(
*,
method: str,
dataset: str,
seed: int,
reason: str,
reproduction_command: str,
) -> SeedBenchmarkResult:
"""Build a skipped seed result with a clear reason."""
return SeedBenchmarkResult(
method=method,
dataset=dataset,
seed=seed,
status="skipped",
skip_reason=reason,
llm_calls=0,
prompt_tokens=0,
completion_tokens=0,
quota_units=0.0,
runtime_s=0.0,
provider=None,
model=None,
warnings=["provider_unset"],
reproduction_command=reproduction_command,
)
def run_agent_comparison(
*,
methods: list[str],
datasets: list[str],
seeds: int,
output_json: Path,
really_run_big_bench: bool,
cache_root: Path | None = None,
reproduction_command: str | None = None,
seed_list: list[int] | None = None,
verify_dataset_hashes: bool = True,
) -> BenchmarkRunOutput:
"""Run the selected benchmark methods across real-world datasets."""
load_dotenv()
_validate_inputs(methods, datasets)
resolved_seed_list = build_seed_list(seeds=seeds, seed_list=seed_list)
estimated_calls = estimate_llm_calls(
methods=methods,
datasets=datasets,
seeds=len(resolved_seed_list),
)
# Validate call budget before any client instantiation or dataset loads that could
# trigger network access in tests with environment variables set.
validate_estimated_calls(
estimated_calls=estimated_calls,
really_run_big_bench=really_run_big_bench,
)
reproduction_command = reproduction_command or _reproduction_command(
methods,
datasets,
seed_count=len(resolved_seed_list),
seed_list=seed_list,
)
records: list[SeedBenchmarkResult] = []
loaded_datasets = {
dataset_name: load_real_world_dataset(
dataset_name,
cache_root=cache_root,
verify_hashes=verify_dataset_hashes,
)
for dataset_name in datasets
}
llm_methods_requested = any(method.startswith("llm_") for method in methods)
skip_reason = _llm_skip_reason() if llm_methods_requested else None
client = None
if llm_methods_requested and skip_reason is None:
# Allow env-driven tuning for tiny CI checks.
model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile")
try:
min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0"))
except ValueError:
min_interval_s = 1.0
try:
timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30"))
except ValueError:
timeout_s = 30.0
try:
max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256"))
except ValueError:
max_tokens = 256
try:
max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3"))
except ValueError:
max_retries = 3
client = GroqBenchClient(
api_key=os.environ["GROQ_API_KEY"],
model=model,
min_interval_s=min_interval_s,
max_tokens=max_tokens,
max_retries=max_retries,
timeout_s=timeout_s,
)
for dataset_name in datasets:
dataset = loaded_datasets[dataset_name]
for method in methods:
for seed in resolved_seed_list:
if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
print(
f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}",
file=sys.stderr,
flush=True,
)
if method == "random":
result = run_random_episode(dataset, seed=seed)
elif method == "heuristic":
result = run_heuristic_episode(dataset, seed=seed)
elif method == "llm_zeroshot":
if client is None or skip_reason is not None:
result = _skipped_result(
method=method,
dataset=dataset_name,
seed=seed,
reason=skip_reason or "LLM client unavailable.",
reproduction_command=reproduction_command,
)
else:
result = run_llm_zeroshot_episode(dataset, seed=seed, client=client)
else:
if client is None or skip_reason is not None:
result = _skipped_result(
method=method,
dataset=dataset_name,
seed=seed,
reason=skip_reason or "LLM client unavailable.",
reproduction_command=reproduction_command,
)
else:
result = run_llm_react_episode(dataset, seed=seed, client=client)
if result.reproduction_command != reproduction_command:
result = result.model_copy(
update={"reproduction_command": reproduction_command}
)
if method == "heuristic":
result = result.model_copy(update={"seed": seed})
records.append(result)
if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
print(
f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}",
file=sys.stderr,
flush=True,
)
aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
records, seeds_requested=len(resolved_seed_list)
)
dataset_evidence = [
dataset_evidence_from_loaded(loaded_datasets[dataset_name]) for dataset_name in datasets
]
metadata = build_benchmark_metadata(
methods=methods,
datasets=datasets,
seed_list=resolved_seed_list,
reproduction_command=reproduction_command,
dataset_evidence=dataset_evidence,
)
output = BenchmarkRunOutput(
metadata=metadata.model_dump(mode="json"),
records=records,
aggregates=aggregates,
)
write_run_output(output, output_json)
return output