VictorLJZ commited on
Commit
c7b65ec
·
1 Parent(s): d07a267

rexvqa now works

Browse files
.gitignore CHANGED
@@ -175,4 +175,6 @@ temp/
175
 
176
  hf_files/
177
  medrax-pdfs/
178
- model-weights/
 
 
 
175
 
176
  hf_files/
177
  medrax-pdfs/
178
+ model-weights/
179
+
180
+ .DS_Store
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -2,11 +2,10 @@
2
 
3
  import json
4
  import os
5
- from pathlib import Path
6
  from typing import Dict, List, Optional, Any
7
  from datasets import load_dataset
8
  from .base import Benchmark, BenchmarkDataPoint
9
- import hashlib
10
 
11
 
12
  class ReXVQABenchmark(Benchmark):
@@ -19,7 +18,7 @@ class ReXVQABenchmark(Benchmark):
19
 
20
  The dataset consists of two separate HuggingFace datasets:
21
  - ReXVQA: Contains questions, answers, and metadata
22
- - ReXGradient-160K: Contains the actual chest X-ray images
23
 
24
  Paper: https://arxiv.org/abs/2506.04353
25
  Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
@@ -36,11 +35,13 @@ class ReXVQABenchmark(Benchmark):
36
  cache_dir (str): Directory for caching HuggingFace datasets
37
  trust_remote_code (bool): Whether to trust remote code (default: False)
38
  max_questions (int): Maximum number of questions to load (default: None, load all)
 
39
  """
40
  self.split = kwargs.get("split", "test")
41
  self.cache_dir = kwargs.get("cache_dir", None)
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
  self.max_questions = kwargs.get("max_questions", None)
 
44
  self.image_dataset = None
45
  self.image_mapping = {} # Maps study_id to image data
46
 
@@ -50,7 +51,7 @@ class ReXVQABenchmark(Benchmark):
50
  """Load ReXVQA data from local JSON file."""
51
  try:
52
  # Construct path to the JSON file
53
- json_file_path = os.path.join("benchmarking", "data", "test_vqa_data.json")
54
 
55
  # Check if file exists
56
  if not os.path.exists(json_file_path):
@@ -71,8 +72,8 @@ class ReXVQABenchmark(Benchmark):
71
 
72
  print(f"Loaded {len(questions_list)} questions from local JSON file")
73
 
74
- # Load images dataset from ReXGradient-160K
75
- print("Loading ReXGradient-160K images dataset...")
76
  try:
77
  self.image_dataset = load_dataset(
78
  "rajpurkarlab/ReXGradient-160K",
@@ -80,9 +81,9 @@ class ReXVQABenchmark(Benchmark):
80
  cache_dir=self.cache_dir,
81
  trust_remote_code=self.trust_remote_code
82
  )
83
- print(f"Loaded {len(self.image_dataset)} images from ReXGradient-160K")
84
 
85
- # Create mapping from study_id to image data
86
  self._create_image_mapping()
87
 
88
  except Exception as e:
@@ -111,7 +112,7 @@ class ReXVQABenchmark(Benchmark):
111
  raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
112
 
113
  def _create_image_mapping(self) -> None:
114
- """Create mapping from study_id to image data."""
115
  if not self.image_dataset:
116
  return
117
 
@@ -120,7 +121,7 @@ class ReXVQABenchmark(Benchmark):
120
  for item in self.image_dataset:
121
  study_instance_uid = item.get("StudyInstanceUid", "")
122
  if study_instance_uid:
123
- # Store the image data for this study using StudyInstanceUid as key
124
  if study_instance_uid not in self.image_mapping:
125
  self.image_mapping[study_instance_uid] = []
126
  self.image_mapping[study_instance_uid].append(item)
@@ -152,39 +153,26 @@ class ReXVQABenchmark(Benchmark):
152
  # Get correct answer
153
  correct_answer = item.get("correct_answer", "")
154
 
155
- # If we have options and a letter answer, get the full text
156
- if options and correct_answer and len(correct_answer) == 1:
157
- try:
158
- # Find the option that starts with the correct letter
159
- for option in options:
160
- if option.strip().startswith(f"{correct_answer}."):
161
- correct_answer = option.strip()
162
- break
163
- except:
164
- pass # Keep the original letter if parsing fails
165
-
166
  if not question:
167
  return None
168
 
169
- # Handle images - look for ImagePath field
170
  images = None
171
- image_paths = item.get("ImagePath", [])
172
- study_id = item.get("study_id", "")
173
- study_instance_uid = item.get("StudyInstanceUid", "")
174
-
175
- if image_paths:
176
- # Use local image paths if available
177
- images = [str(Path(path)) for path in image_paths if path]
178
- elif study_instance_uid and study_instance_uid in self.image_mapping:
179
- # Use StudyInstanceUid for matching with HuggingFace images
180
- images = self._get_images_for_study(study_instance_uid, question_id)
181
 
182
  # Extract metadata
183
  metadata = {
184
  "dataset": "rexvqa",
185
  "split": self.split,
186
- "study_id": study_id,
187
- "study_instance_uid": study_instance_uid,
188
  "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
189
  "category": item.get("category", ""),
190
  "class": item.get("class", ""),
@@ -201,12 +189,9 @@ class ReXVQABenchmark(Benchmark):
201
  "correct_answer_explanation": item.get("correct_answer_explanation", ""),
202
  }
203
 
204
- # Determine category from task_name or category field
205
- category = item.get("task_name", item.get("category", ""))
206
-
207
- # Use study_id as case_id for grouping related questions (keep using compound study_id for grouping)
208
- case_id = study_id
209
-
210
  return BenchmarkDataPoint(
211
  id=question_id,
212
  text=question_with_options,
@@ -216,65 +201,3 @@ class ReXVQABenchmark(Benchmark):
216
  case_id=case_id,
217
  category=category,
218
  )
219
-
220
- def _get_images_for_study(self, study_instance_uid: str, question_id: str) -> Optional[List[str]]:
221
- """Get images for a specific study and save them locally.
222
-
223
- Args:
224
- study_instance_uid (str): Study Instance UID
225
- question_id (str): Question ID for filename
226
-
227
- Returns:
228
- Optional[List[str]]: List of image paths
229
- """
230
- if study_instance_uid not in self.image_mapping:
231
- return None
232
-
233
- images = []
234
- study_images = self.image_mapping[study_instance_uid]
235
-
236
- # Create images directory if it doesn't exist
237
- images_dir = self.data_dir / "images"
238
- images_dir.mkdir(parents=True, exist_ok=True)
239
-
240
- # Get every image for the study
241
- if not images and study_images:
242
- for img_data in study_images:
243
- image_path = self._save_image(img_data, question_id, images_dir)
244
- if image_path:
245
- images.append(image_path)
246
-
247
- return images if images else None
248
-
249
- def _save_image(self, img_data: Dict[str, Any], question_id: str, images_dir) -> Optional[str]:
250
- """Save image data to local file.
251
-
252
- Args:
253
- img_data (Dict[str, Any]): Image data from dataset
254
- question_id (str): Question ID for filename
255
- images_dir: Directory to save images
256
-
257
- Returns:
258
- Optional[str]: Path to saved image
259
- """
260
- try:
261
- # Get the image from the dataset item
262
- image = img_data.get("image")
263
- if image is None:
264
- return None
265
-
266
- # Generate filename using StudyInstanceUid
267
- study_instance_uid = img_data.get("StudyInstanceUid", "")
268
- filename_hash = hashlib.md5(f"{question_id}_{study_instance_uid}".encode()).hexdigest()[:8]
269
- image_filename = f"{question_id}_{filename_hash}.png"
270
- image_path = images_dir / image_filename
271
-
272
- # Save image if it doesn't exist
273
- if not image_path.exists():
274
- image.save(str(image_path))
275
-
276
- return str(image_path)
277
-
278
- except Exception as e:
279
- print(f"Error saving image for question {question_id}: {e}")
280
- return None
 
2
 
3
  import json
4
  import os
 
5
  from typing import Dict, List, Optional, Any
6
  from datasets import load_dataset
7
  from .base import Benchmark, BenchmarkDataPoint
8
+ from pathlib import Path
9
 
10
 
11
  class ReXVQABenchmark(Benchmark):
 
18
 
19
  The dataset consists of two separate HuggingFace datasets:
20
  - ReXVQA: Contains questions, answers, and metadata
21
+ - ReXGradient-160K: Contains metadata only (images are in separate part files)
22
 
23
  Paper: https://arxiv.org/abs/2506.04353
24
  Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
 
35
  cache_dir (str): Directory for caching HuggingFace datasets
36
  trust_remote_code (bool): Whether to trust remote code (default: False)
37
  max_questions (int): Maximum number of questions to load (default: None, load all)
38
+ images_dir (str): Directory containing extracted PNG images (default: None)
39
  """
40
  self.split = kwargs.get("split", "test")
41
  self.cache_dir = kwargs.get("cache_dir", None)
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
  self.max_questions = kwargs.get("max_questions", None)
44
+ self.images_dir = "benchmarking/data/rexvqa/images/deid_png"
45
  self.image_dataset = None
46
  self.image_mapping = {} # Maps study_id to image data
47
 
 
51
  """Load ReXVQA data from local JSON file."""
52
  try:
53
  # Construct path to the JSON file
54
+ json_file_path = os.path.join("benchmarking", "data", "rexvqa", "test_vqa_data.json")
55
 
56
  # Check if file exists
57
  if not os.path.exists(json_file_path):
 
72
 
73
  print(f"Loaded {len(questions_list)} questions from local JSON file")
74
 
75
+ # Load images dataset from ReXGradient-160K (metadata only)
76
+ print("Loading ReXGradient-160K metadata dataset...")
77
  try:
78
  self.image_dataset = load_dataset(
79
  "rajpurkarlab/ReXGradient-160K",
 
81
  cache_dir=self.cache_dir,
82
  trust_remote_code=self.trust_remote_code
83
  )
84
+ print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
85
 
86
+ # Create mapping from study_id to image metadata
87
  self._create_image_mapping()
88
 
89
  except Exception as e:
 
112
  raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
113
 
114
  def _create_image_mapping(self) -> None:
115
+ """Create mapping from study_id to image metadata."""
116
  if not self.image_dataset:
117
  return
118
 
 
121
  for item in self.image_dataset:
122
  study_instance_uid = item.get("StudyInstanceUid", "")
123
  if study_instance_uid:
124
+ # Store the image metadata for this study using StudyInstanceUid as key
125
  if study_instance_uid not in self.image_mapping:
126
  self.image_mapping[study_instance_uid] = []
127
  self.image_mapping[study_instance_uid].append(item)
 
153
  # Get correct answer
154
  correct_answer = item.get("correct_answer", "")
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  if not question:
157
  return None
158
 
159
+ # Handle images using ImagePath field
160
  images = None
161
+ if self.images_dir and "ImagePath" in item and item["ImagePath"]:
162
+ images = []
163
+ for rel_path in item["ImagePath"]:
164
+ # Remove leading ../ if present
165
+ norm_rel_path = rel_path.lstrip("./")
166
+ # Join with images_dir root
167
+ full_path = str(Path(self.images_dir).parent / norm_rel_path)
168
+ images.append(full_path)
 
 
169
 
170
  # Extract metadata
171
  metadata = {
172
  "dataset": "rexvqa",
173
  "split": self.split,
174
+ "study_id": item.get("study_id", ""),
175
+ "study_instance_uid": item.get("StudyInstanceUid", ""),
176
  "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
177
  "category": item.get("category", ""),
178
  "class": item.get("class", ""),
 
189
  "correct_answer_explanation": item.get("correct_answer_explanation", ""),
190
  }
191
 
192
+ case_id = item.get("study_id", "")
193
+ category = item.get("task_name", "")
194
+
 
 
 
195
  return BenchmarkDataPoint(
196
  id=question_id,
197
  text=question_with_options,
 
201
  case_id=case_id,
202
  category=category,
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarking/cli.py CHANGED
@@ -22,7 +22,6 @@ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMPro
22
  provider_map = {
23
  "openai": OpenAIProvider,
24
  "google": GoogleProvider,
25
- "xai": XAIProvider,
26
  "medrax": MedRAXProvider,
27
  }
28
 
@@ -78,6 +77,7 @@ def run_benchmark_command(args) -> None:
78
  output_dir=args.output_dir,
79
  max_questions=args.max_questions,
80
  temperature=args.temperature,
 
81
  max_tokens=args.max_tokens
82
  )
83
 
@@ -112,12 +112,13 @@ def main():
112
  # Run benchmark command
113
  run_parser = subparsers.add_parser("run", help="Run a benchmark")
114
  run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
115
- run_parser.add_argument("--provider", required=True, choices=["openai", "google", "xai", "medrax"], help="LLM provider")
116
  run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
117
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
118
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
119
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
120
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
 
121
  run_parser.add_argument("--max-tokens", type=int, default=5000, help="Maximum tokens per response")
122
 
123
  run_parser.set_defaults(func=run_benchmark_command)
 
22
  provider_map = {
23
  "openai": OpenAIProvider,
24
  "google": GoogleProvider,
 
25
  "medrax": MedRAXProvider,
26
  }
27
 
 
77
  output_dir=args.output_dir,
78
  max_questions=args.max_questions,
79
  temperature=args.temperature,
80
+ top_p=args.top_p,
81
  max_tokens=args.max_tokens
82
  )
83
 
 
112
  # Run benchmark command
113
  run_parser = subparsers.add_parser("run", help="Run a benchmark")
114
  run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
115
+ run_parser.add_argument("--provider", required=True, choices=["openai", "google", "medrax"], help="LLM provider")
116
  run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
117
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
118
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
119
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
120
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
121
+ run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
122
  run_parser.add_argument("--max-tokens", type=int, default=5000, help="Maximum tokens per response")
123
 
124
  run_parser.set_defaults(func=run_benchmark_command)
benchmarking/llm_providers/__init__.py CHANGED
@@ -4,7 +4,6 @@ from .base import LLMProvider, LLMRequest, LLMResponse
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
7
- from .xai_provider import XAIProvider
8
 
9
  __all__ = [
10
  "LLMProvider",
@@ -13,5 +12,4 @@ __all__ = [
13
  "OpenAIProvider",
14
  "GoogleProvider",
15
  "MedRAXProvider",
16
- "XAIProvider",
17
  ]
 
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
 
7
 
8
  __all__ = [
9
  "LLMProvider",
 
12
  "OpenAIProvider",
13
  "GoogleProvider",
14
  "MedRAXProvider",
 
15
  ]