VictorLJZ commited on
Commit
9006287
·
2 Parent(s): aff69d746fe537

Merge pull request #26 from bowang-lab/victor/benchmarking

Browse files
.env.example DELETED
@@ -1,10 +0,0 @@
1
- OPENAI_API_KEY=
2
- OPENAI_BASE_URL=
3
- GOOGLE_API_KEY=
4
- GOOGLE_SEARCH_API_KEY=
5
- GOOGLE_SEARCH_ENGINE_ID=
6
- OPENROUTER_API_KEY=
7
- OPENROUTER_BASE_URL=
8
- COHERE_API_KEY=
9
- PINECONE_API_KEY=
10
- MEDGEMMA_API_URL=
 
 
 
 
 
 
 
 
 
 
 
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -7,6 +7,8 @@ from datasets import load_dataset
7
  from .base import Benchmark, BenchmarkDataPoint
8
  from pathlib import Path
9
  import subprocess
 
 
10
  from huggingface_hub import hf_hub_download, list_repo_files
11
 
12
 
@@ -44,14 +46,20 @@ class ReXVQABenchmark(Benchmark):
44
  self.image_dataset = None
45
  self.image_mapping = {} # Maps study_id to image data
46
 
47
- super().__init__(data_dir, **kwargs)
 
48
 
49
- # Set images_dir after parent initialization
50
- self.images_dir = f"{self.data_dir}/images/deid_png"
51
 
52
  @staticmethod
53
- def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K"):
54
- """Download and extract ReXGradient-160K images if not already present."""
 
 
 
 
 
 
55
  output_dir = Path(output_dir)
56
  tar_path = output_dir / "deid_png.tar"
57
  images_dir = output_dir / "images"
@@ -60,6 +68,33 @@ class ReXVQABenchmark(Benchmark):
60
  if images_dir.exists() and any(images_dir.rglob("*.png")):
61
  print(f"Images already exist in {images_dir}, skipping download.")
62
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  output_dir.mkdir(parents=True, exist_ok=True)
64
  print(f"Output directory: {output_dir}")
65
  try:
@@ -96,6 +131,17 @@ class ReXVQABenchmark(Benchmark):
96
  tar_file.write(f.read())
97
  else:
98
  print(f"Warning: {part_file} not found, skipping...")
 
 
 
 
 
 
 
 
 
 
 
99
  else:
100
  print(f"Tar file already exists: {tar_path}")
101
  # Extract tar file
@@ -106,36 +152,72 @@ class ReXVQABenchmark(Benchmark):
106
  print("Images already extracted.")
107
  else:
108
  try:
109
- subprocess.run([
110
- "tar", "-xf", str(tar_path),
111
- "-C", str(images_dir)
112
- ], check=True)
113
- print("Extraction completed!")
114
- except subprocess.CalledProcessError as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  print(f"Error extracting tar file: {e}")
116
  return
117
- except FileNotFoundError:
118
- print("Error: 'tar' command not found. Please install tar or extract manually.")
119
- return
120
  png_files = list(images_dir.rglob("*.png"))
121
  print(f"Extracted {len(png_files)} PNG images to {images_dir}")
122
 
123
- # Clean up part and tar files after successful extraction
124
- print("Cleaning up part and tar files...")
125
- # Remove deid_png.part* files
126
- for part_file in output_dir.glob("deid_png.part*"):
127
- try:
128
- part_file.unlink()
129
- print(f"Deleted {part_file}")
130
- except Exception as e:
131
- print(f"Could not delete {part_file}: {e}")
132
- # Remove deid_png.tar
133
- if tar_path.exists():
134
- try:
135
- tar_path.unlink()
136
- print(f"Deleted {tar_path}")
137
- except Exception as e:
138
- print(f"Could not delete {tar_path}: {e}")
139
  except Exception as e:
140
  print(f"Error: {e}")
141
 
@@ -167,7 +249,7 @@ class ReXVQABenchmark(Benchmark):
167
  try:
168
  # Check for images and test_vqa_data.json, download if missing
169
  self.download_test_vqa_data_json(self.data_dir)
170
- self.download_rexgradient_images(self.data_dir)
171
 
172
  # Construct path to the JSON file
173
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
 
7
  from .base import Benchmark, BenchmarkDataPoint
8
  from pathlib import Path
9
  import subprocess
10
+ import tarfile
11
+ import zstandard as zstd
12
  from huggingface_hub import hf_hub_download, list_repo_files
13
 
14
 
 
46
  self.image_dataset = None
47
  self.image_mapping = {} # Maps study_id to image data
48
 
49
+ # Set images_dir BEFORE parent initialization to avoid AttributeError
50
+ self.images_dir = f"{data_dir}/images/deid_png"
51
 
52
+ super().__init__(data_dir, **kwargs)
 
53
 
54
  @staticmethod
55
+ def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
56
+ """Download and extract ReXGradient-160K images if not already present.
57
+
58
+ Args:
59
+ output_dir: Directory to store downloaded and extracted images
60
+ repo_id: HuggingFace repository ID for the dataset
61
+ test_only: If True, only extract images from the test split (default: True)
62
+ """
63
  output_dir = Path(output_dir)
64
  tar_path = output_dir / "deid_png.tar"
65
  images_dir = output_dir / "images"
 
68
  if images_dir.exists() and any(images_dir.rglob("*.png")):
69
  print(f"Images already exist in {images_dir}, skipping download.")
70
  return
71
+
72
+ # Load test split metadata if test_only is True
73
+ test_image_paths = set()
74
+ if test_only:
75
+ print("Loading test split metadata to identify test images...")
76
+ try:
77
+ # Load the test metadata to get image paths
78
+ test_metadata_path = output_dir / "metadata" / "test_vqa_data.json"
79
+ if test_metadata_path.exists():
80
+ with open(test_metadata_path, 'r', encoding='utf-8') as f:
81
+ test_data = json.load(f)
82
+
83
+ # Extract all image paths from test data
84
+ for item in test_data.values():
85
+ if "ImagePath" in item and item["ImagePath"]:
86
+ for rel_path in item["ImagePath"]:
87
+ # Normalize path to match tar file structure
88
+ norm_path = rel_path.lstrip("./")
89
+ test_image_paths.add(norm_path)
90
+
91
+ print(f"Found {len(test_image_paths)} test images to extract")
92
+ else:
93
+ print("Warning: test_vqa_data.json not found, will extract all images")
94
+ test_only = False
95
+ except Exception as e:
96
+ print(f"Warning: Could not load test metadata: {e}, will extract all images")
97
+ test_only = False
98
  output_dir.mkdir(parents=True, exist_ok=True)
99
  print(f"Output directory: {output_dir}")
100
  try:
 
131
  tar_file.write(f.read())
132
  else:
133
  print(f"Warning: {part_file} not found, skipping...")
134
+
135
+ # Clean up part files after successful concatenation
136
+ print("Cleaning up part files...")
137
+ for part_file in part_files:
138
+ part_path = output_dir / part_file
139
+ if part_path.exists():
140
+ try:
141
+ part_path.unlink()
142
+ print(f"Deleted {part_file}")
143
+ except Exception as e:
144
+ print(f"Could not delete {part_file}: {e}")
145
  else:
146
  print(f"Tar file already exists: {tar_path}")
147
  # Extract tar file
 
152
  print("Images already extracted.")
153
  else:
154
  try:
155
+ # Stream extract with filtering for test-only images (no seeking)
156
+ print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...")
157
+
158
+ # Create a decompressor
159
+ dctx = zstd.ZstdDecompressor()
160
+
161
+ # Stream extract with filtering
162
+ extracted_count = 0
163
+ total_png_members = 0
164
+
165
+ with open(tar_path, 'rb') as compressed_file:
166
+ with dctx.stream_reader(compressed_file) as decompressed_stream:
167
+ # Use streaming tar mode to avoid seeks
168
+ with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar:
169
+ for member in tar:
170
+ # Only consider PNG files
171
+ if not member.isfile() or not member.name.endswith('.png'):
172
+ continue
173
+ total_png_members += 1
174
+
175
+ # Normalize name to match entries gathered from JSON
176
+ normalized_name = member.name.lstrip('./')
177
+
178
+ # Decide whether to extract this file
179
+ should_extract = True
180
+ if test_only:
181
+ should_extract = normalized_name in test_image_paths
182
+
183
+ if not should_extract:
184
+ # Must still advance the stream for this member
185
+ tar.members = [] # no-op in stream mode; ensure we don't hold refs
186
+ continue
187
+
188
+ # Ensure parent directories exist and write file by streaming
189
+ target_path = Path(images_dir) / normalized_name
190
+ target_path.parent.mkdir(parents=True, exist_ok=True)
191
+
192
+ extracted_file_obj = tar.extractfile(member)
193
+ if extracted_file_obj is None:
194
+ continue
195
+ with open(target_path, 'wb') as out_f:
196
+ while True:
197
+ chunk = extracted_file_obj.read(1024 * 1024)
198
+ if not chunk:
199
+ break
200
+ out_f.write(chunk)
201
+
202
+ extracted_count += 1
203
+ if extracted_count % 100 == 0:
204
+ print(f"Extracted {extracted_count} test images...")
205
+
206
+ print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive")
207
+
208
+ # Clean up compressed tar file after successful extraction
209
+ print("Cleaning up compressed tar file...")
210
+ try:
211
+ tar_path.unlink()
212
+ print(f"Deleted {tar_path}")
213
+ except Exception as e:
214
+ print(f"Could not delete {tar_path}: {e}")
215
+ except Exception as e:
216
  print(f"Error extracting tar file: {e}")
217
  return
 
 
 
218
  png_files = list(images_dir.rglob("*.png"))
219
  print(f"Extracted {len(png_files)} PNG images to {images_dir}")
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
222
  print(f"Error: {e}")
223
 
 
249
  try:
250
  # Check for images and test_vqa_data.json, download if missing
251
  self.download_test_vqa_data_json(self.data_dir)
252
+ self.download_rexgradient_images(self.data_dir, test_only=True)
253
 
254
  # Construct path to the JSON file
255
  json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
benchmarking/cli.py CHANGED
@@ -87,7 +87,8 @@ def run_benchmark_command(args) -> None:
87
  max_questions=args.max_questions,
88
  temperature=args.temperature,
89
  top_p=args.top_p,
90
- max_tokens=args.max_tokens
 
91
  )
92
 
93
  # Run benchmark
@@ -145,6 +146,8 @@ def main():
145
  help="Maximum tokens per model response (default: 5000)")
146
  run_parser.add_argument("--random-seed", type=int, default=42,
147
  help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
 
 
148
 
149
  run_parser.set_defaults(func=run_benchmark_command)
150
 
 
87
  max_questions=args.max_questions,
88
  temperature=args.temperature,
89
  top_p=args.top_p,
90
+ max_tokens=args.max_tokens,
91
+ concurrency=args.concurrency
92
  )
93
 
94
  # Run benchmark
 
146
  help="Maximum tokens per model response (default: 5000)")
147
  run_parser.add_argument("--random-seed", type=int, default=42,
148
  help="Random seed for shuffling benchmark data (enables reproducible runs, default: None)")
149
+ run_parser.add_argument("--concurrency", type=int, default=1,
150
+ help="Number of datapoints to process in parallel (default: 1)")
151
 
152
  run_parser.set_defaults(func=run_benchmark_command)
153
 
benchmarking/llm_providers/base.py CHANGED
@@ -85,7 +85,7 @@ class LLMProvider(ABC):
85
  try:
86
  # Simple test request
87
  test_request = LLMRequest(
88
- text="Hello",
89
  temperature=0.5,
90
  max_tokens=1000
91
  )
 
85
  try:
86
  # Simple test request
87
  test_request = LLMRequest(
88
+ text="Hello! What model are you? Tell me your full specification.",
89
  temperature=0.5,
90
  max_tokens=1000
91
  )
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -37,11 +37,11 @@ class MedRAXProvider(LLMProvider):
37
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
  "XRayPhraseGroundingTool", # For locating described features in X-rays
40
- "MedGemmaVQATool",
41
- # "XRayVQATool", # For visual question answering on X-rays
42
- # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
- # "WebBrowserTool", # For web browsing and search capabilities
44
- # "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
45
  ]
46
 
47
  rag_config = RAGConfig(
 
37
  "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
38
  "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
  "XRayPhraseGroundingTool", # For locating described features in X-rays
40
+ "MedGemmaVQATool", # Google MedGemma VQA tool
41
+ "XRayVQATool", # For visual question answering on X-rays
42
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
43
+ "WebBrowserTool", # For web browsing and search capabilities
44
+ "DuckDuckGoSearchTool", # For privacy-focused web search using DuckDuckGo
45
  ]
46
 
47
  rag_config = RAGConfig(
benchmarking/runner.py CHANGED
@@ -9,6 +9,7 @@ from typing import Dict, Optional, Any
9
  from dataclasses import dataclass
10
  from tqdm import tqdm
11
  import re
 
12
  from .llm_providers import LLMProvider, LLMRequest, LLMResponse
13
  from .benchmarks import Benchmark, BenchmarkDataPoint
14
 
@@ -40,6 +41,7 @@ class BenchmarkRunConfig:
40
  top_p: float = 0.95
41
  max_tokens: int = 5000
42
  additional_params: Optional[Dict[str, Any]] = None
 
43
 
44
 
45
  class BenchmarkRunner:
@@ -122,40 +124,16 @@ class BenchmarkRunner:
122
  correct = 0
123
  total_duration = 0.0
124
 
125
- # Process each data point
126
- for i in tqdm(range(0, end_index), desc="Processing questions"):
 
 
 
 
127
  try:
128
- data_point = benchmark.get_data_point(i)
129
-
130
- # Run the model on this data point
131
- result = self._process_data_point(llm_provider, data_point)
132
-
133
- # Update counters
134
- processed += 1
135
- if result.is_correct:
136
- correct += 1
137
- total_duration += result.duration
138
-
139
- # Add to results
140
- self.results.append(result)
141
-
142
- # Save individual result immediately
143
- self._save_individual_result(result)
144
-
145
- # Log progress
146
- if processed % 10 == 0:
147
- accuracy = (correct / processed) * 100
148
- avg_duration = total_duration / processed
149
-
150
- self.logger.info(
151
- f"Progress: {processed}/{end_index} | "
152
- f"Accuracy: {accuracy:.2f}% | "
153
- f"Avg Duration: {avg_duration:.2f}s"
154
- )
155
-
156
  except Exception as e:
157
- self.logger.error(f"Error processing data point {i}: {e}")
158
- # Add error result
159
  error_result = BenchmarkResult(
160
  data_point_id=f"error_{i}",
161
  question="",
@@ -166,10 +144,50 @@ class BenchmarkRunner:
166
  error=str(e)
167
  )
168
  self.results.append(error_result)
169
-
170
- # Save individual error result immediately
171
  self._save_individual_result(error_result)
172
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # Save final results
175
  summary = self._save_final_results(benchmark)
@@ -391,6 +409,7 @@ class BenchmarkRunner:
391
  "model_name": self.config.model_name,
392
  "benchmark_name": self.config.benchmark_name,
393
  "temperature": self.config.temperature,
 
394
  "max_tokens": self.config.max_tokens,
395
  },
396
  "benchmark_info": {
 
9
  from dataclasses import dataclass
10
  from tqdm import tqdm
11
  import re
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
  from .llm_providers import LLMProvider, LLMRequest, LLMResponse
14
  from .benchmarks import Benchmark, BenchmarkDataPoint
15
 
 
41
  top_p: float = 0.95
42
  max_tokens: int = 5000
43
  additional_params: Optional[Dict[str, Any]] = None
44
+ concurrency: int = 1
45
 
46
 
47
  class BenchmarkRunner:
 
124
  correct = 0
125
  total_duration = 0.0
126
 
127
+ # Determine concurrency
128
+ max_workers = max(1, int(getattr(self.config, "concurrency", 1) or 1))
129
+
130
+ # Prefetch data points to avoid potential thread-safety issues inside benchmark access
131
+ data_points = []
132
+ for i in range(0, end_index):
133
  try:
134
+ data_points.append(benchmark.get_data_point(i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
+ self.logger.error(f"Error fetching data point {i}: {e}")
 
137
  error_result = BenchmarkResult(
138
  data_point_id=f"error_{i}",
139
  question="",
 
144
  error=str(e)
145
  )
146
  self.results.append(error_result)
 
 
147
  self._save_individual_result(error_result)
148
+
149
+ # Process data points in parallel using a bounded thread pool
150
+ with tqdm(total=end_index, desc="Processing questions") as pbar:
151
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
152
+ future_to_index = {executor.submit(self._process_data_point, llm_provider, dp): idx for idx, dp in enumerate(data_points)}
153
+ for future in as_completed(future_to_index):
154
+ idx = future_to_index[future]
155
+ try:
156
+ result = future.result()
157
+ except Exception as e:
158
+ self.logger.error(f"Error processing data point {idx}: {e}")
159
+ result = BenchmarkResult(
160
+ data_point_id=f"error_{idx}",
161
+ question="",
162
+ model_answer="",
163
+ correct_answer="",
164
+ is_correct=False,
165
+ duration=0.0,
166
+ error=str(e)
167
+ )
168
+
169
+ # Update counters
170
+ processed += 1
171
+ if result.is_correct:
172
+ correct += 1
173
+ total_duration += result.duration
174
+
175
+ # Add to results and persist immediately
176
+ self.results.append(result)
177
+ self._save_individual_result(result)
178
+
179
+ # Update progress bar
180
+ pbar.update(1)
181
+
182
+ # Periodic logging
183
+ if processed % 10 == 0:
184
+ accuracy = (correct / processed) * 100
185
+ avg_duration = total_duration / processed if processed > 0 else 0.0
186
+ self.logger.info(
187
+ f"Progress: {processed}/{end_index} | "
188
+ f"Accuracy: {accuracy:.2f}% | "
189
+ f"Avg Duration: {avg_duration:.2f}s"
190
+ )
191
 
192
  # Save final results
193
  summary = self._save_final_results(benchmark)
 
409
  "model_name": self.config.model_name,
410
  "benchmark_name": self.config.benchmark_name,
411
  "temperature": self.config.temperature,
412
+ "top_p": self.config.top_p,
413
  "max_tokens": self.config.max_tokens,
414
  },
415
  "benchmark_info": {
benchmarking_script.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=chestagentbench
4
+ #SBATCH -c 4
5
+ #SBATCH --gres=gpu:rtx6000:1
6
+ #SBATCH --exclude=gpu138
7
+ #SBATCH --time=16:00:00
8
+ #SBATCH --mem=50G
9
+ #SBATCH --output=chestagentbench-%j.out
10
+ #SBATCH --error=chestagentbench-%j.err
11
+
12
+ source venv/bin/activate
13
+
14
+ python -m benchmarking.cli run --model gpt-5 --provider medrax --system-prompt CHESTAGENTBENCH_PROMPT --benchmark chestagentbench --data-dir /scratch/ssd004/scratch/victorli/chestagentbench --output-dir temp --max-questions 2500 --concurrency 4
main.py CHANGED
@@ -68,7 +68,7 @@ def initialize_agent(
68
  prompt = prompts[system_prompt]
69
 
70
  # Define the URL of the MedGemma FastAPI service.
71
- MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://172.17.8.141:8002")
72
 
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
 
68
  prompt = prompts[system_prompt]
69
 
70
  # Define the URL of the MedGemma FastAPI service.
71
+ MEDGEMMA_API_URL = os.getenv("MEDGEMMA_API_URL", "http://localhost:8002")
72
 
73
  all_tools = {
74
  "TorchXRayVisionClassifierTool": lambda: TorchXRayVisionClassifierTool(device=device),
medgemma_script.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=medgemma
4
+ #SBATCH -c 4
5
+ #SBATCH --gres=gpu:rtx6000:1
6
+ #SBATCH --exclude=gpu138
7
+ #SBATCH --time=16:00:00
8
+ #SBATCH --mem=50G
9
+ #SBATCH --output=medgemma-%j.out
10
+ #SBATCH --error=medgemma-%j.err
11
+
12
+ source medgemma/bin/activate
13
+
14
+ cd medrax/tools/vqa/medgemma && python medgemma.py
medrax/docs/system_prompts.txt CHANGED
@@ -17,9 +17,16 @@ Examples:
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
- You are an expert medical assistant who can answer medical questions and analyze medical images with world-class accuracy.
21
- Use your state-of-the art reasoning and critical thinking skills to answer the questions that you are asked.
22
- You may use tools (if available) to complement your reasoning and you are allowed to make multiple tool calls in parallel or in sequence as needed for comprehensive answers.
23
- Think critically about how to best use the tools available to you and scrutinize the tool outputs.
24
- When encountering a multiple-choice question, your final response should end with "Final answer: \boxed{A}" from list of possible choices A, B, C, D, E, F.
25
- It is extremely important that you answer strictly in the format described above.
 
 
 
 
 
 
 
 
17
  - "Based on clinical guidelines [3], the recommended treatment approach is..."
18
 
19
  [CHESTAGENTBENCH_PROMPT]
20
+ You are a highly skilled radiology AI agent, an expert in interpreting medical images, specifically chest X-rays, CT scans, and MRIs, with world-class accuracy and precision. Your primary function is to assist in the analysis of these images and answer diagnostic questions.
21
+
22
+ Your task is to provide a step-by-step, structured analysis. First, carefully examine the provided image and describe all relevant findings in a clear, concise manner. Next, use your expert medical knowledge to form a differential diagnosis based on these findings. Finally, critically evaluate the provided question and all possible choices.
23
+
24
+ You have access to a suite of powerful tools to aid in your analysis. Use these tools as needed to retrieve external medical knowledge, access patient history, or perform specific image processing tasks. You should always scrutinize the output from your tools and integrate it into your reasoning. If tool outputs conflict with your initial assessment, explain the discrepancy and justify your final conclusion.
25
+
26
+ Your final response for a multiple-choice question must strictly follow this format, including your step-by-step reasoning:
27
+ 1. **Image Analysis:** [Describe image findings here]
28
+ 2. **Differential Diagnosis:** [List possible diagnoses and their justifications]
29
+ 3. **Critical Thinking & Tool Use:** [Show your reasoning, including how you used tools and evaluated their output]
30
+ 4. **Final Answer:** \boxed{A}
31
+
32
+ Do not provide a definitive diagnosis or treatment plan for a patient. Your purpose is to assist medical professionals with your analysis, not to replace them. You must maintain this persona and adhere to all instructions.
medrax/models/model_factory.py CHANGED
@@ -123,6 +123,16 @@ class ModelFactory:
123
  if provider_prefix in ["openrouter"] and model_name.startswith(f"{provider_prefix}-"):
124
  actual_model_name = model_name[len(provider_prefix) + 1 :]
125
 
 
 
 
 
 
 
 
 
 
 
126
  # Create and return the model instance
127
  return model_class(
128
  model=actual_model_name,
 
123
  if provider_prefix in ["openrouter"] and model_name.startswith(f"{provider_prefix}-"):
124
  actual_model_name = model_name[len(provider_prefix) + 1 :]
125
 
126
+ # Handle GPT-5 model
127
+ if model_name.startswith("gpt-5"):
128
+ return model_class(
129
+ model=actual_model_name,
130
+ temperature=temperature,
131
+ reasoning_effort="high",
132
+ **provider_kwargs,
133
+ **kwargs,
134
+ )
135
+
136
  # Create and return the model instance
137
  return model_class(
138
  model=actual_model_name,
medrax/tools/classification/torchxrayvision.py CHANGED
@@ -12,6 +12,8 @@ from langchain_core.callbacks import (
12
  )
13
  from langchain_core.tools import BaseTool
14
 
 
 
15
 
16
  class TorchXRayVisionInput(BaseModel):
17
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
@@ -76,7 +78,9 @@ class TorchXRayVisionClassifierTool(BaseTool):
76
  ValueError: If the image cannot be properly loaded or processed.
77
  """
78
  img = skimage.io.imread(image_path)
79
- img = xrv.datasets.normalize(img, 255)
 
 
80
 
81
  if len(img.shape) > 2:
82
  img = img[:, :, 0]
 
12
  )
13
  from langchain_core.tools import BaseTool
14
 
15
+ from medrax.utils.utils import preprocess_medical_image
16
+
17
 
18
  class TorchXRayVisionInput(BaseModel):
19
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
 
78
  ValueError: If the image cannot be properly loaded or processed.
79
  """
80
  img = skimage.io.imread(image_path)
81
+
82
+ # Use robust normalization that handles both 8-bit and 16-bit images
83
+ img = preprocess_medical_image(img, target_range=(-1024.0, 1024.0))
84
 
85
  if len(img.shape) > 2:
86
  img = img[:, :, 0]
medrax/tools/segmentation/segmentation.py CHANGED
@@ -20,6 +20,8 @@ from langchain_core.callbacks import (
20
  )
21
  from langchain_core.tools import BaseTool
22
 
 
 
23
 
24
  class ChestXRaySegmentationInput(BaseModel):
25
  """Input schema for the Chest X-ray Segmentation Tool."""
@@ -246,7 +248,8 @@ class ChestXRaySegmentationTool(BaseTool):
246
  if len(original_img.shape) > 2:
247
  original_img = original_img[:, :, 0]
248
 
249
- img = xrv.datasets.normalize(original_img, 255)
 
250
  img = img[None, ...]
251
  img = self.transform(img)
252
  img = torch.from_numpy(img)
 
20
  )
21
  from langchain_core.tools import BaseTool
22
 
23
+ from medrax.utils.utils import preprocess_medical_image
24
+
25
 
26
  class ChestXRaySegmentationInput(BaseModel):
27
  """Input schema for the Chest X-ray Segmentation Tool."""
 
248
  if len(original_img.shape) > 2:
249
  original_img = original_img[:, :, 0]
250
 
251
+ # Use robust normalization that handles both 8-bit and 16-bit images
252
+ img = preprocess_medical_image(original_img)
253
  img = img[None, ...]
254
  img = self.transform(img)
255
  img = torch.from_numpy(img)
medrax/utils/utils.py CHANGED
@@ -1,6 +1,73 @@
1
  import os
2
  import json
3
- from typing import Dict, List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def load_prompts_from_file(file_path: str) -> Dict[str, str]:
 
1
  import os
2
  import json
3
+ import numpy as np
4
+ from typing import Dict, List, Tuple
5
+
6
+
7
+ def preprocess_medical_image(
8
+ image: np.ndarray,
9
+ target_range: Tuple[float, float] = (0.0, 1.0),
10
+ clip_values: bool = True
11
+ ) -> np.ndarray:
12
+ """
13
+ Preprocess medical images by auto-detecting bit depth and normalizing appropriately.
14
+
15
+ This function handles both 8-bit (0-255) and 16-bit (0-65535) images automatically,
16
+ normalizing them to the target range. It's designed for medical imaging tools that
17
+ expect consistent input ranges regardless of the original image bit depth.
18
+
19
+ Args:
20
+ image (np.ndarray): Input image array (2D or 3D)
21
+ target_range (Tuple[float, float]): Target range for normalization (default: (0.0, 1.0))
22
+ clip_values (bool): Whether to clip values to target range (default: True)
23
+
24
+ Returns:
25
+ np.ndarray: Normalized image in the target range
26
+
27
+ Raises:
28
+ ValueError: If image is empty or has invalid values
29
+ ValueError: If target_range is invalid
30
+ """
31
+ if image.size == 0:
32
+ raise ValueError("Input image is empty")
33
+
34
+ if len(target_range) != 2 or target_range[0] >= target_range[1]:
35
+ raise ValueError("target_range must be a tuple of (min, max) where min < max")
36
+
37
+ # Convert to float for processing
38
+ image = image.astype(np.float32)
39
+
40
+ # Auto-detect bit depth based on maximum value
41
+ max_val = np.max(image)
42
+ min_val = np.min(image)
43
+
44
+ # Determine the expected maximum value based on bit depth
45
+ if max_val <= 255:
46
+ # 8-bit image
47
+ expected_max = 255.0
48
+ elif max_val <= 65535:
49
+ # 16-bit image
50
+ expected_max = 65535.0
51
+ else:
52
+ # Higher bit depth or already normalized, use actual max
53
+ expected_max = max_val
54
+
55
+ # Normalize to 0-1 range first
56
+ if expected_max > 0:
57
+ image = (image - min_val) / (expected_max - min_val)
58
+ else:
59
+ # Handle edge case where image has no contrast
60
+ image = np.zeros_like(image)
61
+
62
+ # Scale to target range
63
+ target_min, target_max = target_range
64
+ image = image * (target_max - target_min) + target_min
65
+
66
+ # Clip values if requested
67
+ if clip_values:
68
+ image = np.clip(image, target_min, target_max)
69
+
70
+ return image
71
 
72
 
73
  def load_prompts_from_file(file_path: str) -> Dict[str, str]: