Spaces:
Sleeping
Sleeping
| """CLI subcommand: ``dataforge bench``.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Annotated | |
| import typer | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.table import Table | |
| from dataforge.bench.runner import run_agent_comparison | |
| _console = Console(stderr=True) | |
| def _parse_csv_list(raw_value: str) -> list[str]: | |
| """Parse a comma-separated CLI option into a list of strings.""" | |
| values = [item.strip() for item in raw_value.split(",")] | |
| return [value for value in values if value] | |
| def bench( | |
| methods: Annotated[ | |
| str, | |
| typer.Option( | |
| "--methods", | |
| help="Comma-separated benchmark methods.", | |
| ), | |
| ] = "heuristic,llm_zeroshot", | |
| datasets: Annotated[ | |
| str, | |
| typer.Option( | |
| "--datasets", | |
| help="Comma-separated benchmark datasets.", | |
| ), | |
| ] = "hospital", | |
| seeds: Annotated[ | |
| int, | |
| typer.Option("--seeds", help="Number of seeds per method/dataset pair."), | |
| ] = 3, | |
| really_run_big_bench: Annotated[ | |
| bool, | |
| typer.Option( | |
| "--really-run-big-bench", | |
| help="Override the free-tier benchmark quota guard when estimated calls exceed 500.", | |
| ), | |
| ] = False, | |
| output_json: Annotated[ | |
| Path, | |
| typer.Option( | |
| "--output-json", | |
| help="Where to write eval/results/agent_comparison.json.", | |
| ), | |
| ] = Path("eval/results/agent_comparison.json"), | |
| ) -> None: | |
| """Run real-world benchmark methods across cached benchmark datasets.""" | |
| try: | |
| output = run_agent_comparison( | |
| methods=_parse_csv_list(methods), | |
| datasets=_parse_csv_list(datasets), | |
| seeds=seeds, | |
| output_json=output_json, | |
| really_run_big_bench=really_run_big_bench, | |
| ) | |
| except Exception as exc: | |
| _console.print( | |
| Panel( | |
| f"[bold red]{exc}[/bold red]", | |
| title="Benchmark Error", | |
| style="red", | |
| ) | |
| ) | |
| raise typer.Exit(code=2) from exc | |
| table = Table(title="DataForge Benchmark Summary") | |
| table.add_column("Method") | |
| table.add_column("Dataset") | |
| table.add_column("Status") | |
| table.add_column("F1") | |
| table.add_column("Avg Steps") | |
| table.add_column("Quota") | |
| for aggregate in output.aggregates: | |
| table.add_row( | |
| aggregate.method, | |
| aggregate.dataset, | |
| aggregate.status, | |
| "Skipped" if aggregate.f1_mean is None else f"{aggregate.f1_mean:.4f}", | |
| "Skipped" if aggregate.avg_steps_mean is None else f"{aggregate.avg_steps_mean:.2f}", | |
| "Skipped" | |
| if aggregate.quota_units_mean is None | |
| else f"{aggregate.quota_units_mean:.4f}", | |
| ) | |
| Console().print(table) | |
| if any(aggregate.status == "skipped" for aggregate in output.aggregates): | |
| Console().print( | |
| Panel( | |
| "Some LLM baselines were skipped. Set DATAFORGE_LLM_PROVIDER=groq and GROQ_API_KEY to enable them.", | |
| title="Benchmark Warning", | |
| style="yellow", | |
| ) | |
| ) | |