File size: 6,651 Bytes
99f2cbc
 
 
 
 
a7d0aad
99f2cbc
 
 
 
89321e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f2cbc
 
 
89321e2
99f2cbc
a7d0aad
99f2cbc
 
 
 
 
a7d0aad
 
 
 
 
 
 
 
 
 
 
 
 
89321e2
 
 
a7d0aad
89321e2
a7d0aad
 
99f2cbc
 
 
 
89321e2
99f2cbc
 
 
89321e2
 
a7d0aad
89321e2
 
 
 
 
 
 
99f2cbc
 
 
89321e2
2b26ed4
99f2cbc
 
 
 
c7b65ec
aa37a55
89321e2
 
99f2cbc
 
 
 
89321e2
 
99f2cbc
 
 
 
 
 
 
 
a7d0aad
89321e2
 
 
a7d0aad
89321e2
a7d0aad
89321e2
 
a7d0aad
dba3d2e
a7d0aad
 
 
 
 
 
 
25fe9b8
a7d0aad
 
 
587ddab
 
aa37a55
 
89321e2
 
99f2cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Command-line interface for the benchmarking pipeline."""

import argparse
import sys

from .llm_providers.base import LLMProvider
from .benchmarks import *
from .runner import BenchmarkRunner, BenchmarkRunConfig


def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
    """Create a benchmark based on the benchmark name.
    
    Args:
        benchmark_name (str): Name of the benchmark
        data_dir (str): Directory containing benchmark data
        **kwargs: Additional configuration parameters
        
    Returns:
        Benchmark: The configured benchmark
    """
    benchmark_map = {
        "rexvqa": ReXVQABenchmark,
        "chestagentbench": ChestAgentBenchBenchmark,
    }
    
    if benchmark_name not in benchmark_map:
        raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
    
    benchmark_class = benchmark_map[benchmark_name]
    return benchmark_class(data_dir, **kwargs)


def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider:
    """Create an LLM provider based on the model name and type.
    
    Args:
        provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma)
        model_name (str): Name of the model
        system_prompt (str): System prompt identifier to load from file
        **kwargs: Additional configuration parameters
        
    Returns:
        LLMProvider: The configured LLM provider
    """
    # Lazy imports to avoid slow startup
    if provider_type == "openai":
        from .llm_providers.openai_provider import OpenAIProvider
        provider_class = OpenAIProvider
    elif provider_type == "google":
        from .llm_providers.google_provider import GoogleProvider
        provider_class = GoogleProvider
    elif provider_type == "openrouter":
        from .llm_providers.openrouter_provider import OpenRouterProvider
        provider_class = OpenRouterProvider
    elif provider_type == "medrax":
        from .llm_providers.medrax_provider import MedRAXProvider
        provider_class = MedRAXProvider
    elif provider_type == "medgemma":
        from .llm_providers.medgemma_provider import MedGemmaProvider
        provider_class = MedGemmaProvider
    else:
        raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma")
    
    return provider_class(model_name, system_prompt, **kwargs)


def run_benchmark_command(args) -> None:
    """Run a benchmark."""
    print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}")
    
    # Create benchmark
    benchmark_kwargs = {}
    benchmark_kwargs["max_questions"] = args.max_questions
    benchmark_kwargs["random_seed"] = args.random_seed
    benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)

    # Create LLM provider
    provider_kwargs = {}
    provider_kwargs["temperature"] = args.temperature
    provider_kwargs["top_p"] = args.top_p
    provider_kwargs["max_tokens"] = args.max_tokens
    llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs)
    
    # Create runner config
    config = BenchmarkRunConfig(
        benchmark_name=args.benchmark,
        provider_name=args.provider,
        model_name=args.model,
        output_dir=args.output_dir,
        max_questions=args.max_questions,
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
        concurrency=args.concurrency,
        random_seed=args.random_seed
    )
    
    # Run benchmark
    runner = BenchmarkRunner(config)
    summary = runner.run_benchmark(benchmark, llm_provider)
    print(summary)


def main():
    """Main CLI entry point."""
    parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline")
    subparsers = parser.add_subparsers(dest="command", help="Available commands")
    
    # Run benchmark command
    run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
    run_parser.add_argument("--benchmark", required=True, 
                           choices=["rexvqa", "chestagentbench"], 
                           help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
    run_parser.add_argument("--provider", required=True, 
                           choices=["openai", "google", "openrouter", "medrax", "medgemma"], 
                           help="LLM provider to use")
    run_parser.add_argument("--model", required=True, 
                           help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
    run_parser.add_argument("--system-prompt", required=True, 
                           choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT", "MEDGEMMA_PROMPT"], 
                           help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
    run_parser.add_argument("--data-dir", required=True, 
                           help="Directory containing benchmark data files")
    run_parser.add_argument("--output-dir", default="benchmark_results", 
                           help="Output directory for results (default: benchmark_results)")
    run_parser.add_argument("--max-questions", type=int, 
                           help="Maximum number of questions to process (default: all)")
    run_parser.add_argument("--temperature", type=float, default=1, 
                           help="Model temperature for response generation (default: 0.7)")
    run_parser.add_argument("--top-p", type=float, default=0.95, 
                           help="Top-p nucleus sampling parameter (default: 0.95)")
    run_parser.add_argument("--max-tokens", type=int, default=5000, 
                           help="Maximum tokens per model response (default: 5000)")
    run_parser.add_argument("--concurrency", type=int, default=1,
                            help="Number of datapoints to process in parallel (default: 1)")
    run_parser.add_argument("--random-seed", type=int, default=42, 
                           help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)")
    
    run_parser.set_defaults(func=run_benchmark_command)
    
    args = parser.parse_args()
    
    if args.command is None:
        parser.print_help()
        return
    
    try:
        args.func(args)
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()