Junzhe Li commited on
Commit
89321e2
·
1 Parent(s): b93ad3f

revamped benchmarking suite

Browse files
benchmarking/benchmarks/base.py CHANGED
@@ -39,7 +39,7 @@ class Benchmark(ABC):
39
  self._load_data()
40
  self._shuffle_data()
41
 
42
- self.max_questions = kwargs.get("max_questions", None)
43
  if self.max_questions:
44
  self.data_points = self.data_points[:self.max_questions]
45
  print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
@@ -51,12 +51,13 @@ class Benchmark(ABC):
51
  """Load benchmark data from the data directory."""
52
  pass
53
 
54
- def _shuffle_data(self, random_seed: Optional[int]=42) -> None:
55
  """Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
56
 
57
  This method is called automatically after data loading to ensure
58
  reproducible benchmark runs when a random seed is specified.
59
  """
 
60
  random.seed(random_seed)
61
  random.shuffle(self.data_points)
62
  print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
@@ -99,21 +100,3 @@ class Benchmark(ABC):
99
  for i in range(len(self)):
100
  yield self.get_data_point(i)
101
 
102
- def validate_images(self) -> Tuple[List[str], List[str]]:
103
- """Validate that all image paths exist.
104
-
105
- Returns:
106
- Tuple[List[str], List[str]]: Tuple of (valid_image_paths, invalid_image_paths)
107
- """
108
- valid_images = []
109
- invalid_images = []
110
-
111
- for dp in self:
112
- if dp.images:
113
- for image_path in dp.images:
114
- if Path(image_path).exists():
115
- valid_images.append(image_path)
116
- else:
117
- invalid_images.append(image_path)
118
-
119
- return valid_images, invalid_images
 
39
  self._load_data()
40
  self._shuffle_data()
41
 
42
+ self.max_questions = self.config.get("max_questions", None)
43
  if self.max_questions:
44
  self.data_points = self.data_points[:self.max_questions]
45
  print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
 
51
  """Load benchmark data from the data directory."""
52
  pass
53
 
54
+ def _shuffle_data(self) -> None:
55
  """Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
56
 
57
  This method is called automatically after data loading to ensure
58
  reproducible benchmark runs when a random seed is specified.
59
  """
60
+ random_seed = self.config.get("random_seed", 42)
61
  random.seed(random_seed)
62
  random.shuffle(self.data_points)
63
  print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
 
100
  for i in range(len(self)):
101
  yield self.get_data_point(i)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -48,11 +48,11 @@ class ReXVQABenchmark(Benchmark):
48
  max_questions (int): Maximum number of questions to load (default: None, load all)
49
  images_dir (str): Directory containing extracted PNG images (default: None)
50
  """
51
- super().__init__(data_dir, **kwargs)
52
-
53
  self.split = kwargs.get("split", "test")
54
  self.images_dir = f"{data_dir}/images/deid_png"
55
 
 
 
56
  def _load_data(self) -> None:
57
  """Load ReXVQA data from HuggingFace."""
58
  try:
 
48
  max_questions (int): Maximum number of questions to load (default: None, load all)
49
  images_dir (str): Directory containing extracted PNG images (default: None)
50
  """
 
 
51
  self.split = kwargs.get("split", "test")
52
  self.images_dir = f"{data_dir}/images/deid_png"
53
 
54
+ super().__init__(data_dir, **kwargs)
55
+
56
  def _load_data(self) -> None:
57
  """Load ReXVQA data from HuggingFace."""
58
  try:
benchmarking/cli.py CHANGED
@@ -8,12 +8,35 @@ from .benchmarks import *
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
- def create_llm_provider(model_name: str, provider_type: str, system_prompt: str, **kwargs) -> LLMProvider:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """Create an LLM provider based on the model name and type.
13
 
14
  Args:
 
15
  model_name (str): Name of the model
16
- provider_type (str): Type of provider (openai, google, openrouter, medrax)
17
  system_prompt (str): System prompt identifier to load from file
18
  **kwargs: Additional configuration parameters
19
 
@@ -33,85 +56,50 @@ def create_llm_provider(model_name: str, provider_type: str, system_prompt: str,
33
  elif provider_type == "medrax":
34
  from .llm_providers.medrax_provider import MedRAXProvider
35
  provider_class = MedRAXProvider
 
 
 
36
  else:
37
- raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax")
38
 
39
  return provider_class(model_name, system_prompt, **kwargs)
40
 
41
 
42
- def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
43
- """Create a benchmark based on the benchmark name.
44
-
45
- Args:
46
- benchmark_name (str): Name of the benchmark
47
- data_dir (str): Directory containing benchmark data
48
- **kwargs: Additional configuration parameters
49
-
50
- Returns:
51
- Benchmark: The configured benchmark
52
- """
53
- benchmark_map = {
54
- "rexvqa": ReXVQABenchmark,
55
- "chestagentbench": ChestAgentBenchBenchmark,
56
- }
57
-
58
- if benchmark_name not in benchmark_map:
59
- raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
60
-
61
- benchmark_class = benchmark_map[benchmark_name]
62
- return benchmark_class(data_dir, **kwargs)
63
-
64
-
65
  def run_benchmark_command(args) -> None:
66
  """Run a benchmark."""
67
- print(f"Running benchmark: {args.benchmark} with model: {args.model}")
68
-
69
- # Create LLM provider
70
- provider_kwargs = {}
71
-
72
- llm_provider = create_llm_provider(model_name=args.model, provider_type=args.provider, system_prompt=args.system_prompt, **provider_kwargs)
73
 
74
  # Create benchmark
75
  benchmark_kwargs = {}
76
- if args.random_seed is not None:
77
- benchmark_kwargs["random_seed"] = args.random_seed
78
-
79
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
 
 
 
 
 
 
 
80
 
81
  # Create runner config
82
  config = BenchmarkRunConfig(
 
83
  provider_name=args.provider,
84
  model_name=args.model,
85
- benchmark_name=args.benchmark,
86
  output_dir=args.output_dir,
87
  max_questions=args.max_questions,
88
  temperature=args.temperature,
89
  top_p=args.top_p,
90
  max_tokens=args.max_tokens,
91
- concurrency=args.concurrency
 
92
  )
93
 
94
  # Run benchmark
95
  runner = BenchmarkRunner(config)
96
- summary = runner.run_benchmark(llm_provider, benchmark)
97
-
98
- print("\n" + "="*50)
99
- print("BENCHMARK COMPLETED")
100
- print("="*50)
101
-
102
- # Check if benchmark run was successful
103
- if "error" in summary:
104
- print(f"Error: {summary['error']}")
105
- return
106
-
107
- # Print results
108
- print(f"Model: {args.model}")
109
- print(f"Benchmark: {args.benchmark}")
110
- print(f"Total Questions: {summary['results']['total_questions']}")
111
- print(f"Correct Answers: {summary['results']['correct_answers']}")
112
- print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
113
- print(f"Total Duration: {summary['results']['total_duration']:.2f}s")
114
- print(f"Results saved to: {summary['results_file']}")
115
 
116
 
117
  def main():
@@ -121,17 +109,17 @@ def main():
121
 
122
  # Run benchmark command
123
  run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
124
- run_parser.add_argument("--model", required=True,
125
- help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
 
126
  run_parser.add_argument("--provider", required=True,
127
- choices=["openai", "google", "openrouter", "medrax"],
128
  help="LLM provider to use")
 
 
129
  run_parser.add_argument("--system-prompt", required=True,
130
  choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
131
  help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
132
- run_parser.add_argument("--benchmark", required=True,
133
- choices=["rexvqa", "chestagentbench"],
134
- help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
135
  run_parser.add_argument("--data-dir", required=True,
136
  help="Directory containing benchmark data files")
137
  run_parser.add_argument("--output-dir", default="benchmark_results",
@@ -144,10 +132,10 @@ def main():
144
  help="Top-p nucleus sampling parameter (default: 0.95)")
145
  run_parser.add_argument("--max-tokens", type=int, default=5000,
146
  help="Maximum tokens per model response (default: 5000)")
147
- run_parser.add_argument("--random-seed", type=int, default=42,
148
- help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
149
  run_parser.add_argument("--concurrency", type=int, default=1,
150
  help="Number of datapoints to process in parallel (default: 1)")
 
 
151
 
152
  run_parser.set_defaults(func=run_benchmark_command)
153
 
 
8
  from .runner import BenchmarkRunner, BenchmarkRunConfig
9
 
10
 
11
+ def create_benchmark(benchmark_name: str, data_dir: str, **kwargs) -> Benchmark:
12
+ """Create a benchmark based on the benchmark name.
13
+
14
+ Args:
15
+ benchmark_name (str): Name of the benchmark
16
+ data_dir (str): Directory containing benchmark data
17
+ **kwargs: Additional configuration parameters
18
+
19
+ Returns:
20
+ Benchmark: The configured benchmark
21
+ """
22
+ benchmark_map = {
23
+ "rexvqa": ReXVQABenchmark,
24
+ "chestagentbench": ChestAgentBenchBenchmark,
25
+ }
26
+
27
+ if benchmark_name not in benchmark_map:
28
+ raise ValueError(f"Unknown benchmark: {benchmark_name}. Available: {list(benchmark_map.keys())}")
29
+
30
+ benchmark_class = benchmark_map[benchmark_name]
31
+ return benchmark_class(data_dir, **kwargs)
32
+
33
+
34
+ def create_llm_provider(provider_type: str, model_name: str, system_prompt: str, **kwargs) -> LLMProvider:
35
  """Create an LLM provider based on the model name and type.
36
 
37
  Args:
38
+ provider_type (str): Type of provider (openai, google, openrouter, medrax, medgemma)
39
  model_name (str): Name of the model
 
40
  system_prompt (str): System prompt identifier to load from file
41
  **kwargs: Additional configuration parameters
42
 
 
56
  elif provider_type == "medrax":
57
  from .llm_providers.medrax_provider import MedRAXProvider
58
  provider_class = MedRAXProvider
59
+ elif provider_type == "medgemma":
60
+ from .llm_providers.medgemma_provider import MedGemmaProvider
61
+ provider_class = MedGemmaProvider
62
  else:
63
+ raise ValueError(f"Unknown provider type: {provider_type}. Available: openai, google, openrouter, medrax, medgemma")
64
 
65
  return provider_class(model_name, system_prompt, **kwargs)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def run_benchmark_command(args) -> None:
69
  """Run a benchmark."""
70
+ print(f"Running benchmark: {args.benchmark} with provider: {args.provider}, model: {args.model}")
 
 
 
 
 
71
 
72
  # Create benchmark
73
  benchmark_kwargs = {}
74
+ benchmark_kwargs["max_questions"] = args.max_questions
75
+ benchmark_kwargs["random_seed"] = args.random_seed
 
76
  benchmark = create_benchmark(benchmark_name=args.benchmark, data_dir=args.data_dir, **benchmark_kwargs)
77
+
78
+ # Create LLM provider
79
+ provider_kwargs = {}
80
+ provider_kwargs["temperature"] = args.temperature
81
+ provider_kwargs["top_p"] = args.top_p
82
+ provider_kwargs["max_tokens"] = args.max_tokens
83
+ llm_provider = create_llm_provider(provider_type=args.provider, model_name=args.model, system_prompt=args.system_prompt, **provider_kwargs)
84
 
85
  # Create runner config
86
  config = BenchmarkRunConfig(
87
+ benchmark_name=args.benchmark,
88
  provider_name=args.provider,
89
  model_name=args.model,
 
90
  output_dir=args.output_dir,
91
  max_questions=args.max_questions,
92
  temperature=args.temperature,
93
  top_p=args.top_p,
94
  max_tokens=args.max_tokens,
95
+ concurrency=args.concurrency,
96
+ random_seed=args.random_seed
97
  )
98
 
99
  # Run benchmark
100
  runner = BenchmarkRunner(config)
101
+ summary = runner.run_benchmark(benchmark, llm_provider)
102
+ print(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def main():
 
109
 
110
  # Run benchmark command
111
  run_parser = subparsers.add_parser("run", help="Run a benchmark evaluation")
112
+ run_parser.add_argument("--benchmark", required=True,
113
+ choices=["rexvqa", "chestagentbench"],
114
+ help="Benchmark dataset: rexvqa (radiology VQA) or chestagentbench (chest X-ray reasoning)")
115
  run_parser.add_argument("--provider", required=True,
116
+ choices=["openai", "google", "openrouter", "medrax", "medgemma"],
117
  help="LLM provider to use")
118
+ run_parser.add_argument("--model", required=True,
119
+ help="Model name (e.g., gpt-4o, gpt-4.1-2025-04-14, gemini-2.5-pro)")
120
  run_parser.add_argument("--system-prompt", required=True,
121
  choices=["MEDICAL_ASSISTANT", "CHESTAGENTBENCH_PROMPT"],
122
  help="System prompt: MEDICAL_ASSISTANT (general) or CHESTAGENTBENCH_PROMPT (benchmarks)")
 
 
 
123
  run_parser.add_argument("--data-dir", required=True,
124
  help="Directory containing benchmark data files")
125
  run_parser.add_argument("--output-dir", default="benchmark_results",
 
132
  help="Top-p nucleus sampling parameter (default: 0.95)")
133
  run_parser.add_argument("--max-tokens", type=int, default=5000,
134
  help="Maximum tokens per model response (default: 5000)")
 
 
135
  run_parser.add_argument("--concurrency", type=int, default=1,
136
  help="Number of datapoints to process in parallel (default: 1)")
137
+ run_parser.add_argument("--random-seed", type=int, default=42,
138
+ help="Random seed for shuffling benchmark data (enables reproducible runs, default: 42)")
139
 
140
  run_parser.set_defaults(func=run_benchmark_command)
141
 
benchmarking/llm_providers/__init__.py CHANGED
@@ -5,6 +5,17 @@ from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
7
  from .openrouter_provider import OpenRouterProvider
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  __all__ = [
10
  "LLMProvider",
@@ -14,4 +25,7 @@ __all__ = [
14
  "GoogleProvider",
15
  "MedRAXProvider",
16
  "OpenRouterProvider",
 
 
 
17
  ]
 
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
7
  from .openrouter_provider import OpenRouterProvider
8
+ from .medgemma_provider import MedGemmaProvider
9
+
10
+ # QwenProvider is optional - only import if dependencies are compatible
11
+ try:
12
+ from .qwen_provider import QwenProvider
13
+ QWEN_AVAILABLE = True
14
+ except ImportError as e:
15
+ QWEN_AVAILABLE = False
16
+ QwenProvider = None
17
+ print(f"QwenProvider not available: {e}")
18
+ print("To use Qwen models, upgrade transformers: pip install --upgrade git+https://github.com/huggingface/transformers")
19
 
20
  __all__ = [
21
  "LLMProvider",
 
25
  "GoogleProvider",
26
  "MedRAXProvider",
27
  "OpenRouterProvider",
28
+ "MedGemmaProvider",
29
+ "QwenProvider",
30
+ "QWEN_AVAILABLE",
31
  ]
benchmarking/llm_providers/base.py CHANGED
@@ -13,10 +13,6 @@ class LLMRequest:
13
  """Request to an LLM provider."""
14
  text: str
15
  images: Optional[List[str]] = None # List of image paths
16
- temperature: float = 0.7
17
- top_p: float = 0.95
18
- max_tokens: int = 5000
19
- additional_params: Optional[Dict[str, Any]] = None
20
 
21
 
22
  @dataclass
@@ -44,15 +40,17 @@ class LLMProvider(ABC):
44
  **kwargs: Additional configuration parameters
45
  """
46
  self.model_name = model_name
47
- self.config = kwargs
48
- self.prompt_name = system_prompt # Store the original prompt identifier
 
 
49
 
50
  # Load system prompt content from file
51
  try:
52
- prompts = load_prompts_from_file("medrax/docs/system_prompts.txt")
53
- self.system_prompt = prompts.get(system_prompt, None)
54
  if self.system_prompt is None:
55
- print(f"Warning: System prompt '{system_prompt}' not found in medrax/docs/system_prompts.txt.")
56
  except Exception as e:
57
  print(f"Error loading system prompt: {e}")
58
  self.system_prompt = None
@@ -85,9 +83,7 @@ class LLMProvider(ABC):
85
  try:
86
  # Simple test request
87
  test_request = LLMRequest(
88
- text="Hello! What model are you? Tell me your full specification.",
89
- temperature=0.5,
90
- max_tokens=1000
91
  )
92
  response = self.generate_response(test_request)
93
  return response.content is not None and len(response.content.strip()) > 0
@@ -95,6 +91,23 @@ class LLMProvider(ABC):
95
  print(f"Connection test failed: {e}")
96
  return False
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _encode_image(self, image_path: str) -> str:
99
  """Encode image to base64 string.
100
 
@@ -110,23 +123,30 @@ class LLMProvider(ABC):
110
  except Exception as e:
111
  print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
112
  raise
113
-
114
- def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
115
- """Validate that image paths exist and are readable.
116
 
117
  Args:
118
- image_paths (List[str]): List of image paths to validate
119
 
120
  Returns:
121
- List[str]: List of valid image paths
122
  """
123
- valid_paths = []
124
- for path in image_paths:
125
- if Path(path).exists() and Path(path).is_file():
126
- valid_paths.append(path)
127
- else:
128
- print(f"Warning: Image path does not exist: {path}")
129
- return valid_paths
 
 
 
 
 
 
 
130
 
131
  def __str__(self) -> str:
132
  """String representation of the provider."""
 
13
  """Request to an LLM provider."""
14
  text: str
15
  images: Optional[List[str]] = None # List of image paths
 
 
 
 
16
 
17
 
18
  @dataclass
 
40
  **kwargs: Additional configuration parameters
41
  """
42
  self.model_name = model_name
43
+ self.temperature = kwargs.get("temperature", 0.7)
44
+ self.top_p = kwargs.get("top_p", 0.95)
45
+ self.max_tokens = kwargs.get("max_tokens", 5000)
46
+ self.prompt_name = system_prompt
47
 
48
  # Load system prompt content from file
49
  try:
50
+ prompts = load_prompts_from_file("benchmarking/system_prompts.txt")
51
+ self.system_prompt = prompts.get(self.prompt_name, None)
52
  if self.system_prompt is None:
53
+ print(f"Warning: System prompt '{system_prompt}' not found in benchmarking/system_prompts.txt.")
54
  except Exception as e:
55
  print(f"Error loading system prompt: {e}")
56
  self.system_prompt = None
 
83
  try:
84
  # Simple test request
85
  test_request = LLMRequest(
86
+ text="Hello! What model are you? Tell me your full specification."
 
 
87
  )
88
  response = self.generate_response(test_request)
89
  return response.content is not None and len(response.content.strip()) > 0
 
91
  print(f"Connection test failed: {e}")
92
  return False
93
 
94
+ def _validate_image_paths(self, image_paths: List[str]) -> List[str]:
95
+ """Validate that image paths exist and are readable.
96
+
97
+ Args:
98
+ image_paths (List[str]): List of image paths to validate
99
+
100
+ Returns:
101
+ List[str]: List of valid image paths
102
+ """
103
+ valid_paths = []
104
+ for path in image_paths:
105
+ if Path(path).exists() and Path(path).is_file():
106
+ valid_paths.append(path)
107
+ else:
108
+ print(f"Warning: Image path does not exist: {path}")
109
+ return valid_paths
110
+
111
  def _encode_image(self, image_path: str) -> str:
112
  """Encode image to base64 string.
113
 
 
123
  except Exception as e:
124
  print(f"ERROR: _encode_image failed for {image_path} (type: {type(image_path)}): {e}")
125
  raise
126
+
127
+ def _get_image_mime_type(self, image_path: str) -> str:
128
+ """Detect the MIME type of an image file.
129
 
130
  Args:
131
+ image_path (str): Path to the image file
132
 
133
  Returns:
134
+ str: MIME type (e.g., 'image/png', 'image/jpeg')
135
  """
136
+ # Get file extension
137
+ ext = Path(image_path).suffix.lower()
138
+
139
+ # Map extensions to MIME types
140
+ mime_types = {
141
+ '.png': 'image/png',
142
+ '.jpg': 'image/jpeg',
143
+ '.jpeg': 'image/jpeg',
144
+ '.gif': 'image/gif',
145
+ '.webp': 'image/webp',
146
+ '.bmp': 'image/bmp',
147
+ }
148
+
149
+ return mime_types.get(ext, 'image/png') # Default to PNG for medical images
150
 
151
  def __str__(self) -> str:
152
  """String representation of the provider."""
benchmarking/llm_providers/google_provider.py CHANGED
@@ -14,6 +14,10 @@ class GoogleProvider(LLMProvider):
14
 
15
  def _setup(self) -> None:
16
  """Set up Google langchain client."""
 
 
 
 
17
  api_key = os.getenv("GOOGLE_API_KEY")
18
  if not api_key:
19
  raise ValueError("GOOGLE_API_KEY environment variable is required")
@@ -21,7 +25,10 @@ class GoogleProvider(LLMProvider):
21
  # Create ChatGoogleGenerativeAI instance
22
  self.client = ChatGoogleGenerativeAI(
23
  model=self.model_name,
24
- google_api_key=api_key
 
 
 
25
  )
26
 
27
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
@@ -54,9 +61,10 @@ class GoogleProvider(LLMProvider):
54
  try:
55
  # For langchain Google, pass image data as base64
56
  image_b64 = self._encode_image(image_path)
 
57
  content_parts.append({
58
  "type": "image_url",
59
- "image_url": f"data:image/jpeg;base64,{image_b64}"
60
  })
61
  except Exception as e:
62
  print(f"Error reading image {image_path}: {e}")
@@ -68,18 +76,13 @@ class GoogleProvider(LLMProvider):
68
 
69
  # Make API call using langchain
70
  try:
71
- # Update client parameters for this request
72
- self.client.temperature = request.temperature
73
- self.client.max_output_tokens = request.max_tokens
74
- self.client.top_p = request.top_p
75
-
76
  response = self.client.invoke(messages)
 
77
 
 
78
  duration = time.time() - start_time
79
 
80
- # Extract response content
81
- content = response.content if response.content else ""
82
-
83
  # Get usage information if available
84
  usage = {}
85
  if hasattr(response, 'usage_metadata') and response.usage_metadata:
 
14
 
15
  def _setup(self) -> None:
16
  """Set up Google langchain client."""
17
+ # Set provider name
18
+ self.provider_name = "google"
19
+
20
+ # Get API key from environment variable
21
  api_key = os.getenv("GOOGLE_API_KEY")
22
  if not api_key:
23
  raise ValueError("GOOGLE_API_KEY environment variable is required")
 
25
  # Create ChatGoogleGenerativeAI instance
26
  self.client = ChatGoogleGenerativeAI(
27
  model=self.model_name,
28
+ google_api_key=api_key,
29
+ temperature=self.temperature,
30
+ max_output_tokens=self.max_tokens,
31
+ top_p=self.top_p
32
  )
33
 
34
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
 
61
  try:
62
  # For langchain Google, pass image data as base64
63
  image_b64 = self._encode_image(image_path)
64
+ mime_type = self._get_image_mime_type(image_path)
65
  content_parts.append({
66
  "type": "image_url",
67
+ "image_url": f"data:{mime_type};base64,{image_b64}"
68
  })
69
  except Exception as e:
70
  print(f"Error reading image {image_path}: {e}")
 
76
 
77
  # Make API call using langchain
78
  try:
79
+ # Make API call
 
 
 
 
80
  response = self.client.invoke(messages)
81
+ content = response.content if response.content else ""
82
 
83
+ # Calculate duration
84
  duration = time.time() - start_time
85
 
 
 
 
86
  # Get usage information if available
87
  usage = {}
88
  if hasattr(response, 'usage_metadata') and response.usage_metadata:
benchmarking/llm_providers/medgemma_provider.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MedGemma LLM provider implementation using the MedGemma FastAPI service."""
2
+
3
+ import os
4
+ import time
5
+ import httpx
6
+ from typing import Optional
7
+ from pathlib import Path
8
+ from tenacity import retry, wait_exponential, stop_after_attempt
9
+
10
+ from .base import LLMProvider, LLMRequest, LLMResponse
11
+
12
+
13
+ class MedGemmaProvider(LLMProvider):
14
+ """MedGemma LLM provider that communicates with the MedGemma FastAPI service.
15
+
16
+ This provider wraps Google's MedGemma 4B model as an LLMProvider for benchmarking.
17
+ It communicates with a running MedGemma FastAPI service on localhost:8002.
18
+
19
+ MedGemma is a specialized multimodal AI model trained on medical images and text.
20
+ It provides expert-level analysis for chest X-rays, dermatology images,
21
+ ophthalmology images, and histopathology slides.
22
+
23
+ Requirements:
24
+ - MedGemma FastAPI service must be running on the configured API URL
25
+ - Default URL: http://localhost:8002
26
+ - Can be overridden via MEDGEMMA_API_URL environment variable
27
+ """
28
+
29
+ def __init__(self, model_name: str, system_prompt: str, **kwargs):
30
+ """Initialize MedGemma provider.
31
+
32
+ Args:
33
+ model_name (str): Model name (for consistency with other providers)
34
+ system_prompt (str): System prompt identifier to load from file
35
+ **kwargs: Additional configuration parameters
36
+ - api_url: URL of the MedGemma FastAPI service
37
+ - max_new_tokens: Maximum tokens to generate (default: 300)
38
+ """
39
+ # Extract MedGemma-specific config before calling super().__init__
40
+ self.api_url = kwargs.pop('api_url', None) or os.getenv('MEDGEMMA_API_URL', 'http://localhost:8002')
41
+ self.max_new_tokens = kwargs.pop('max_new_tokens', 300)
42
+ self.client = None
43
+
44
+ # Call parent constructor
45
+ super().__init__(model_name, system_prompt, **kwargs)
46
+
47
+ def _setup(self) -> None:
48
+ """Set up httpx client for communicating with MedGemma API."""
49
+ # Create httpx client with reasonable timeouts
50
+ timeout_config = httpx.Timeout(
51
+ timeout=300.0, # 5 minutes for inference
52
+ connect=10.0 # 10 seconds to establish connection
53
+ )
54
+ self.client = httpx.Client(timeout=timeout_config)
55
+
56
+ # Test connection to MedGemma service
57
+ try:
58
+ response = self.client.get(f"{self.api_url}/docs")
59
+ if response.status_code != 200:
60
+ print(f"Warning: MedGemma API at {self.api_url} may not be running (status: {response.status_code})")
61
+ except httpx.ConnectError:
62
+ print(f"Warning: Could not connect to MedGemma API at {self.api_url}")
63
+ print("Please ensure the MedGemma FastAPI service is running:")
64
+ print(f" python medrax/tools/vqa/medgemma/medgemma.py")
65
+
66
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
67
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
68
+ """Generate response using MedGemma API.
69
+
70
+ Args:
71
+ request (LLMRequest): The request containing text, images, and parameters
72
+
73
+ Returns:
74
+ LLMResponse: The response from MedGemma
75
+ """
76
+ start_time = time.time()
77
+
78
+ if self.client is None:
79
+ return LLMResponse(
80
+ content="Error: MedGemma client not initialized",
81
+ duration=time.time() - start_time
82
+ )
83
+
84
+ try:
85
+ # Validate and prepare images
86
+ if not request.images:
87
+ return LLMResponse(
88
+ content="Error: MedGemma requires at least one image",
89
+ duration=time.time() - start_time
90
+ )
91
+
92
+ valid_images = self._validate_image_paths(request.images)
93
+ if not valid_images:
94
+ return LLMResponse(
95
+ content="Error: No valid image paths provided",
96
+ duration=time.time() - start_time
97
+ )
98
+
99
+ # Prepare multipart form data
100
+ files_to_send = []
101
+ for image_path in valid_images:
102
+ try:
103
+ # Detect correct MIME type based on file extension
104
+ ext = Path(image_path).suffix.lower()
105
+ mime_type = "image/png" if ext == ".png" else "image/jpeg"
106
+
107
+ # Read image file
108
+ with open(image_path, "rb") as f:
109
+ image_data = f.read()
110
+
111
+ # Add to files list
112
+ files_to_send.append(
113
+ ("images", (os.path.basename(image_path), image_data, mime_type))
114
+ )
115
+ except Exception as e:
116
+ print(f"Error reading image {image_path}: {e}")
117
+ continue
118
+
119
+ if not files_to_send:
120
+ return LLMResponse(
121
+ content="Error: Failed to read any image files",
122
+ duration=time.time() - start_time
123
+ )
124
+
125
+ # Prepare form data
126
+ # Use system_prompt if provided, otherwise use default
127
+ system_prompt_text = self.system_prompt if self.system_prompt else "You are an expert radiologist."
128
+
129
+ # Override max_new_tokens if provided in request
130
+ max_tokens = getattr(request, 'max_tokens', self.max_new_tokens)
131
+
132
+ data = {
133
+ "prompt": request.text,
134
+ "system_prompt": system_prompt_text,
135
+ "max_new_tokens": max_tokens,
136
+ }
137
+
138
+ # Make API request
139
+ response = self.client.post(
140
+ f"{self.api_url}/analyze-images/",
141
+ data=data,
142
+ files=files_to_send,
143
+ )
144
+
145
+ # Check for errors
146
+ response.raise_for_status()
147
+
148
+ # Parse response
149
+ response_data = response.json()
150
+ content = response_data.get("response", "")
151
+ metadata = response_data.get("metadata", {})
152
+
153
+ duration = time.time() - start_time
154
+
155
+ # MedGemma doesn't provide token usage, but we can include request info
156
+ usage = {
157
+ "num_images": len(valid_images),
158
+ "max_new_tokens": max_tokens,
159
+ }
160
+
161
+ return LLMResponse(
162
+ content=content,
163
+ usage=usage,
164
+ duration=duration
165
+ )
166
+
167
+ except httpx.TimeoutException as e:
168
+ duration = time.time() - start_time
169
+ error_msg = f"MedGemma API request timed out after {duration:.1f}s. The server might be overloaded or the model is taking too long to process."
170
+ print(f"Error: {error_msg}")
171
+ return LLMResponse(
172
+ content=f"Error: {error_msg}",
173
+ duration=duration
174
+ )
175
+
176
+ except httpx.ConnectError as e:
177
+ duration = time.time() - start_time
178
+ error_msg = f"Could not connect to MedGemma API at {self.api_url}. Please ensure the service is running."
179
+ print(f"Error: {error_msg}")
180
+ return LLMResponse(
181
+ content=f"Error: {error_msg}",
182
+ duration=duration
183
+ )
184
+
185
+ except httpx.HTTPStatusError as e:
186
+ duration = time.time() - start_time
187
+ error_msg = f"MedGemma API returned error {e.response.status_code}: {e.response.text}"
188
+ print(f"Error: {error_msg}")
189
+ return LLMResponse(
190
+ content=f"Error: {error_msg}",
191
+ duration=duration
192
+ )
193
+
194
+ except Exception as e:
195
+ duration = time.time() - start_time
196
+ error_msg = f"Unexpected error calling MedGemma API: {str(e)}"
197
+ print(f"Error: {error_msg}")
198
+ return LLMResponse(
199
+ content=f"Error: {error_msg}",
200
+ duration=duration
201
+ )
202
+
203
+ def test_connection(self) -> bool:
204
+ """Test the connection to the MedGemma API service.
205
+
206
+ Returns:
207
+ bool: True if connection is successful and service is responding
208
+ """
209
+ try:
210
+ # Try to access the API docs endpoint
211
+ response = self.client.get(f"{self.api_url}/docs")
212
+ return response.status_code == 200
213
+ except Exception as e:
214
+ print(f"MedGemma connection test failed: {e}")
215
+ return False
216
+
217
+ def __del__(self):
218
+ """Clean up httpx client on deletion."""
219
+ if self.client is not None:
220
+ self.client.close()
221
+
222
+
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -21,7 +21,9 @@ class MedRAXProvider(LLMProvider):
21
  system_prompt (str): System prompt to use
22
  **kwargs: Additional configuration parameters
23
  """
24
- self.model_name = model_name
 
 
25
  self.agent = None
26
  self.tools_dict = None
27
 
@@ -33,15 +35,15 @@ class MedRAXProvider(LLMProvider):
33
  print("Starting server...")
34
 
35
  selected_tools = [
36
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
37
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
- "XRayPhraseGroundingTool", # For locating described features in X-rays
40
- "MedGemmaVQATool", # Google MedGemma VQA tool
41
- "XRayVQATool", # For visual question answering on X-rays
42
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
- "WebBrowserTool", # For web browsing and search capabilities
44
- "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
45
  ]
46
 
47
  rag_config = RAGConfig(
@@ -62,14 +64,15 @@ class MedRAXProvider(LLMProvider):
62
  model_kwargs = {}
63
 
64
  agent, tools_dict = initialize_agent(
65
- prompt_file="medrax/docs/system_prompts.txt",
66
  tools_to_use=selected_tools,
67
- model_dir="/scratch/ssd004/scratch/victorli/model-weights",
68
  temp_dir="temp", # Change this to the path of the temporary directory
69
  device="cuda:0",
70
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
71
- temperature=1.0,
72
- top_p=0.95,
 
73
  model_kwargs=model_kwargs,
74
  rag_config=rag_config,
75
  system_prompt=self.prompt_name,
@@ -107,32 +110,34 @@ class MedRAXProvider(LLMProvider):
107
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
108
 
109
  if request.images:
 
 
 
 
110
  valid_images = self._validate_image_paths(request.images)
111
  print(f"Processing {len(valid_images)} images")
112
- for i, image_path in enumerate(valid_images):
113
- # Add image path message for tools
114
- messages.append(HumanMessage(content=f"image_path: {image_path}"))
115
-
116
- # Add image content for multimodal LLM
 
 
117
  try:
118
- with open(image_path, "rb") as img_file:
119
- img_base64 = self._encode_image(image_path)
120
 
121
- messages.append(HumanMessage(content=[{
122
  "type": "image_url",
123
- "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
124
- }]))
125
  except Exception as e:
126
  print(f"ERROR: Image encoding failed for {image_path}: {e}")
127
  raise
128
-
129
- # Add text message
130
- if request.images:
131
- # If there are images, add text as part of multimodal content
132
- messages.append(HumanMessage(content=[{
133
- "type": "text",
134
- "text": request.text
135
- }]))
136
  else:
137
  # If no images, add text as simple string
138
  messages.append(HumanMessage(content=request.text))
@@ -216,8 +221,67 @@ class MedRAXProvider(LLMProvider):
216
  "type": type(msg).__name__,
217
  "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
218
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  chunk_messages.append(msg_info)
220
- print(f"Message in chunk: {msg_info}")
 
 
 
 
 
 
 
 
 
 
221
  serializable_chunk["messages"] = chunk_messages
222
 
223
  return serializable_chunk
 
21
  system_prompt (str): System prompt to use
22
  **kwargs: Additional configuration parameters
23
  """
24
+ # Set provider name
25
+ self.provider_name = "medrax"
26
+
27
  self.agent = None
28
  self.tools_dict = None
29
 
 
35
  print("Starting server...")
36
 
37
  selected_tools = [
38
+ # "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
39
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
40
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
41
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
42
+ # "MedGemmaVQATool", # Google MedGemma VQA tool
43
+ # "XRayVQATool", # For visual question answering on X-rays
44
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
45
+ # "WebBrowserTool", # For web browsing and search capabilities
46
+ # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
47
  ]
48
 
49
  rag_config = RAGConfig(
 
64
  model_kwargs = {}
65
 
66
  agent, tools_dict = initialize_agent(
67
+ prompt_file="benchmarking/system_prompts.txt",
68
  tools_to_use=selected_tools,
69
+ model_dir="/home/lijunzh3/scratch/MedRAX2/model-weights",
70
  temp_dir="temp", # Change this to the path of the temporary directory
71
  device="cuda:0",
72
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
73
+ temperature=self.temperature,
74
+ top_p=self.top_p,
75
+ max_tokens=self.max_tokens,
76
  model_kwargs=model_kwargs,
77
  rag_config=rag_config,
78
  system_prompt=self.prompt_name,
 
110
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
111
 
112
  if request.images:
113
+ # Build multimodal content with text and images
114
+ content = [{"type": "text", "text": request.text}]
115
+
116
+ # Validate image paths
117
  valid_images = self._validate_image_paths(request.images)
118
  print(f"Processing {len(valid_images)} images")
119
+
120
+ # Add image paths for tools
121
+ for image_path in valid_images:
122
+ content.append({"type": "text", "text": f"image_path: {image_path}"})
123
+
124
+ # Add image content for multimodal LLM
125
+ for image_path in valid_images:
126
  try:
127
+ img_base64 = self._encode_image(image_path)
128
+ mime_type = self._get_image_mime_type(image_path)
129
 
130
+ content.append({
131
  "type": "image_url",
132
+ "image_url": {"url": f"data:{mime_type};base64,{img_base64}"}
133
+ })
134
  except Exception as e:
135
  print(f"ERROR: Image encoding failed for {image_path}: {e}")
136
  raise
137
+
138
+ # Create single multimodal message
139
+ messages.append(HumanMessage(content=content))
140
+
 
 
 
 
141
  else:
142
  # If no images, add text as simple string
143
  messages.append(HumanMessage(content=request.text))
 
221
  "type": type(msg).__name__,
222
  "content": str(msg.content) if hasattr(msg, 'content') else str(msg)
223
  }
224
+
225
+ # Extract response metadata (reasoning/thinking traces)
226
+ if hasattr(msg, 'response_metadata') and msg.response_metadata:
227
+ try:
228
+ msg_info["response_metadata"] = dict(msg.response_metadata)
229
+
230
+ # Extract specific reasoning fields for easier access
231
+ # Gemini 2.0 Flash Thinking uses 'thoughts'
232
+ if "thoughts" in msg.response_metadata:
233
+ msg_info["thinking"] = msg.response_metadata["thoughts"]
234
+
235
+ # DeepSeek-R1 and similar models use 'reasoning_content'
236
+ if "reasoning_content" in msg.response_metadata:
237
+ msg_info["reasoning"] = msg.response_metadata["reasoning_content"]
238
+
239
+ # Some models expose thinking in other fields
240
+ if "extended_thinking" in msg.response_metadata:
241
+ msg_info["extended_thinking"] = msg.response_metadata["extended_thinking"]
242
+ except Exception as e:
243
+ print(f"Warning: Could not serialize response_metadata: {e}")
244
+
245
+ # Extract usage metadata (reasoning tokens for o1/o3 models)
246
+ if hasattr(msg, 'usage_metadata') and msg.usage_metadata:
247
+ try:
248
+ msg_info["usage_metadata"] = dict(msg.usage_metadata)
249
+
250
+ # Highlight reasoning tokens if present
251
+ if isinstance(msg.usage_metadata, dict) and "reasoning_tokens" in msg.usage_metadata:
252
+ msg_info["reasoning_tokens"] = msg.usage_metadata["reasoning_tokens"]
253
+ except Exception as e:
254
+ print(f"Warning: Could not serialize usage_metadata: {e}")
255
+
256
+ # Extract additional kwargs (some models put reasoning here)
257
+ if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
258
+ try:
259
+ # Filter for reasoning-related fields
260
+ reasoning_kwargs = {}
261
+ for key in ['thinking', 'reasoning', 'thoughts', 'chain_of_thought']:
262
+ if key in msg.additional_kwargs:
263
+ reasoning_kwargs[key] = msg.additional_kwargs[key]
264
+
265
+ if reasoning_kwargs:
266
+ msg_info["additional_reasoning"] = reasoning_kwargs
267
+
268
+ # Include full additional_kwargs for completeness (may contain other useful info)
269
+ msg_info["additional_kwargs"] = dict(msg.additional_kwargs)
270
+ except Exception as e:
271
+ print(f"Warning: Could not serialize additional_kwargs: {e}")
272
+
273
  chunk_messages.append(msg_info)
274
+
275
+ # Enhanced logging for debugging
276
+ log_msg = f"Message in chunk: type={msg_info['type']}"
277
+ if "thinking" in msg_info:
278
+ log_msg += f", has_thinking=True (length={len(str(msg_info['thinking']))})"
279
+ if "reasoning" in msg_info:
280
+ log_msg += f", has_reasoning=True (length={len(str(msg_info['reasoning']))})"
281
+ if "reasoning_tokens" in msg_info:
282
+ log_msg += f", reasoning_tokens={msg_info['reasoning_tokens']}"
283
+ print(log_msg)
284
+
285
  serializable_chunk["messages"] = chunk_messages
286
 
287
  return serializable_chunk
benchmarking/llm_providers/openai_provider.py CHANGED
@@ -14,21 +14,28 @@ class OpenAIProvider(LLMProvider):
14
 
15
  def _setup(self) -> None:
16
  """Set up OpenAI langchain client."""
 
 
 
 
17
  api_key = os.getenv("OPENAI_API_KEY")
18
- if not api_key:
19
- raise ValueError("OPENAI_API_KEY environment variable is required")
20
-
21
  base_url = os.getenv("OPENAI_BASE_URL")
 
 
22
 
23
- # Create ChatOpenAI instance
24
  kwargs = {
25
  "model": self.model_name,
26
  "api_key": api_key,
 
 
27
  }
28
-
29
  if base_url:
30
  kwargs["base_url"] = base_url
31
-
 
 
 
32
  self.client = ChatOpenAI(**kwargs)
33
 
34
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
@@ -63,10 +70,11 @@ class OpenAIProvider(LLMProvider):
63
  for image_path in valid_images:
64
  try:
65
  image_b64 = self._encode_image(image_path)
 
66
  user_content.append({
67
  "type": "image_url",
68
  "image_url": {
69
- "url": f"data:image/jpeg;base64,{image_b64}",
70
  "detail": "high"
71
  }
72
  })
@@ -75,13 +83,8 @@ class OpenAIProvider(LLMProvider):
75
 
76
  messages.append(HumanMessage(content=user_content))
77
 
78
- # Make API call using langchain
79
  try:
80
- # Update client parameters for this request
81
- self.client.temperature = request.temperature
82
- self.client.max_tokens = request.max_tokens
83
- self.client.top_p = request.top_p
84
-
85
  response = self.client.invoke(messages)
86
 
87
  duration = time.time() - start_time
 
14
 
15
  def _setup(self) -> None:
16
  """Set up OpenAI langchain client."""
17
+ # Set provider name
18
+ self.provider_name = "openai"
19
+
20
+ # Get API key and base URL from environment variables
21
  api_key = os.getenv("OPENAI_API_KEY")
 
 
 
22
  base_url = os.getenv("OPENAI_BASE_URL")
23
+ if not api_key or not base_url:
24
+ raise ValueError("OPENAI_API_KEY and OPENAI_BASE_URL environment variables are required")
25
 
26
+ # Construct kwargs for ChatOpenAI instance
27
  kwargs = {
28
  "model": self.model_name,
29
  "api_key": api_key,
30
+ "temperature": self.temperature,
31
+ "max_tokens": self.max_tokens
32
  }
 
33
  if base_url:
34
  kwargs["base_url"] = base_url
35
+ if self.model_name.startswith("gpt-5") or self.model_name.startswith("o1") or self.model_name.startswith("o3"):
36
+ kwargs["reasoning_effort"] = "high"
37
+
38
+ # Create ChatOpenAI instance
39
  self.client = ChatOpenAI(**kwargs)
40
 
41
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
 
70
  for image_path in valid_images:
71
  try:
72
  image_b64 = self._encode_image(image_path)
73
+ mime_type = self._get_image_mime_type(image_path)
74
  user_content.append({
75
  "type": "image_url",
76
  "image_url": {
77
+ "url": f"data:{mime_type};base64,{image_b64}",
78
  "detail": "high"
79
  }
80
  })
 
83
 
84
  messages.append(HumanMessage(content=user_content))
85
 
86
+ # Make API call
87
  try:
 
 
 
 
 
88
  response = self.client.invoke(messages)
89
 
90
  duration = time.time() - start_time
benchmarking/llm_providers/openrouter_provider.py CHANGED
@@ -13,11 +13,16 @@ class OpenRouterProvider(LLMProvider):
13
 
14
  def _setup(self) -> None:
15
  """Set up OpenRouter client models."""
 
 
 
 
16
  api_key = os.getenv("OPENROUTER_API_KEY")
17
- if not api_key:
18
- raise ValueError("OPENROUTER_API_KEY environment variable is required for xAI Grok via OpenRouter.")
19
  base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
20
- # Use OpenAI SDK with OpenRouter endpoint
 
 
 
21
  self.client = OpenAI(api_key=api_key, base_url=base_url)
22
 
23
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
@@ -45,10 +50,11 @@ class OpenRouterProvider(LLMProvider):
45
  for image_path in valid_images:
46
  try:
47
  image_b64 = self._encode_image(image_path)
 
48
  user_content.append({
49
  "type": "image_url",
50
  "image_url": {
51
- "url": f"data:image/jpeg;base64,{image_b64}",
52
  "detail": "high"
53
  }
54
  })
@@ -57,14 +63,14 @@ class OpenRouterProvider(LLMProvider):
57
 
58
  messages.append({"role": "user", "content": user_content})
59
 
 
60
  try:
61
  response = self.client.chat.completions.create(
62
  model=self.model_name,
63
  messages=messages,
64
- temperature=request.temperature,
65
- top_p=request.top_p,
66
- max_tokens=request.max_tokens,
67
- **(request.additional_params or {})
68
  )
69
  duration = time.time() - start_time
70
  content = response.choices[0].message.content if response.choices else ""
 
13
 
14
  def _setup(self) -> None:
15
  """Set up OpenRouter client models."""
16
+ # Set provider name
17
+ self.provider_name = "openrouter"
18
+
19
+ # Get API key and base URL from environment variables
20
  api_key = os.getenv("OPENROUTER_API_KEY")
 
 
21
  base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
22
+ if not api_key or not base_url:
23
+ raise ValueError("OPENROUTER_API_KEY and OPENROUTER_BASE_URL environment variables are required")
24
+
25
+ # Create OpenAI client with OpenRouter endpoint
26
  self.client = OpenAI(api_key=api_key, base_url=base_url)
27
 
28
  @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
 
50
  for image_path in valid_images:
51
  try:
52
  image_b64 = self._encode_image(image_path)
53
+ mime_type = self._get_image_mime_type(image_path)
54
  user_content.append({
55
  "type": "image_url",
56
  "image_url": {
57
+ "url": f"data:{mime_type};base64,{image_b64}",
58
  "detail": "high"
59
  }
60
  })
 
63
 
64
  messages.append({"role": "user", "content": user_content})
65
 
66
+ # Make API call
67
  try:
68
  response = self.client.chat.completions.create(
69
  model=self.model_name,
70
  messages=messages,
71
+ temperature=self.temperature,
72
+ max_tokens=self.max_tokens,
73
+ top_p=self.top_p
 
74
  )
75
  duration = time.time() - start_time
76
  content = response.choices[0].message.content if response.choices else ""
benchmarking/runner.py CHANGED
@@ -32,16 +32,17 @@ class BenchmarkResult:
32
  @dataclass
33
  class BenchmarkRunConfig:
34
  """Configuration for a benchmark run."""
 
35
  provider_name: str
36
  model_name: str
37
- benchmark_name: str
38
  output_dir: str
39
  max_questions: Optional[int] = None
40
  temperature: float = 0.7
41
  top_p: float = 0.95
42
  max_tokens: int = 5000
43
- additional_params: Optional[Dict[str, Any]] = None
44
  concurrency: int = 1
 
 
45
 
46
 
47
  class BenchmarkRunner:
@@ -59,11 +60,10 @@ class BenchmarkRunner:
59
  self.output_dir.mkdir(parents=True, exist_ok=True)
60
 
61
  # Generate unique run ID
62
- self.run_id = f"{config.benchmark_name}_{config.provider_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
63
 
64
  # Set up logging
65
  self._setup_logging()
66
-
67
  self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
68
 
69
  def _setup_logging(self) -> None:
@@ -91,34 +91,28 @@ class BenchmarkRunner:
91
 
92
  def run_benchmark(
93
  self,
94
- llm_provider: LLMProvider,
95
  benchmark: Benchmark,
 
96
  ) -> Dict[str, Any]:
97
  """Run a benchmark against an LLM provider.
98
 
99
  Args:
100
- llm_provider (LLMProvider): The LLM provider to test
101
  benchmark (Benchmark): The benchmark to run
 
102
 
103
  Returns:
104
  Dict[str, Any]: Summary of benchmark results
105
  """
106
  self.logger.info(f"Starting benchmark run: {self.run_id}")
107
- self.logger.info(f"Model: {llm_provider.model_name}")
108
  self.logger.info(f"Benchmark: {benchmark}")
 
 
109
 
110
  # Test provider connection
111
  if not llm_provider.test_connection():
112
  self.logger.error("LLM provider connection test failed")
113
  return {"error": "LLM provider connection test failed"}
114
 
115
- # Get data points to process
116
- total_questions = len(benchmark)
117
- max_questions = self.config.max_questions or total_questions
118
- end_index = min(max_questions, total_questions)
119
-
120
- self.logger.info(f"Processing questions {0} to {end_index-1} of {total_questions}")
121
-
122
  # Initialize counters
123
  processed = 0
124
  correct = 0
@@ -127,29 +121,10 @@ class BenchmarkRunner:
127
  # Determine concurrency
128
  max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
129
 
130
- # Prefetch data points to avoid potential thread-safety issues inside benchmark access
131
- data_points = []
132
- for i in range(0, end_index):
133
- try:
134
- data_points.append(benchmark.get_data_point(i))
135
- except Exception as e:
136
- self.logger.error(f"Error fetching data point {i}: {e}")
137
- error_result = BenchmarkResult(
138
- data_point_id=f"error_{i}",
139
- question="",
140
- model_answer="",
141
- correct_answer="",
142
- is_correct=False,
143
- duration=0.0,
144
- error=str(e)
145
- )
146
- self.results.append(error_result)
147
- self._save_individual_result(error_result)
148
-
149
  # Process data points in parallel using a bounded thread pool
150
- with tqdm(total=end_index, desc="Processing questions") as pbar:
151
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
152
- future_to_index = {executor.submit(self._process_data_point, llm_provider, dp): idx for idx, dp in enumerate(data_points)}
153
  for future in as_completed(future_to_index):
154
  idx = future_to_index[future]
155
  try:
@@ -184,30 +159,29 @@ class BenchmarkRunner:
184
  accuracy = (correct / processed) * 100
185
  avg_duration = total_duration / processed if processed > 0 else 0.0
186
  self.logger.info(
187
- f"Progress: {processed}/{end_index} | "
188
  f"Accuracy: {accuracy:.2f}% | "
189
  f"Avg Duration: {avg_duration:.2f}s"
190
  )
191
 
192
  # Save final results
193
  summary = self._save_final_results(benchmark)
194
-
195
  self.logger.info(f"Benchmark run completed: {self.run_id}")
196
- self.logger.info(f"Final accuracy: {summary['results']['accuracy']:.2f}%")
197
- self.logger.info(f"Total duration: {summary['results']['total_duration']:.2f}s")
198
-
199
  return summary
200
 
201
  def _process_data_point(
202
  self,
203
- llm_provider: LLMProvider,
204
  data_point: BenchmarkDataPoint,
 
205
  ) -> BenchmarkResult:
206
  """Process a single data point.
207
 
208
  Args:
209
- llm_provider (LLMProvider): The LLM provider to use
210
  data_point (BenchmarkDataPoint): The data point to process
 
211
 
212
  Returns:
213
  BenchmarkResult: Result of processing the data point
@@ -215,14 +189,10 @@ class BenchmarkRunner:
215
  start_time = time.time()
216
 
217
  try:
218
- # Create request
219
  request = LLMRequest(
220
  text=data_point.text,
221
- images=data_point.images,
222
- temperature=self.config.temperature,
223
- top_p=self.config.top_p,
224
- max_tokens=self.config.max_tokens,
225
- additional_params=self.config.additional_params
226
  )
227
 
228
  # Get response from LLM
@@ -232,10 +202,12 @@ class BenchmarkRunner:
232
  model_answer = self._extract_answer(response.content)
233
 
234
  # Check if correct
235
- is_correct = self._is_correct_answer(model_answer, data_point.correct_answer)
236
 
 
237
  duration = time.time() - start_time
238
 
 
239
  return BenchmarkResult(
240
  data_point_id=data_point.id,
241
  question=data_point.text,
@@ -247,8 +219,6 @@ class BenchmarkRunner:
247
  chunk_history=response.chunk_history,
248
  metadata={
249
  "data_point_metadata": data_point.metadata,
250
- "case_id": data_point.case_id,
251
- "category": data_point.category,
252
  "raw_response": response.content,
253
  }
254
  )
@@ -265,9 +235,7 @@ class BenchmarkRunner:
265
  error=str(e),
266
  chunk_history=None,
267
  metadata={
268
- "data_point_metadata": data_point.metadata,
269
- "case_id": data_point.case_id,
270
- "category": data_point.category,
271
  }
272
  )
273
 
@@ -289,29 +257,6 @@ class BenchmarkRunner:
289
  # If no pattern matches, return the full response
290
  return response_text.strip()
291
 
292
- def _is_correct_answer(self, model_answer: str, correct_answer: str) -> bool:
293
- """Check if the model answer is correct.
294
-
295
- Args:
296
- model_answer (str): The model's answer
297
- correct_answer (str): The correct answer
298
-
299
- Returns:
300
- bool: True if the answer is correct
301
- """
302
- if not model_answer or not correct_answer:
303
- return False
304
-
305
- # For multiple choice, compare just the letter
306
- model_clean = model_answer.strip().upper()
307
- correct_clean = correct_answer.strip().upper()
308
-
309
- # Extract just the first letter for comparison
310
- model_letter = model_clean[0] if model_clean else ""
311
- correct_letter = correct_clean[0] if correct_clean else ""
312
-
313
- return model_letter == correct_letter
314
-
315
  def _save_individual_result(self, result: BenchmarkResult) -> None:
316
  """Save a single result to its own JSON file.
317
 
@@ -321,12 +266,14 @@ class BenchmarkRunner:
321
  # Sanitize data_point_id for filename (remove invalid characters)
322
  safe_id = re.sub(r'[^\w\-_.]', '_', result.data_point_id)
323
 
 
 
 
 
 
324
  # Create filename with benchmark name and data point ID
325
  filename = f"{self.config.benchmark_name}_{safe_id}.json"
326
- result_file = self.output_dir / "individual_results" / filename
327
-
328
- # Create individual_results directory if it doesn't exist
329
- result_file.parent.mkdir(exist_ok=True)
330
 
331
  # Convert result to serializable format
332
  result_data = {
@@ -341,7 +288,7 @@ class BenchmarkRunner:
341
  "usage": result.usage,
342
  "error": result.error,
343
  "chunk_history": result.chunk_history,
344
- "metadata": result.metadata
345
  }
346
 
347
  # Save to file
@@ -357,8 +304,13 @@ class BenchmarkRunner:
357
  Returns:
358
  Dict[str, Any]: Summary of results
359
  """
 
 
 
 
 
360
  # Save detailed results
361
- results_file = self.output_dir / f"{self.run_id}_results.json"
362
 
363
  # Convert results to serializable format for final file
364
  results_data = []
@@ -385,29 +337,14 @@ class BenchmarkRunner:
385
 
386
  accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
387
 
388
- # Calculate per-category accuracy
389
- category_stats = {}
390
- for result in self.results:
391
- if result.metadata and result.metadata.get("category"):
392
- category = result.metadata["category"]
393
- if category not in category_stats:
394
- category_stats[category] = {"correct": 0, "total": 0}
395
- category_stats[category]["total"] += 1
396
- if result.is_correct:
397
- category_stats[category]["correct"] += 1
398
-
399
- # Calculate accuracy for each category
400
- category_accuracies = {}
401
- for category, stats in category_stats.items():
402
- category_accuracies[category] = (stats["correct"] / stats["total"]) * 100
403
-
404
  # Create summary
405
  summary = {
406
  "run_id": self.run_id,
407
  "timestamp": datetime.now().isoformat(),
408
  "config": {
409
- "model_name": self.config.model_name,
410
  "benchmark_name": self.config.benchmark_name,
 
 
411
  "temperature": self.config.temperature,
412
  "top_p": self.config.top_p,
413
  "max_tokens": self.config.max_tokens,
@@ -422,13 +359,12 @@ class BenchmarkRunner:
422
  "total_questions": total_questions,
423
  "total_duration": total_duration,
424
  "avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
425
- "category_accuracies": category_accuracies,
426
  },
427
  "results_file": str(results_file),
428
  }
429
 
430
  # Save summary
431
- summary_file = self.output_dir / f"{self.run_id}_summary.json"
432
  with open(summary_file, 'w') as f:
433
  json.dump(summary, f, indent=2)
434
 
 
32
  @dataclass
33
  class BenchmarkRunConfig:
34
  """Configuration for a benchmark run."""
35
+ benchmark_name: str
36
  provider_name: str
37
  model_name: str
 
38
  output_dir: str
39
  max_questions: Optional[int] = None
40
  temperature: float = 0.7
41
  top_p: float = 0.95
42
  max_tokens: int = 5000
 
43
  concurrency: int = 1
44
+ random_seed: Optional[int] = None
45
+
46
 
47
 
48
  class BenchmarkRunner:
 
60
  self.output_dir.mkdir(parents=True, exist_ok=True)
61
 
62
  # Generate unique run ID
63
+ self.run_id = f"{config.benchmark_name}_{config.provider_name}_{config.model_name}_{config.max_questions}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
64
 
65
  # Set up logging
66
  self._setup_logging()
 
67
  self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
68
 
69
  def _setup_logging(self) -> None:
 
91
 
92
  def run_benchmark(
93
  self,
 
94
  benchmark: Benchmark,
95
+ llm_provider: LLMProvider,
96
  ) -> Dict[str, Any]:
97
  """Run a benchmark against an LLM provider.
98
 
99
  Args:
 
100
  benchmark (Benchmark): The benchmark to run
101
+ llm_provider (LLMProvider): The LLM provider to test
102
 
103
  Returns:
104
  Dict[str, Any]: Summary of benchmark results
105
  """
106
  self.logger.info(f"Starting benchmark run: {self.run_id}")
 
107
  self.logger.info(f"Benchmark: {benchmark}")
108
+ self.logger.info(f"Provider: {llm_provider.provider_name}")
109
+ self.logger.info(f"Model: {llm_provider.model_name}")
110
 
111
  # Test provider connection
112
  if not llm_provider.test_connection():
113
  self.logger.error("LLM provider connection test failed")
114
  return {"error": "LLM provider connection test failed"}
115
 
 
 
 
 
 
 
 
116
  # Initialize counters
117
  processed = 0
118
  correct = 0
 
121
  # Determine concurrency
122
  max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Process data points in parallel using a bounded thread pool
125
+ with tqdm(total=len(benchmark), desc="Processing questions") as pbar:
126
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
127
+ future_to_index = {executor.submit(self._process_data_point, dp, llm_provider): idx for idx, dp in enumerate(benchmark)}
128
  for future in as_completed(future_to_index):
129
  idx = future_to_index[future]
130
  try:
 
159
  accuracy = (correct / processed) * 100
160
  avg_duration = total_duration / processed if processed > 0 else 0.0
161
  self.logger.info(
162
+ f"Progress: {processed}/{len(benchmark)} | "
163
  f"Accuracy: {accuracy:.2f}% | "
164
  f"Avg Duration: {avg_duration:.2f}s"
165
  )
166
 
167
  # Save final results
168
  summary = self._save_final_results(benchmark)
169
+
170
  self.logger.info(f"Benchmark run completed: {self.run_id}")
171
+ self.logger.info(f"Summary: {summary}")
172
+
 
173
  return summary
174
 
175
  def _process_data_point(
176
  self,
 
177
  data_point: BenchmarkDataPoint,
178
+ llm_provider: LLMProvider
179
  ) -> BenchmarkResult:
180
  """Process a single data point.
181
 
182
  Args:
 
183
  data_point (BenchmarkDataPoint): The data point to process
184
+ llm_provider (LLMProvider): The LLM provider to use
185
 
186
  Returns:
187
  BenchmarkResult: Result of processing the data point
 
189
  start_time = time.time()
190
 
191
  try:
192
+ # Create request for LLM
193
  request = LLMRequest(
194
  text=data_point.text,
195
+ images=data_point.images
 
 
 
 
196
  )
197
 
198
  # Get response from LLM
 
202
  model_answer = self._extract_answer(response.content)
203
 
204
  # Check if correct
205
+ is_correct = model_answer == data_point.correct_answer
206
 
207
+ # Calculate duration
208
  duration = time.time() - start_time
209
 
210
+ # Return result
211
  return BenchmarkResult(
212
  data_point_id=data_point.id,
213
  question=data_point.text,
 
219
  chunk_history=response.chunk_history,
220
  metadata={
221
  "data_point_metadata": data_point.metadata,
 
 
222
  "raw_response": response.content,
223
  }
224
  )
 
235
  error=str(e),
236
  chunk_history=None,
237
  metadata={
238
+ "data_point_metadata": data_point.metadata
 
 
239
  }
240
  )
241
 
 
257
  # If no pattern matches, return the full response
258
  return response_text.strip()
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def _save_individual_result(self, result: BenchmarkResult) -> None:
261
  """Save a single result to its own JSON file.
262
 
 
266
  # Sanitize data_point_id for filename (remove invalid characters)
267
  safe_id = re.sub(r'[^\w\-_.]', '_', result.data_point_id)
268
 
269
+ # Create run_id directory and individual_results subdirectory
270
+ run_dir = self.output_dir / self.run_id
271
+ individual_results_dir = run_dir / "individual_results"
272
+ individual_results_dir.mkdir(parents=True, exist_ok=True)
273
+
274
  # Create filename with benchmark name and data point ID
275
  filename = f"{self.config.benchmark_name}_{safe_id}.json"
276
+ result_file = individual_results_dir / filename
 
 
 
277
 
278
  # Convert result to serializable format
279
  result_data = {
 
288
  "usage": result.usage,
289
  "error": result.error,
290
  "chunk_history": result.chunk_history,
291
+ "metadata": result.metadata,
292
  }
293
 
294
  # Save to file
 
304
  Returns:
305
  Dict[str, Any]: Summary of results
306
  """
307
+ # Create run_id directory and final_results subdirectory
308
+ run_dir = self.output_dir / self.run_id
309
+ final_results_dir = run_dir / "final_results"
310
+ final_results_dir.mkdir(parents=True, exist_ok=True)
311
+
312
  # Save detailed results
313
+ results_file = final_results_dir / f"{self.run_id}_results.json"
314
 
315
  # Convert results to serializable format for final file
316
  results_data = []
 
337
 
338
  accuracy = (correct_answers / total_questions) * 100 if total_questions > 0 else 0
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # Create summary
341
  summary = {
342
  "run_id": self.run_id,
343
  "timestamp": datetime.now().isoformat(),
344
  "config": {
 
345
  "benchmark_name": self.config.benchmark_name,
346
+ "provider_name": self.config.provider_name,
347
+ "model_name": self.config.model_name,
348
  "temperature": self.config.temperature,
349
  "top_p": self.config.top_p,
350
  "max_tokens": self.config.max_tokens,
 
359
  "total_questions": total_questions,
360
  "total_duration": total_duration,
361
  "avg_duration_per_question": total_duration / total_questions if total_questions > 0 else 0,
 
362
  },
363
  "results_file": str(results_file),
364
  }
365
 
366
  # Save summary
367
+ summary_file = final_results_dir / f"{self.run_id}_summary.json"
368
  with open(summary_file, 'w') as f:
369
  json.dump(summary, f, indent=2)
370
 
benchmarking/system_prompts.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [MEDICAL_ASSISTANT]
2
+ You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
3
+ Solve using your own vision and reasoning and use tools to complement your reasoning.
4
+ You can make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
5
+ Think critically about and criticize the tool outputs.
6
+ If you need to look up some information before asking a follow up question, you are allowed to do that.
7
+
8
+ CITATION REQUIREMENTS:
9
+ - When referencing information from RAG and/or web search tools, ALWAYS include numbered citations [1], [2], [3], etc.
10
+ - Use citations immediately after making claims or statements based on the above tool results.
11
+ - Be consistent with citation numbering throughout your response.
12
+ - Only cite sources that actually contain the information you're referencing.
13
+
14
+ Examples:
15
+ - "According to recent research [1], chest X-rays can show signs of pneumonia..."
16
+ - "The medical literature indicates [2] that this condition typically presents with..."
17
+ - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
+
19
+ [CHESTAGENTBENCH_PROMPT]
20
+ You are a highly skilled radiology AI agent, an expert in interpreting medical images, specifically chest X-rays, CT scans, and MRIs, with world-class accuracy and precision.
21
+ Your primary function is to assist in the analysis of these images and answer diagnostic questions.
22
+
23
+ Your task is to provide a step-by-step, structured analysis. First, carefully examine the provided image and describe all relevant findings in a clear, concise manner.
24
+ Next, use your expert medical knowledge to form a differential diagnosis based on these findings. Finally, critically evaluate the provided question and all possible choices.
25
+
26
+ You have access to a suite of powerful tools to aid in your analysis. Use these tools as needed to retrieve external medical knowledge, access patient history, or perform specific image processing tasks.
27
+ You should always scrutinize the output from your tools and integrate it into your reasoning. If tool outputs conflict with your initial assessment, explain the discrepancy and justify your final conclusion.
28
+ You must take care to pass in the image paths exactly or else the tools will not work. Do not mangle up the image paths.
29
+
30
+ Your final response for a multiple-choice question must strictly follow this format, including your step-by-step reasoning:
31
+ 1. **Image Analysis:** [Describe image findings here]
32
+ 2. **Differential Diagnosis:** [List possible diagnoses and their justifications]
33
+ 3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
34
+ 4. **Final Answer:** \boxed{A}
35
+
36
+ Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.