SDR-Arena / scripts /run_benchmark.py
behavior-in-the-wild's picture
Deploy SDR-Arena leaderboard
f9e2361 verified
#!/usr/bin/env python3
"""
DR-Bench CLI - Run a benchmark for an agent.
Usage:
python scripts/run_benchmark.py --agent agents/baseline_agent.py
python scripts/run_benchmark.py --agent agents/iterative_agent.py --prompts data/phi_prompts/
python scripts/run_benchmark.py --agent path/to/my_agent.py --limit 10
Environment variables (or .env file):
LLM_API_KEY - API key for the LLM provider
LLM_MODEL - Model name (default: gpt-4o)
LLM_PROVIDER - Provider: openai, azure, custom (default: openai)
LLM_BASE_URL - API base URL (default: https://api.openai.com/v1)
BENCHMARK_WEBSEARCH_URL - WebSearch service URL (default: http://localhost:8002)
"""
import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path
# Add project root to path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / ".env")
from benchmark.llm import LLMConfig
from benchmark.prompts import load_prompts
from benchmark.runner import BenchmarkRunner, load_agent_from_file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(
description="DR-Bench: Run a benchmark for a research agent",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
# Required
parser.add_argument(
"--agent", required=True,
help="Path to the agent Python file (must contain a BaseResearchAgent subclass)",
)
# Prompts
parser.add_argument(
"--prompts", nargs="*",
default=[str(PROJECT_ROOT / "data" / "phi_prompts")],
help="Path(s) to prompt JSON files or directories",
)
# LLM config
parser.add_argument("--model", default=None, help="Override LLM model name")
parser.add_argument("--provider", default=None, help="LLM provider: openai, azure, custom")
parser.add_argument("--api-key", default=None, help="LLM API key (or set LLM_API_KEY env)")
# WebSearch
parser.add_argument(
"--websearch-url",
default=os.getenv("BENCHMARK_WEBSEARCH_URL", "http://localhost:8002"),
help="WebSearch service URL",
)
parser.add_argument(
"--websearch-provider", default="brightdata",
help="Search provider (default: brightdata)",
)
parser.add_argument(
"--max-search-results", type=int, default=5,
help="Max results per search query",
)
# Execution
parser.add_argument("--concurrent", type=int, default=10, help="Max concurrent prompts")
parser.add_argument("--timeout", type=float, default=300.0, help="Timeout per prompt (seconds)")
parser.add_argument("--limit", type=int, default=None, help="Limit number of prompts")
# Output
parser.add_argument("-o", "--output", default=None, help="Output results JSON path")
parser.add_argument("--checkpoint", default=None, help="Checkpoint file path")
parser.add_argument("--no-resume", action="store_true", help="Don't resume from checkpoint")
parser.add_argument(
"--update-leaderboard", action="store_true",
help="Update data/leaderboard.json with results",
)
# Evaluation
parser.add_argument(
"--evaluate", action="store_true",
help="Run coverage evaluation after benchmark (requires ground truth data and eval LLM)",
)
parser.add_argument(
"--no-evaluate", action="store_true",
help="Skip evaluation even if ground truth is available",
)
parser.add_argument(
"--ground-truth-dir",
default=str(PROJECT_ROOT / "data" / "ground_truth"),
help="Path to ground truth pitch points directory",
)
parser.add_argument(
"--eval-workers", type=int, default=10,
help="Max parallel workers for evaluation (default: 10)",
)
args = parser.parse_args()
# --- Load LLM config ---
llm_config = LLMConfig.from_env()
if args.model:
llm_config.model = args.model
if args.provider:
llm_config.provider = args.provider
if args.api_key:
llm_config.api_key = args.api_key
if not llm_config.api_key:
print("ERROR: No LLM API key provided. Set LLM_API_KEY or use --api-key.")
sys.exit(1)
# --- Load agent ---
logger.info(f"Loading agent from {args.agent}")
agent = load_agent_from_file(args.agent, model_name=llm_config.model)
logger.info(f"Agent loaded: {agent.name} by {agent.author}")
# --- Load prompts ---
prompt_files = []
for p in args.prompts:
path = Path(p)
if path.is_dir():
prompt_files.extend(sorted(path.glob("*.json")))
elif path.exists():
prompt_files.append(path)
else:
logger.warning(f"Prompt path not found: {p}")
if not prompt_files:
print("ERROR: No prompt files found. Provide --prompts with valid paths.")
sys.exit(1)
prompts = load_prompts([str(f) for f in prompt_files])
if args.limit:
prompts = prompts[:args.limit]
logger.info(f"Loaded {len(prompts)} prompts from {len(prompt_files)} file(s)")
# --- Setup checkpoint ---
checkpoint = args.checkpoint
if checkpoint is None and not args.no_resume:
safe_name = "".join(c if c.isalnum() or c == "-" else "_" for c in agent.name)
checkpoint = str(PROJECT_ROOT / "data" / "results" / f"{safe_name}_checkpoint.json")
# --- Create runner ---
runner = BenchmarkRunner(
agent=agent,
llm_config=llm_config,
websearch_url=args.websearch_url,
websearch_provider=args.websearch_provider,
max_search_results=args.max_search_results,
max_concurrent=args.concurrent,
timeout_per_prompt=args.timeout,
output_dir=str(PROJECT_ROOT / "data" / "results"),
checkpoint_file=checkpoint if not args.no_resume else None,
)
# --- Print config ---
print(f"\n DR-Bench Runner")
print(f" {'='*50}")
print(f" Agent: {agent.name}")
print(f" Author: {agent.author}")
print(f" Model: {llm_config.model}")
print(f" Provider: {llm_config.provider}")
print(f" Prompts: {len(prompts)}")
print(f" Concurrent: {args.concurrent}")
print(f" WebSearch: {args.websearch_url}")
print(f" {'='*50}\n")
# --- Run ---
try:
asyncio.run(runner.run(prompts, resume=not args.no_resume))
except KeyboardInterrupt:
logger.warning("Interrupted. Saving progress...")
finally:
# Run evaluation if requested
if args.evaluate and not args.no_evaluate:
gt_dir = Path(args.ground_truth_dir)
if gt_dir.exists() and any(gt_dir.glob("*.json")):
print(f"\n Running coverage evaluation...")
print(f" Ground truth: {gt_dir}")
runner.evaluate_results(
prompts=prompts,
ground_truth_dir=str(gt_dir),
max_workers=args.eval_workers,
)
else:
logger.warning(
f"Ground truth directory not found or empty: {gt_dir}. "
"Skipping evaluation."
)
# Save results
output_path = runner.save_results(args.output)
runner.print_summary()
# Update leaderboard if requested
if args.update_leaderboard:
leaderboard_path = str(PROJECT_ROOT / "data" / "leaderboard.json")
runner.update_leaderboard(leaderboard_path)
print(f"\n Leaderboard updated: {leaderboard_path}")
# Cleanup
asyncio.run(runner.cleanup())
if __name__ == "__main__":
main()