Pranesh
deploy: sync staged DataForge Space
66b1c50
"""CLI subcommand: ``dataforge bench``."""
from __future__ import annotations
from pathlib import Path
from typing import Annotated
import typer
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from dataforge.bench.runner import run_agent_comparison
_console = Console(stderr=True)
def _parse_csv_list(raw_value: str) -> list[str]:
"""Parse a comma-separated CLI option into a list of strings."""
values = [item.strip() for item in raw_value.split(",")]
return [value for value in values if value]
def bench(
methods: Annotated[
str,
typer.Option(
"--methods",
help="Comma-separated benchmark methods.",
),
] = "heuristic,llm_zeroshot",
datasets: Annotated[
str,
typer.Option(
"--datasets",
help="Comma-separated benchmark datasets.",
),
] = "hospital",
seeds: Annotated[
int,
typer.Option("--seeds", help="Number of seeds per method/dataset pair."),
] = 3,
really_run_big_bench: Annotated[
bool,
typer.Option(
"--really-run-big-bench",
help="Override the free-tier benchmark quota guard when estimated calls exceed 500.",
),
] = False,
output_json: Annotated[
Path,
typer.Option(
"--output-json",
help="Where to write eval/results/agent_comparison.json.",
),
] = Path("eval/results/agent_comparison.json"),
) -> None:
"""Run real-world benchmark methods across cached benchmark datasets."""
try:
output = run_agent_comparison(
methods=_parse_csv_list(methods),
datasets=_parse_csv_list(datasets),
seeds=seeds,
output_json=output_json,
really_run_big_bench=really_run_big_bench,
)
except Exception as exc:
_console.print(
Panel(
f"[bold red]{exc}[/bold red]",
title="Benchmark Error",
style="red",
)
)
raise typer.Exit(code=2) from exc
table = Table(title="DataForge Benchmark Summary")
table.add_column("Method")
table.add_column("Dataset")
table.add_column("Status")
table.add_column("F1")
table.add_column("Avg Steps")
table.add_column("Quota")
for aggregate in output.aggregates:
table.add_row(
aggregate.method,
aggregate.dataset,
aggregate.status,
"Skipped" if aggregate.f1_mean is None else f"{aggregate.f1_mean:.4f}",
"Skipped" if aggregate.avg_steps_mean is None else f"{aggregate.avg_steps_mean:.2f}",
"Skipped"
if aggregate.quota_units_mean is None
else f"{aggregate.quota_units_mean:.4f}",
)
Console().print(table)
if any(aggregate.status == "skipped" for aggregate in output.aggregates):
Console().print(
Panel(
"Some LLM baselines were skipped. Set DATAFORGE_LLM_PROVIDER=groq and GROQ_API_KEY to enable them.",
title="Benchmark Warning",
style="yellow",
)
)