Spaces:
Sleeping
Sleeping
| """Command-line interface for the benchmarking pipeline.""" | |
| import argparse | |
| import sys | |
| from .llm_providers.base import LLMProvider | |
| from .benchmarks import * | |
| from .runner import BenchmarkRunner, BenchmarkRunConfig | |
| def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark: | |
| """Create a benchmark based on the benchmark name. | |
| Args: | |
| benchmark_name (str): Name of the benchmark | |
| data_dir (str): Directory containing benchmark data | |
| **kwargs: Additional configuration parameters | |
| Returns: | |
| Benchmark: The configured benchmark | |
| """ | |
| benchmark_map = { | |
| "rexvqa": ReXVQABenchmark, | |
| "chestagentbench": ChestAgentBenchBenchmark, | |
| } | |
| if benchmark_name not in benchmark_map: | |
| raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}") | |
| benchmark_class = benchmark_map[benchmark_name] | |
| return benchmark_class(data_dir, **kwargs) | |
| def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider: | |
| """Create an LLM provider based on the model name and type. | |
| Args: | |
| provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma) | |
| model_name (str): Name of the model | |
| system_prompt (str): System prompt identifier to load from file | |
| **kwargs: Additional configuration parameters | |
| Returns: | |
| LLMProvider: The configured LLM provider | |
| """ | |
| # Lazy imports to avoid slow startup | |
| if provider_type == "openai": | |
| from .llm_providers.openai_provider import OpenAIProvider | |
| provider_class = OpenAIProvider | |
| elif provider_type == "google": | |
| from .llm_providers.google_provider import GoogleProvider | |
| provider_class = GoogleProvider | |
| elif provider_type == "openrouter": | |
| from .llm_providers.openrouter_provider import OpenRouterProvider | |
| provider_class = OpenRouterProvider | |
| elif provider_type == "medrax": | |
| from .llm_providers.medrax_provider import MedRAXProvider | |
| provider_class = MedRAXProvider | |
| elif provider_type == "medgemma": | |
| from .llm_providers.medgemma_provider import MedGemmaProvider | |
| provider_class = MedGemmaProvider | |
| else: | |
| raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma") | |
| return provider_class(model_name, system_prompt, **kwargs) | |
| def run_benchmark_command(args) -> None: | |
| """Run a benchmark.""" | |
| print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}") | |
| # Create benchmark | |
| benchmark_kwargs = {} | |
| benchmark_kwargs["max_questions"] = args.max_questions | |
| benchmark_kwargs["random_seed"] = args.random_seed | |
| benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs) | |
| # Create LLM provider | |
| provider_kwargs = {} | |
| provider_kwargs["temperature"] = args.temperature | |
| provider_kwargs["top_p"] = args.top_p | |
| provider_kwargs["max_tokens"] = args.max_tokens | |
| llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs) | |
| # Create runner config | |
| config = BenchmarkRunConfig( | |
| benchmark_name=args.benchmark, | |
| provider_name=args.provider, | |
| model_name=args.model, | |
| output_dir=args.output_dir, | |
| max_questions=args.max_questions, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| max_tokens=args.max_tokens, | |
| concurrency=args.concurrency, | |
| random_seed=args.random_seed | |
| ) | |
| # Run benchmark | |
| runner = BenchmarkRunner(config) | |
| summary = runner.run_benchmark(benchmark, llm_provider) | |
| print(summary) | |
| def main(): | |
| """Main CLI entry point.""" | |
| parser = argparse.ArgumentParser(description="MedRAX Benchmarking Pipeline") | |
| subparsers = parser.add_subparsers(dest="command", help="Available commands") | |
| # Run benchmark command | |
| run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation") | |
| run_parser.add_argument("--benchmark", required=True, | |
| choices=["rexvqa", "chestagentbench"], | |
| help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)") | |
| run_parser.add_argument("--provider", required=True, | |
| choices=["openai", "google", "openrouter", "medrax", "medgemma"], | |
| help="LLM provider to use") | |
| run_parser.add_argument("--model", required=True, | |
| help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)") | |
| run_parser.add_argument("--system-prompt", required=True, | |
| choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT", "MEDGEMMA_PROMPT"], | |
| help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)") | |
| run_parser.add_argument("--data-dir", required=True, | |
| help="Directory containing benchmark data files") | |
| run_parser.add_argument("--output-dir", default="benchmark_results", | |
| help="Output directory for results (default: benchmark_results)") | |
| run_parser.add_argument("--max-questions", type=int, | |
| help="Maximum number of questions to process (default: all)") | |
| run_parser.add_argument("--temperature", type=float, default=1, | |
| help="Model temperature for response generation (default: 0.7)") | |
| run_parser.add_argument("--top-p", type=float, default=0.95, | |
| help="Top-p nucleus sampling parameter (default: 0.95)") | |
| run_parser.add_argument("--max-tokens", type=int, default=5000, | |
| help="Maximum tokens per model response (default: 5000)") | |
| run_parser.add_argument("--concurrency", type=int, default=1, | |
| help="Number of datapoints to process in parallel (default: 1)") | |
| run_parser.add_argument("--random-seed", type=int, default=42, | |
| help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)") | |
| run_parser.set_defaults(func=run_benchmark_command) | |
| args = parser.parse_args() | |
| if args.command is None: | |
| parser.print_help() | |
| return | |
| try: | |
| args.func(args) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |