Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| DR-Bench CLI - Run a benchmark for an agent. | |
| Usage: | |
| python scripts/run_benchmark.py --agent agents/baseline_agent.py | |
| python scripts/run_benchmark.py --agent agents/iterative_agent.py --prompts data/phi_prompts/ | |
| python scripts/run_benchmark.py --agent path/to/my_agent.py --limit 10 | |
| Environment variables (or .env file): | |
| LLM_API_KEY - API key for the LLM provider | |
| LLM_MODEL - Model name (default: gpt-4o) | |
| LLM_PROVIDER - Provider: openai, azure, custom (default: openai) | |
| LLM_BASE_URL - API base URL (default: https://api.openai.com/v1) | |
| BENCHMARK_WEBSEARCH_URL - WebSearch service URL (default: http://localhost:8002) | |
| """ | |
| import argparse | |
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from dotenv import load_dotenv | |
| load_dotenv(PROJECT_ROOT / ".env") | |
| from benchmark.llm import LLMConfig | |
| from benchmark.prompts import load_prompts | |
| from benchmark.runner import BenchmarkRunner, load_agent_from_file | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-8s | %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="DR-Bench: Run a benchmark for a research agent", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| # Required | |
| parser.add_argument( | |
| "--agent", required=True, | |
| help="Path to the agent Python file (must contain a BaseResearchAgent subclass)", | |
| ) | |
| # Prompts | |
| parser.add_argument( | |
| "--prompts", nargs="*", | |
| default=[str(PROJECT_ROOT / "data" / "phi_prompts")], | |
| help="Path(s) to prompt JSON files or directories", | |
| ) | |
| # LLM config | |
| parser.add_argument("--model", default=None, help="Override LLM model name") | |
| parser.add_argument("--provider", default=None, help="LLM provider: openai, azure, custom") | |
| parser.add_argument("--api-key", default=None, help="LLM API key (or set LLM_API_KEY env)") | |
| # WebSearch | |
| parser.add_argument( | |
| "--websearch-url", | |
| default=os.getenv("BENCHMARK_WEBSEARCH_URL", "http://localhost:8002"), | |
| help="WebSearch service URL", | |
| ) | |
| parser.add_argument( | |
| "--websearch-provider", default="brightdata", | |
| help="Search provider (default: brightdata)", | |
| ) | |
| parser.add_argument( | |
| "--max-search-results", type=int, default=5, | |
| help="Max results per search query", | |
| ) | |
| # Execution | |
| parser.add_argument("--concurrent", type=int, default=10, help="Max concurrent prompts") | |
| parser.add_argument("--timeout", type=float, default=300.0, help="Timeout per prompt (seconds)") | |
| parser.add_argument("--limit", type=int, default=None, help="Limit number of prompts") | |
| # Output | |
| parser.add_argument("-o", "--output", default=None, help="Output results JSON path") | |
| parser.add_argument("--checkpoint", default=None, help="Checkpoint file path") | |
| parser.add_argument("--no-resume", action="store_true", help="Don't resume from checkpoint") | |
| parser.add_argument( | |
| "--update-leaderboard", action="store_true", | |
| help="Update data/leaderboard.json with results", | |
| ) | |
| # Evaluation | |
| parser.add_argument( | |
| "--evaluate", action="store_true", | |
| help="Run coverage evaluation after benchmark (requires ground truth data and eval LLM)", | |
| ) | |
| parser.add_argument( | |
| "--no-evaluate", action="store_true", | |
| help="Skip evaluation even if ground truth is available", | |
| ) | |
| parser.add_argument( | |
| "--ground-truth-dir", | |
| default=str(PROJECT_ROOT / "data" / "ground_truth"), | |
| help="Path to ground truth pitch points directory", | |
| ) | |
| parser.add_argument( | |
| "--eval-workers", type=int, default=10, | |
| help="Max parallel workers for evaluation (default: 10)", | |
| ) | |
| args = parser.parse_args() | |
| # --- Load LLM config --- | |
| llm_config = LLMConfig.from_env() | |
| if args.model: | |
| llm_config.model = args.model | |
| if args.provider: | |
| llm_config.provider = args.provider | |
| if args.api_key: | |
| llm_config.api_key = args.api_key | |
| if not llm_config.api_key: | |
| print("ERROR: No LLM API key provided. Set LLM_API_KEY or use --api-key.") | |
| sys.exit(1) | |
| # --- Load agent --- | |
| logger.info(f"Loading agent from {args.agent}") | |
| agent = load_agent_from_file(args.agent, model_name=llm_config.model) | |
| logger.info(f"Agent loaded: {agent.name} by {agent.author}") | |
| # --- Load prompts --- | |
| prompt_files = [] | |
| for p in args.prompts: | |
| path = Path(p) | |
| if path.is_dir(): | |
| prompt_files.extend(sorted(path.glob("*.json"))) | |
| elif path.exists(): | |
| prompt_files.append(path) | |
| else: | |
| logger.warning(f"Prompt path not found: {p}") | |
| if not prompt_files: | |
| print("ERROR: No prompt files found. Provide --prompts with valid paths.") | |
| sys.exit(1) | |
| prompts = load_prompts([str(f) for f in prompt_files]) | |
| if args.limit: | |
| prompts = prompts[:args.limit] | |
| logger.info(f"Loaded {len(prompts)} prompts from {len(prompt_files)} file(s)") | |
| # --- Setup checkpoint --- | |
| checkpoint = args.checkpoint | |
| if checkpoint is None and not args.no_resume: | |
| safe_name = "".join(c if c.isalnum() or c == "-" else "_" for c in agent.name) | |
| checkpoint = str(PROJECT_ROOT / "data" / "results" / f"{safe_name}_checkpoint.json") | |
| # --- Create runner --- | |
| runner = BenchmarkRunner( | |
| agent=agent, | |
| llm_config=llm_config, | |
| websearch_url=args.websearch_url, | |
| websearch_provider=args.websearch_provider, | |
| max_search_results=args.max_search_results, | |
| max_concurrent=args.concurrent, | |
| timeout_per_prompt=args.timeout, | |
| output_dir=str(PROJECT_ROOT / "data" / "results"), | |
| checkpoint_file=checkpoint if not args.no_resume else None, | |
| ) | |
| # --- Print config --- | |
| print(f"\n DR-Bench Runner") | |
| print(f" {'='*50}") | |
| print(f" Agent: {agent.name}") | |
| print(f" Author: {agent.author}") | |
| print(f" Model: {llm_config.model}") | |
| print(f" Provider: {llm_config.provider}") | |
| print(f" Prompts: {len(prompts)}") | |
| print(f" Concurrent: {args.concurrent}") | |
| print(f" WebSearch: {args.websearch_url}") | |
| print(f" {'='*50}\n") | |
| # --- Run --- | |
| try: | |
| asyncio.run(runner.run(prompts, resume=not args.no_resume)) | |
| except KeyboardInterrupt: | |
| logger.warning("Interrupted. Saving progress...") | |
| finally: | |
| # Run evaluation if requested | |
| if args.evaluate and not args.no_evaluate: | |
| gt_dir = Path(args.ground_truth_dir) | |
| if gt_dir.exists() and any(gt_dir.glob("*.json")): | |
| print(f"\n Running coverage evaluation...") | |
| print(f" Ground truth: {gt_dir}") | |
| runner.evaluate_results( | |
| prompts=prompts, | |
| ground_truth_dir=str(gt_dir), | |
| max_workers=args.eval_workers, | |
| ) | |
| else: | |
| logger.warning( | |
| f"Ground truth directory not found or empty: {gt_dir}. " | |
| "Skipping evaluation." | |
| ) | |
| # Save results | |
| output_path = runner.save_results(args.output) | |
| runner.print_summary() | |
| # Update leaderboard if requested | |
| if args.update_leaderboard: | |
| leaderboard_path = str(PROJECT_ROOT / "data" / "leaderboard.json") | |
| runner.update_leaderboard(leaderboard_path) | |
| print(f"\n Leaderboard updated: {leaderboard_path}") | |
| # Cleanup | |
| asyncio.run(runner.cleanup()) | |
| if __name__ == "__main__": | |
| main() | |