"""CLI subcommand: ``dataforge bench``.""" from __future__ import annotations from pathlib import Path from typing import Annotated import typer from rich.console import Console from rich.panel import Panel from rich.table import Table from dataforge.bench.runner import run_agent_comparison _console = Console(stderr=True) def _parse_csv_list(raw_value: str) -> list[str]: """Parse a comma-separated CLI option into a list of strings.""" values = [item.strip() for item in raw_value.split(",")] return [value for value in values if value] def bench( methods: Annotated[ str, typer.Option( "--methods", help="Comma-separated benchmark methods.", ), ] = "heuristic,llm_zeroshot", datasets: Annotated[ str, typer.Option( "--datasets", help="Comma-separated benchmark datasets.", ), ] = "hospital", seeds: Annotated[ int, typer.Option("--seeds", help="Number of seeds per method/dataset pair."), ] = 3, really_run_big_bench: Annotated[ bool, typer.Option( "--really-run-big-bench", help="Override the free-tier benchmark quota guard when estimated calls exceed 500.", ), ] = False, output_json: Annotated[ Path, typer.Option( "--output-json", help="Where to write eval/results/agent_comparison.json.", ), ] = Path("eval/results/agent_comparison.json"), ) -> None: """Run real-world benchmark methods across cached benchmark datasets.""" try: output = run_agent_comparison( methods=_parse_csv_list(methods), datasets=_parse_csv_list(datasets), seeds=seeds, output_json=output_json, really_run_big_bench=really_run_big_bench, ) except Exception as exc: _console.print( Panel( f"[bold red]{exc}[/bold red]", title="Benchmark Error", style="red", ) ) raise typer.Exit(code=2) from exc table = Table(title="DataForge Benchmark Summary") table.add_column("Method") table.add_column("Dataset") table.add_column("Status") table.add_column("F1") table.add_column("Avg Steps") table.add_column("Quota") for aggregate in output.aggregates: table.add_row( aggregate.method, aggregate.dataset, aggregate.status, "Skipped" if aggregate.f1_mean is None else f"{aggregate.f1_mean:.4f}", "Skipped" if aggregate.avg_steps_mean is None else f"{aggregate.avg_steps_mean:.2f}", "Skipped" if aggregate.quota_units_mean is None else f"{aggregate.quota_units_mean:.4f}", ) Console().print(table) if any(aggregate.status == "skipped" for aggregate in output.aggregates): Console().print( Panel( "Some LLM baselines were skipped. Set DATAFORGE_LLM_PROVIDER=groq and GROQ_API_KEY to enable them.", title="Benchmark Warning", style="yellow", ) )