Spaces:
Sleeping
Sleeping
File size: 6,651 Bytes
99f2cbc a7d0aad 99f2cbc 89321e2 99f2cbc 89321e2 99f2cbc a7d0aad 99f2cbc a7d0aad 89321e2 a7d0aad 89321e2 a7d0aad 99f2cbc 89321e2 99f2cbc 89321e2 a7d0aad 89321e2 99f2cbc 89321e2 2b26ed4 99f2cbc c7b65ec aa37a55 89321e2 99f2cbc 89321e2 99f2cbc a7d0aad 89321e2 a7d0aad 89321e2 a7d0aad 89321e2 a7d0aad dba3d2e a7d0aad 25fe9b8 a7d0aad 587ddab aa37a55 89321e2 99f2cbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""Command-line interface for the benchmarking pipeline."""
import argparse
import sys
from .llm_providers.base import LLMProvider
from .benchmarks import *
from .runner import BenchmarkRunner, BenchmarkRunConfig
def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
"""Create a benchmark based on the benchmark name.
Args:
benchmark_name (str): Name of the benchmark
data_dir (str): Directory containing benchmark data
**kwargs: Additional configuration parameters
Returns:
Benchmark: The configured benchmark
"""
benchmark_map = {
"rexvqa": ReXVQABenchmark,
"chestagentbench": ChestAgentBenchBenchmark,
}
if benchmark_name not in benchmark_map:
raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
benchmark_class = benchmark_map[benchmark_name]
return benchmark_class(data_dir, **kwargs)
def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider:
"""Create an LLM provider based on the model name and type.
Args:
provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma)
model_name (str): Name of the model
system_prompt (str): System prompt identifier to load from file
**kwargs: Additional configuration parameters
Returns:
LLMProvider: The configured LLM provider
"""
# Lazy imports to avoid slow startup
if provider_type == "openai":
from .llm_providers.openai_provider import OpenAIProvider
provider_class = OpenAIProvider
elif provider_type == "google":
from .llm_providers.google_provider import GoogleProvider
provider_class = GoogleProvider
elif provider_type == "openrouter":
from .llm_providers.openrouter_provider import OpenRouterProvider
provider_class = OpenRouterProvider
elif provider_type == "medrax":
from .llm_providers.medrax_provider import MedRAXProvider
provider_class = MedRAXProvider
elif provider_type == "medgemma":
from .llm_providers.medgemma_provider import MedGemmaProvider
provider_class = MedGemmaProvider
else:
raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma")
return provider_class(model_name, system_prompt, **kwargs)
def run_benchmark_command(args) -> None:
"""Run a benchmark."""
print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}")
# Create benchmark
benchmark_kwargs = {}
benchmark_kwargs["max_questions"] = args.max_questions
benchmark_kwargs["random_seed"] = args.random_seed
benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
# Create LLM provider
provider_kwargs = {}
provider_kwargs["temperature"] = args.temperature
provider_kwargs["top_p"] = args.top_p
provider_kwargs["max_tokens"] = args.max_tokens
llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs)
# Create runner config
config = BenchmarkRunConfig(
benchmark_name=args.benchmark,
provider_name=args.provider,
model_name=args.model,
output_dir=args.output_dir,
max_questions=args.max_questions,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
concurrency=args.concurrency,
random_seed=args.random_seed
)
# Run benchmark
runner = BenchmarkRunner(config)
summary = runner.run_benchmark(benchmark, llm_provider)
print(summary)
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Run benchmark command
run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
run_parser.add_argument("--benchmark", required=True,
choices=["rexvqa", "chestagentbench"],
help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
run_parser.add_argument("--provider", required=True,
choices=["openai", "google", "openrouter", "medrax", "medgemma"],
help="LLM provider to use")
run_parser.add_argument("--model", required=True,
help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
run_parser.add_argument("--system-prompt", required=True,
choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT", "MEDGEMMA_PROMPT"],
help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
run_parser.add_argument("--data-dir", required=True,
help="Directory containing benchmark data files")
run_parser.add_argument("--output-dir", default="benchmark_results",
help="Output directory for results (default: benchmark_results)")
run_parser.add_argument("--max-questions", type=int,
help="Maximum number of questions to process (default: all)")
run_parser.add_argument("--temperature", type=float, default=1,
help="Model temperature for response generation (default: 0.7)")
run_parser.add_argument("--top-p", type=float, default=0.95,
help="Top-p nucleus sampling parameter (default: 0.95)")
run_parser.add_argument("--max-tokens", type=int, default=5000,
help="Maximum tokens per model response (default: 5000)")
run_parser.add_argument("--concurrency", type=int, default=1,
help="Number of datapoints to process in parallel (default: 1)")
run_parser.add_argument("--random-seed", type=int, default=42,
help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)")
run_parser.set_defaults(func=run_benchmark_command)
args = parser.parse_args()
if args.command is None:
parser.print_help()
return
try:
args.func(args)
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main() |