Junzhe Li
yes
dba3d2e
"""Command-line interface for the benchmarking pipeline."""
import argparse
import sys
from .llm_providers.base import LLMProvider
from .benchmarks import *
from .runner import BenchmarkRunner, BenchmarkRunConfig
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
"""Create a benchmark based on the benchmark name.
Args:
benchmark_name (str): Name of the benchmark
data_dir (str): Directory containing benchmark data
**kwargs: Additional configuration parameters
Returns:
Benchmark: The configured benchmark
"""
benchmark_map = {
"rexvqa": ReXVQABenchmark,
"chestagentbench": ChestAgentBenchBenchmark,
}
if benchmark_name not in benchmark_map:
raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
benchmark_class = benchmark_map[benchmark_name]
return benchmark_class(data_dir, **kwargs)
def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider:
"""Create an LLM provider based on the model name and type.
Args:
provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma)
model_name (str): Name of the model
system_prompt (str): System prompt identifier to load from file
**kwargs: Additional configuration parameters
Returns:
LLMProvider: The configured LLM provider
"""
# Lazy imports to avoid slow startup
if provider_type == "openai":
from .llm_providers.openai_provider import OpenAIProvider
provider_class = OpenAIProvider
elif provider_type == "google":
from .llm_providers.google_provider import GoogleProvider
provider_class = GoogleProvider
elif provider_type == "openrouter":
from .llm_providers.openrouter_provider import OpenRouterProvider
provider_class = OpenRouterProvider
elif provider_type == "medrax":
from .llm_providers.medrax_provider import MedRAXProvider
provider_class = MedRAXProvider
elif provider_type == "medgemma":
from .llm_providers.medgemma_provider import MedGemmaProvider
provider_class = MedGemmaProvider
else:
raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma")
return provider_class(model_name, system_prompt, **kwargs)
def run_benchmark_command(args) -> None:
"""Run a benchmark."""
print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}")
# Create benchmark
benchmark_kwargs = {}
benchmark_kwargs["max_questions"] = args.max_questions
benchmark_kwargs["random_seed"] = args.random_seed
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
# Create LLM provider
provider_kwargs = {}
provider_kwargs["temperature"] = args.temperature
provider_kwargs["top_p"] = args.top_p
provider_kwargs["max_tokens"] = args.max_tokens
llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs)
# Create runner config
config = BenchmarkRunConfig(
benchmark_name=args.benchmark,
provider_name=args.provider,
model_name=args.model,
output_dir=args.output_dir,
max_questions=args.max_questions,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
concurrency=args.concurrency,
random_seed=args.random_seed
)
# Run benchmark
runner = BenchmarkRunner(config)
summary = runner.run_benchmark(benchmark, llm_provider)
print(summary)
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Run benchmark command
run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
run_parser.add_argument("--benchmark", required=True,
choices=["rexvqa", "chestagentbench"],
help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
run_parser.add_argument("--provider", required=True,
choices=["openai", "google", "openrouter", "medrax", "medgemma"],
help="LLM provider to use")
run_parser.add_argument("--model", required=True,
help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
run_parser.add_argument("--system-prompt", required=True,
choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT", "MEDGEMMA_PROMPT"],
help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
run_parser.add_argument("--data-dir", required=True,
help="Directory containing benchmark data files")
run_parser.add_argument("--output-dir", default="benchmark_results",
help="Output directory for results (default: benchmark_results)")
run_parser.add_argument("--max-questions", type=int,
help="Maximum number of questions to process (default: all)")
run_parser.add_argument("--temperature", type=float, default=1,
help="Model temperature for response generation (default: 0.7)")
run_parser.add_argument("--top-p", type=float, default=0.95,
help="Top-p nucleus sampling parameter (default: 0.95)")
run_parser.add_argument("--max-tokens", type=int, default=5000,
help="Maximum tokens per model response (default: 5000)")
run_parser.add_argument("--concurrency", type=int, default=1,
help="Number of datapoints to process in parallel (default: 1)")
run_parser.add_argument("--random-seed", type=int, default=42,
help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)")
run_parser.set_defaults(func=run_benchmark_command)
args = parser.parse_args()
if args.command is None:
parser.print_help()
return
try:
args.func(args)
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()