File size: 7,900 Bytes
f9e2361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
#!/usr/bin/env python3
"""
DR-Bench CLI - Run a benchmark for an agent.

Usage:
    python scripts/run_benchmark.py --agent agents/baseline_agent.py
    python scripts/run_benchmark.py --agent agents/iterative_agent.py --prompts data/phi_prompts/
    python scripts/run_benchmark.py --agent path/to/my_agent.py --limit 10

Environment variables (or .env file):
    LLM_API_KEY         - API key for the LLM provider
    LLM_MODEL           - Model name (default: gpt-4o)
    LLM_PROVIDER        - Provider: openai, azure, custom (default: openai)
    LLM_BASE_URL        - API base URL (default: https://api.openai.com/v1)
    BENCHMARK_WEBSEARCH_URL - WebSearch service URL (default: http://localhost:8002)
"""

import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / ".env")

from benchmark.llm import LLMConfig
from benchmark.prompts import load_prompts
from benchmark.runner import BenchmarkRunner, load_agent_from_file


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(
        description="DR-Bench: Run a benchmark for a research agent",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )

    # Required
    parser.add_argument(
        "--agent", required=True,
        help="Path to the agent Python file (must contain a BaseResearchAgent subclass)",
    )

    # Prompts
    parser.add_argument(
        "--prompts", nargs="*",
        default=[str(PROJECT_ROOT / "data" / "phi_prompts")],
        help="Path(s) to prompt JSON files or directories",
    )

    # LLM config
    parser.add_argument("--model", default=None, help="Override LLM model name")
    parser.add_argument("--provider", default=None, help="LLM provider: openai, azure, custom")
    parser.add_argument("--api-key", default=None, help="LLM API key (or set LLM_API_KEY env)")

    # WebSearch
    parser.add_argument(
        "--websearch-url",
        default=os.getenv("BENCHMARK_WEBSEARCH_URL", "http://localhost:8002"),
        help="WebSearch service URL",
    )
    parser.add_argument(
        "--websearch-provider", default="brightdata",
        help="Search provider (default: brightdata)",
    )
    parser.add_argument(
        "--max-search-results", type=int, default=5,
        help="Max results per search query",
    )

    # Execution
    parser.add_argument("--concurrent", type=int, default=10, help="Max concurrent prompts")
    parser.add_argument("--timeout", type=float, default=300.0, help="Timeout per prompt (seconds)")
    parser.add_argument("--limit", type=int, default=None, help="Limit number of prompts")

    # Output
    parser.add_argument("-o", "--output", default=None, help="Output results JSON path")
    parser.add_argument("--checkpoint", default=None, help="Checkpoint file path")
    parser.add_argument("--no-resume", action="store_true", help="Don't resume from checkpoint")
    parser.add_argument(
        "--update-leaderboard", action="store_true",
        help="Update data/leaderboard.json with results",
    )

    # Evaluation
    parser.add_argument(
        "--evaluate", action="store_true",
        help="Run coverage evaluation after benchmark (requires ground truth data and eval LLM)",
    )
    parser.add_argument(
        "--no-evaluate", action="store_true",
        help="Skip evaluation even if ground truth is available",
    )
    parser.add_argument(
        "--ground-truth-dir",
        default=str(PROJECT_ROOT / "data" / "ground_truth"),
        help="Path to ground truth pitch points directory",
    )
    parser.add_argument(
        "--eval-workers", type=int, default=10,
        help="Max parallel workers for evaluation (default: 10)",
    )

    args = parser.parse_args()

    # --- Load LLM config ---
    llm_config = LLMConfig.from_env()
    if args.model:
        llm_config.model = args.model
    if args.provider:
        llm_config.provider = args.provider
    if args.api_key:
        llm_config.api_key = args.api_key

    if not llm_config.api_key:
        print("ERROR: No LLM API key provided. Set LLM_API_KEY or use --api-key.")
        sys.exit(1)

    # --- Load agent ---
    logger.info(f"Loading agent from {args.agent}")
    agent = load_agent_from_file(args.agent, model_name=llm_config.model)
    logger.info(f"Agent loaded: {agent.name} by {agent.author}")

    # --- Load prompts ---
    prompt_files = []
    for p in args.prompts:
        path = Path(p)
        if path.is_dir():
            prompt_files.extend(sorted(path.glob("*.json")))
        elif path.exists():
            prompt_files.append(path)
        else:
            logger.warning(f"Prompt path not found: {p}")

    if not prompt_files:
        print("ERROR: No prompt files found. Provide --prompts with valid paths.")
        sys.exit(1)

    prompts = load_prompts([str(f) for f in prompt_files])
    if args.limit:
        prompts = prompts[:args.limit]

    logger.info(f"Loaded {len(prompts)} prompts from {len(prompt_files)} file(s)")

    # --- Setup checkpoint ---
    checkpoint = args.checkpoint
    if checkpoint is None and not args.no_resume:
        safe_name = "".join(c if c.isalnum() or c == "-" else "_" for c in agent.name)
        checkpoint = str(PROJECT_ROOT / "data" / "results" / f"{safe_name}_checkpoint.json")

    # --- Create runner ---
    runner = BenchmarkRunner(
        agent=agent,
        llm_config=llm_config,
        websearch_url=args.websearch_url,
        websearch_provider=args.websearch_provider,
        max_search_results=args.max_search_results,
        max_concurrent=args.concurrent,
        timeout_per_prompt=args.timeout,
        output_dir=str(PROJECT_ROOT / "data" / "results"),
        checkpoint_file=checkpoint if not args.no_resume else None,
    )

    # --- Print config ---
    print(f"\n  DR-Bench Runner")
    print(f"  {'='*50}")
    print(f"  Agent:     {agent.name}")
    print(f"  Author:    {agent.author}")
    print(f"  Model:     {llm_config.model}")
    print(f"  Provider:  {llm_config.provider}")
    print(f"  Prompts:   {len(prompts)}")
    print(f"  Concurrent: {args.concurrent}")
    print(f"  WebSearch: {args.websearch_url}")
    print(f"  {'='*50}\n")

    # --- Run ---
    try:
        asyncio.run(runner.run(prompts, resume=not args.no_resume))
    except KeyboardInterrupt:
        logger.warning("Interrupted. Saving progress...")
    finally:
        # Run evaluation if requested
        if args.evaluate and not args.no_evaluate:
            gt_dir = Path(args.ground_truth_dir)
            if gt_dir.exists() and any(gt_dir.glob("*.json")):
                print(f"\n  Running coverage evaluation...")
                print(f"  Ground truth: {gt_dir}")
                runner.evaluate_results(
                    prompts=prompts,
                    ground_truth_dir=str(gt_dir),
                    max_workers=args.eval_workers,
                )
            else:
                logger.warning(
                    f"Ground truth directory not found or empty: {gt_dir}. "
                    "Skipping evaluation."
                )

        # Save results
        output_path = runner.save_results(args.output)
        runner.print_summary()

        # Update leaderboard if requested
        if args.update_leaderboard:
            leaderboard_path = str(PROJECT_ROOT / "data" / "leaderboard.json")
            runner.update_leaderboard(leaderboard_path)
            print(f"\n  Leaderboard updated: {leaderboard_path}")

        # Cleanup
        asyncio.run(runner.cleanup())


if __name__ == "__main__":
    main()