#!/usr/bin/env python3 """ DR-Bench CLI - Run a benchmark for an agent. Usage: python scripts/run_benchmark.py --agent agents/baseline_agent.py python scripts/run_benchmark.py --agent agents/iterative_agent.py --prompts data/phi_prompts/ python scripts/run_benchmark.py --agent path/to/my_agent.py --limit 10 Environment variables (or .env file): LLM_API_KEY - API key for the LLM provider LLM_MODEL - Model name (default: gpt-4o) LLM_PROVIDER - Provider: openai, azure, custom (default: openai) LLM_BASE_URL - API base URL (default: https://api.openai.com/v1) BENCHMARK_WEBSEARCH_URL - WebSearch service URL (default: http://localhost:8002) """ import argparse import asyncio import logging import os import sys from pathlib import Path # Add project root to path PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from dotenv import load_dotenv load_dotenv(PROJECT_ROOT / ".env") from benchmark.llm import LLMConfig from benchmark.prompts import load_prompts from benchmark.runner import BenchmarkRunner, load_agent_from_file logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser( description="DR-Bench: Run a benchmark for a research agent", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) # Required parser.add_argument( "--agent", required=True, help="Path to the agent Python file (must contain a BaseResearchAgent subclass)", ) # Prompts parser.add_argument( "--prompts", nargs="*", default=[str(PROJECT_ROOT / "data" / "phi_prompts")], help="Path(s) to prompt JSON files or directories", ) # LLM config parser.add_argument("--model", default=None, help="Override LLM model name") parser.add_argument("--provider", default=None, help="LLM provider: openai, azure, custom") parser.add_argument("--api-key", default=None, help="LLM API key (or set LLM_API_KEY env)") # WebSearch parser.add_argument( "--websearch-url", default=os.getenv("BENCHMARK_WEBSEARCH_URL", "http://localhost:8002"), help="WebSearch service URL", ) parser.add_argument( "--websearch-provider", default="brightdata", help="Search provider (default: brightdata)", ) parser.add_argument( "--max-search-results", type=int, default=5, help="Max results per search query", ) # Execution parser.add_argument("--concurrent", type=int, default=10, help="Max concurrent prompts") parser.add_argument("--timeout", type=float, default=300.0, help="Timeout per prompt (seconds)") parser.add_argument("--limit", type=int, default=None, help="Limit number of prompts") # Output parser.add_argument("-o", "--output", default=None, help="Output results JSON path") parser.add_argument("--checkpoint", default=None, help="Checkpoint file path") parser.add_argument("--no-resume", action="store_true", help="Don't resume from checkpoint") parser.add_argument( "--update-leaderboard", action="store_true", help="Update data/leaderboard.json with results", ) # Evaluation parser.add_argument( "--evaluate", action="store_true", help="Run coverage evaluation after benchmark (requires ground truth data and eval LLM)", ) parser.add_argument( "--no-evaluate", action="store_true", help="Skip evaluation even if ground truth is available", ) parser.add_argument( "--ground-truth-dir", default=str(PROJECT_ROOT / "data" / "ground_truth"), help="Path to ground truth pitch points directory", ) parser.add_argument( "--eval-workers", type=int, default=10, help="Max parallel workers for evaluation (default: 10)", ) args = parser.parse_args() # --- Load LLM config --- llm_config = LLMConfig.from_env() if args.model: llm_config.model = args.model if args.provider: llm_config.provider = args.provider if args.api_key: llm_config.api_key = args.api_key if not llm_config.api_key: print("ERROR: No LLM API key provided. Set LLM_API_KEY or use --api-key.") sys.exit(1) # --- Load agent --- logger.info(f"Loading agent from {args.agent}") agent = load_agent_from_file(args.agent, model_name=llm_config.model) logger.info(f"Agent loaded: {agent.name} by {agent.author}") # --- Load prompts --- prompt_files = [] for p in args.prompts: path = Path(p) if path.is_dir(): prompt_files.extend(sorted(path.glob("*.json"))) elif path.exists(): prompt_files.append(path) else: logger.warning(f"Prompt path not found: {p}") if not prompt_files: print("ERROR: No prompt files found. Provide --prompts with valid paths.") sys.exit(1) prompts = load_prompts([str(f) for f in prompt_files]) if args.limit: prompts = prompts[:args.limit] logger.info(f"Loaded {len(prompts)} prompts from {len(prompt_files)} file(s)") # --- Setup checkpoint --- checkpoint = args.checkpoint if checkpoint is None and not args.no_resume: safe_name = "".join(c if c.isalnum() or c == "-" else "_" for c in agent.name) checkpoint = str(PROJECT_ROOT / "data" / "results" / f"{safe_name}_checkpoint.json") # --- Create runner --- runner = BenchmarkRunner( agent=agent, llm_config=llm_config, websearch_url=args.websearch_url, websearch_provider=args.websearch_provider, max_search_results=args.max_search_results, max_concurrent=args.concurrent, timeout_per_prompt=args.timeout, output_dir=str(PROJECT_ROOT / "data" / "results"), checkpoint_file=checkpoint if not args.no_resume else None, ) # --- Print config --- print(f"\n DR-Bench Runner") print(f" {'='*50}") print(f" Agent: {agent.name}") print(f" Author: {agent.author}") print(f" Model: {llm_config.model}") print(f" Provider: {llm_config.provider}") print(f" Prompts: {len(prompts)}") print(f" Concurrent: {args.concurrent}") print(f" WebSearch: {args.websearch_url}") print(f" {'='*50}\n") # --- Run --- try: asyncio.run(runner.run(prompts, resume=not args.no_resume)) except KeyboardInterrupt: logger.warning("Interrupted. Saving progress...") finally: # Run evaluation if requested if args.evaluate and not args.no_evaluate: gt_dir = Path(args.ground_truth_dir) if gt_dir.exists() and any(gt_dir.glob("*.json")): print(f"\n Running coverage evaluation...") print(f" Ground truth: {gt_dir}") runner.evaluate_results( prompts=prompts, ground_truth_dir=str(gt_dir), max_workers=args.eval_workers, ) else: logger.warning( f"Ground truth directory not found or empty: {gt_dir}. " "Skipping evaluation." ) # Save results output_path = runner.save_results(args.output) runner.print_summary() # Update leaderboard if requested if args.update_leaderboard: leaderboard_path = str(PROJECT_ROOT / "data" / "leaderboard.json") runner.update_leaderboard(leaderboard_path) print(f"\n Leaderboard updated: {leaderboard_path}") # Cleanup asyncio.run(runner.cleanup()) if __name__ == "__main__": main()