VictorLJZ commited on
Commit
9bb904a
·
1 Parent(s): 99f2cbc

first working version

Browse files
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -1,8 +1,12 @@
1
  """ReXVQA benchmark implementation."""
2
 
 
 
 
3
  from typing import Dict, List, Optional, Any
4
  from datasets import load_dataset
5
  from .base import Benchmark, BenchmarkDataPoint
 
6
 
7
 
8
  class ReXVQABenchmark(Benchmark):
@@ -13,8 +17,13 @@ class ReXVQABenchmark(Benchmark):
13
  reasoning skills: presence assessment, location analysis, negation detection,
14
  differential diagnosis, and geometric reasoning.
15
 
 
 
 
 
16
  Paper: https://arxiv.org/abs/2506.04353
17
  Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
 
18
  """
19
 
20
  def __init__(self, data_dir: str, **kwargs):
@@ -23,34 +32,72 @@ class ReXVQABenchmark(Benchmark):
23
  Args:
24
  data_dir (str): Directory to store/cache downloaded data
25
  **kwargs: Additional configuration parameters
26
- split (str): Dataset split to use ('validation' or 'test', default: 'validation')
27
  cache_dir (str): Directory for caching HuggingFace datasets
28
  trust_remote_code (bool): Whether to trust remote code (default: False)
 
29
  """
30
- self.split = kwargs.get("split", "validation")
31
  self.cache_dir = kwargs.get("cache_dir", None)
32
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
 
 
 
33
 
34
  super().__init__(data_dir, **kwargs)
35
 
36
  def _load_data(self) -> None:
37
- """Load ReXVQA data from HuggingFace."""
38
  try:
39
- # Load dataset from HuggingFace
40
- print(f"Loading ReXVQA {self.split} split from HuggingFace...")
 
 
 
 
 
 
41
 
42
- dataset = load_dataset(
43
- "rajpurkarlab/ReXVQA",
44
- split=self.split,
45
- cache_dir=self.cache_dir,
46
- trust_remote_code=self.trust_remote_code
47
- )
48
 
49
- print(f"Loaded {len(dataset)} examples from ReXVQA {self.split} split")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  self.data_points = []
52
 
53
- for i, item in enumerate(dataset):
 
 
 
 
 
54
  try:
55
  data_point = self._parse_rexvqa_item(item, i)
56
  if data_point:
@@ -63,11 +110,28 @@ class ReXVQABenchmark(Benchmark):
63
  except Exception as e:
64
  raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
67
  """Parse a ReXVQA dataset item.
68
 
69
  Args:
70
- item (Dict[str, Any]): Dataset item from HuggingFace
71
  index (int): Item index
72
 
73
  Returns:
@@ -76,96 +140,141 @@ class ReXVQABenchmark(Benchmark):
76
  # Extract basic information
77
  question_id = item.get("id", f"rexvqa_{self.split}_{index}")
78
  question = item.get("question", "")
79
- answer = item.get("answer", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if not question:
82
  return None
83
 
84
- # Handle image
85
  images = None
86
- if "image" in item and item["image"] is not None:
87
- # Save image to local cache directory
88
- image_filename = f"{question_id}.png"
89
- image_path = self.data_dir / "images" / image_filename
90
-
91
- # Create images directory if it doesn't exist
92
- image_path.parent.mkdir(parents=True, exist_ok=True)
93
-
94
- # Save image if it doesn't exist
95
- if not image_path.exists():
96
- try:
97
- item["image"].save(str(image_path))
98
- except Exception as e:
99
- print(f"Error saving image for {question_id}: {e}")
100
- return None
101
-
102
- images = [str(image_path)]
103
 
104
  # Extract metadata
105
  metadata = {
106
  "dataset": "rexvqa",
107
  "split": self.split,
108
- "study_id": item.get("study_id", ""),
109
- "image_id": item.get("image_id", ""),
110
- "reasoning_type": item.get("reasoning_type", ""),
111
- "anatomical_location": item.get("anatomical_location", ""),
112
- "pathology": item.get("pathology", ""),
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
 
115
- # Determine category from reasoning type
116
- category = item.get("reasoning_type", "")
117
 
118
- # Use study_id as case_id for grouping related questions
119
- case_id = item.get("study_id", "")
120
 
121
  return BenchmarkDataPoint(
122
  id=question_id,
123
- text=question,
124
  images=images,
125
- correct_answer=answer,
126
  metadata=metadata,
127
  case_id=case_id,
128
  category=category,
129
  )
130
 
131
- def get_pathologies(self) -> List[str]:
132
- """Get all unique pathologies in the dataset.
133
-
134
- Returns:
135
- List[str]: List of unique pathologies
136
- """
137
- pathologies = set()
138
- for dp in self:
139
- pathology = dp.metadata.get("pathology", "")
140
- if pathology:
141
- pathologies.add(pathology)
142
- return sorted(list(pathologies))
143
-
144
- def get_by_pathology(self, pathology: str) -> List[BenchmarkDataPoint]:
145
- """Get all data points about a specific pathology.
146
 
147
  Args:
148
- pathology (str): Pathology to filter by
 
149
 
150
  Returns:
151
- List[BenchmarkDataPoint]: List of data points about the pathology
152
  """
153
- return [dp for dp in self if dp.metadata.get("pathology", "") == pathology]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- def get_dataset_info(self) -> Dict[str, Any]:
156
- """Get information about the ReXVQA dataset.
157
 
 
 
 
 
 
158
  Returns:
159
- Dict[str, Any]: Dataset information
160
  """
161
- return {
162
- "name": "ReXVQA",
163
- "description": "Large-scale Visual Question Answering Benchmark for Chest Radiology",
164
- "split": self.split,
165
- "size": len(self.data_points),
166
- "reasoning_types": self.get_reasoning_types(),
167
- "pathologies": self.get_pathologies(),
168
- "categories": self.get_categories(),
169
- "paper": "https://arxiv.org/abs/2506.04353",
170
- "dataset_url": "https://huggingface.co/datasets/rajpurkarlab/ReXVQA",
171
- }
 
 
 
 
 
 
 
 
 
 
 
1
  """ReXVQA benchmark implementation."""
2
 
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
  from typing import Dict, List, Optional, Any
7
  from datasets import load_dataset
8
  from .base import Benchmark, BenchmarkDataPoint
9
+ import hashlib
10
 
11
 
12
  class ReXVQABenchmark(Benchmark):
 
17
  reasoning skills: presence assessment, location analysis, negation detection,
18
  differential diagnosis, and geometric reasoning.
19
 
20
+ The dataset consists of two separate HuggingFace datasets:
21
+ - ReXVQA: Contains questions, answers, and metadata
22
+ - ReXGradient-160K: Contains the actual chest X-ray images
23
+
24
  Paper: https://arxiv.org/abs/2506.04353
25
  Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
26
+ Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K
27
  """
28
 
29
  def __init__(self, data_dir: str, **kwargs):
 
32
  Args:
33
  data_dir (str): Directory to store/cache downloaded data
34
  **kwargs: Additional configuration parameters
35
+ split (str): Dataset split to use (default: 'test')
36
  cache_dir (str): Directory for caching HuggingFace datasets
37
  trust_remote_code (bool): Whether to trust remote code (default: False)
38
+ max_questions (int): Maximum number of questions to load (default: None, load all)
39
  """
40
+ self.split = kwargs.get("split", "test")
41
  self.cache_dir = kwargs.get("cache_dir", None)
42
  self.trust_remote_code = kwargs.get("trust_remote_code", False)
43
+ self.max_questions = kwargs.get("max_questions", None)
44
+ self.image_dataset = None
45
+ self.image_mapping = {} # Maps study_id to image data
46
 
47
  super().__init__(data_dir, **kwargs)
48
 
49
  def _load_data(self) -> None:
50
+ """Load ReXVQA data from local JSON file."""
51
  try:
52
+ # Construct path to the JSON file
53
+ json_file_path = os.path.join("benchmarking", "data", "test_vqa_data.json")
54
+
55
+ # Check if file exists
56
+ if not os.path.exists(json_file_path):
57
+ raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
58
+
59
+ print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
60
 
61
+ # Load JSON file directly
62
+ with open(json_file_path, 'r', encoding='utf-8') as f:
63
+ questions_data = json.load(f)
 
 
 
64
 
65
+ # ReXVQA format: {question_id: {question_data}, ...}
66
+ questions_list = []
67
+ for question_id, question_data in questions_data.items():
68
+ # Add the question_id to the question_data for reference
69
+ question_data['id'] = question_id
70
+ questions_list.append(question_data)
71
+
72
+ print(f"Loaded {len(questions_list)} questions from local JSON file")
73
+
74
+ # Load images dataset from ReXGradient-160K
75
+ print("Loading ReXGradient-160K images dataset...")
76
+ try:
77
+ self.image_dataset = load_dataset(
78
+ "rajpurkarlab/ReXGradient-160K",
79
+ split="test",
80
+ cache_dir=self.cache_dir,
81
+ trust_remote_code=self.trust_remote_code
82
+ )
83
+ print(f"Loaded {len(self.image_dataset)} images from ReXGradient-160K")
84
+
85
+ # Create mapping from study_id to image data
86
+ self._create_image_mapping()
87
+
88
+ except Exception as e:
89
+ print(f"Warning: Could not load ReXGradient-160K dataset: {e}")
90
+ print("Proceeding without images...")
91
+ self.load_images = False
92
 
93
  self.data_points = []
94
 
95
+ # Process questions (limit if max_questions is specified)
96
+ questions_to_process = questions_list
97
+ if self.max_questions:
98
+ questions_to_process = questions_list[:min(self.max_questions, len(questions_list))]
99
+
100
+ for i, item in enumerate(questions_to_process):
101
  try:
102
  data_point = self._parse_rexvqa_item(item, i)
103
  if data_point:
 
110
  except Exception as e:
111
  raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
112
 
113
+ def _create_image_mapping(self) -> None:
114
+ """Create mapping from study_id to image data."""
115
+ if not self.image_dataset:
116
+ return
117
+
118
+ print("Creating image mapping...")
119
+
120
+ for item in self.image_dataset:
121
+ study_instance_uid = item.get("StudyInstanceUid", "")
122
+ if study_instance_uid:
123
+ # Store the image data for this study using StudyInstanceUid as key
124
+ if study_instance_uid not in self.image_mapping:
125
+ self.image_mapping[study_instance_uid] = []
126
+ self.image_mapping[study_instance_uid].append(item)
127
+
128
+ print(f"Created image mapping for {len(self.image_mapping)} studies")
129
+
130
  def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
131
  """Parse a ReXVQA dataset item.
132
 
133
  Args:
134
+ item (Dict[str, Any]): Dataset item from JSON file
135
  index (int): Item index
136
 
137
  Returns:
 
140
  # Extract basic information
141
  question_id = item.get("id", f"rexvqa_{self.split}_{index}")
142
  question = item.get("question", "")
143
+
144
+ # Handle multiple choice options
145
+ options = item.get("options", [])
146
+ if options:
147
+ # Add options to the question for multiple choice format
148
+ question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
149
+ else:
150
+ question_with_options = question
151
+
152
+ # Get correct answer
153
+ correct_answer = item.get("correct_answer", "")
154
+
155
+ # If we have options and a letter answer, get the full text
156
+ if options and correct_answer and len(correct_answer) == 1:
157
+ try:
158
+ # Find the option that starts with the correct letter
159
+ for option in options:
160
+ if option.strip().startswith(f"{correct_answer}."):
161
+ correct_answer = option.strip()
162
+ break
163
+ except:
164
+ pass # Keep the original letter if parsing fails
165
 
166
  if not question:
167
  return None
168
 
169
+ # Handle images - look for ImagePath field
170
  images = None
171
+ image_paths = item.get("ImagePath", [])
172
+ study_id = item.get("study_id", "")
173
+ study_instance_uid = item.get("StudyInstanceUid", "")
174
+
175
+ if image_paths:
176
+ # Use local image paths if available
177
+ images = [str(Path(path)) for path in image_paths if path]
178
+ elif study_instance_uid and study_instance_uid in self.image_mapping:
179
+ # Use StudyInstanceUid for matching with HuggingFace images
180
+ images = self._get_images_for_study(study_instance_uid, question_id)
 
 
 
 
 
 
 
181
 
182
  # Extract metadata
183
  metadata = {
184
  "dataset": "rexvqa",
185
  "split": self.split,
186
+ "study_id": study_id,
187
+ "study_instance_uid": study_instance_uid,
188
+ "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
189
+ "category": item.get("category", ""),
190
+ "class": item.get("class", ""),
191
+ "subcategory": item.get("subcategory", ""),
192
+ "patient_id": item.get("PatientID", ""),
193
+ "patient_age": item.get("PatientAge", ""),
194
+ "patient_sex": item.get("PatientSex", ""),
195
+ "study_date": item.get("StudyDate", ""),
196
+ "indication": item.get("Indication", ""),
197
+ "findings": item.get("Findings", ""),
198
+ "impression": item.get("Impression", ""),
199
+ "image_modality": item.get("ImageModality", []),
200
+ "image_view_position": item.get("ImageViewPosition", []),
201
+ "correct_answer_explanation": item.get("correct_answer_explanation", ""),
202
  }
203
 
204
+ # Determine category from task_name or category field
205
+ category = item.get("task_name", item.get("category", ""))
206
 
207
+ # Use study_id as case_id for grouping related questions (keep using compound study_id for grouping)
208
+ case_id = study_id
209
 
210
  return BenchmarkDataPoint(
211
  id=question_id,
212
+ text=question_with_options,
213
  images=images,
214
+ correct_answer=correct_answer,
215
  metadata=metadata,
216
  case_id=case_id,
217
  category=category,
218
  )
219
 
220
+ def _get_images_for_study(self, study_instance_uid: str, question_id: str) -> Optional[List[str]]:
221
+ """Get images for a specific study and save them locally.
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  Args:
224
+ study_instance_uid (str): Study Instance UID
225
+ question_id (str): Question ID for filename
226
 
227
  Returns:
228
+ Optional[List[str]]: List of image paths
229
  """
230
+ if study_instance_uid not in self.image_mapping:
231
+ return None
232
+
233
+ images = []
234
+ study_images = self.image_mapping[study_instance_uid]
235
+
236
+ # Create images directory if it doesn't exist
237
+ images_dir = self.data_dir / "images"
238
+ images_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ # Get every image for the study
241
+ if not images and study_images:
242
+ for img_data in study_images:
243
+ image_path = self._save_image(img_data, question_id, images_dir)
244
+ if image_path:
245
+ images.append(image_path)
246
+
247
+ return images if images else None
248
 
249
+ def _save_image(self, img_data: Dict[str, Any], question_id: str, images_dir) -> Optional[str]:
250
+ """Save image data to local file.
251
 
252
+ Args:
253
+ img_data (Dict[str, Any]): Image data from dataset
254
+ question_id (str): Question ID for filename
255
+ images_dir: Directory to save images
256
+
257
  Returns:
258
+ Optional[str]: Path to saved image
259
  """
260
+ try:
261
+ # Get the image from the dataset item
262
+ image = img_data.get("image")
263
+ if image is None:
264
+ return None
265
+
266
+ # Generate filename using StudyInstanceUid
267
+ study_instance_uid = img_data.get("StudyInstanceUid", "")
268
+ filename_hash = hashlib.md5(f"{question_id}_{study_instance_uid}".encode()).hexdigest()[:8]
269
+ image_filename = f"{question_id}_{filename_hash}.png"
270
+ image_path = images_dir / image_filename
271
+
272
+ # Save image if it doesn't exist
273
+ if not image_path.exists():
274
+ image.save(str(image_path))
275
+
276
+ return str(image_path)
277
+
278
+ except Exception as e:
279
+ print(f"Error saving image for question {question_id}: {e}")
280
+ return None
benchmarking/cli.py CHANGED
@@ -102,6 +102,12 @@ def run_benchmark_command(args) -> None:
102
  print("\n" + "="*50)
103
  print("BENCHMARK COMPLETED")
104
  print("="*50)
 
 
 
 
 
 
105
  print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
106
  print(f"Total Questions: {summary['results']['total_questions']}")
107
  print(f"Correct Answers: {summary['results']['correct_answers']}")
 
102
  print("\n" + "="*50)
103
  print("BENCHMARK COMPLETED")
104
  print("="*50)
105
+
106
+ # Check if benchmark run was successful
107
+ if "error" in summary:
108
+ print(f"Error: {summary['error']}")
109
+ return
110
+
111
  print(f"Overall Accuracy: {summary['results']['accuracy']:.2f}%")
112
  print(f"Total Questions: {summary['results']['total_questions']}")
113
  print(f"Correct Answers: {summary['results']['correct_answers']}")
benchmarking/llm_providers/base.py CHANGED
@@ -73,8 +73,8 @@ class LLMProvider(ABC):
73
  # Simple test request
74
  test_request = LLMRequest(
75
  text="Hello",
76
- temperature=0.0,
77
- max_tokens=10
78
  )
79
  response = self.generate_response(test_request)
80
  return response.content is not None and len(response.content.strip()) > 0
 
73
  # Simple test request
74
  test_request = LLMRequest(
75
  text="Hello",
76
+ temperature=0.5,
77
+ max_tokens=1000
78
  )
79
  response = self.generate_response(test_request)
80
  return response.content is not None and len(response.content.strip()) > 0
benchmarking/llm_providers/google_provider.py CHANGED
@@ -44,30 +44,28 @@ class GoogleProvider(LLMProvider):
44
  if request.system_prompt:
45
  messages.append(SystemMessage(content=request.system_prompt))
46
 
47
- # Build user message content
48
- user_content = []
49
- user_content.append({
50
- "type": "text",
51
- "text": request.text
52
- })
53
-
54
- # Add images if provided
55
  if request.images:
 
 
 
 
56
  valid_images = self._validate_image_paths(request.images)
57
  for image_path in valid_images:
58
  try:
59
- # For langchain Google, we can pass the image data directly
60
  image_b64 = self._encode_image(image_path)
61
- user_content.append({
62
  "type": "image_url",
63
- "image_url": {
64
- "url": f"data:image/jpeg;base64,{image_b64}"
65
- }
66
  })
67
  except Exception as e:
68
  print(f"Error reading image {image_path}: {e}")
69
-
70
- messages.append(HumanMessage(content=user_content))
 
 
 
71
 
72
  # Make API call using langchain
73
  try:
 
44
  if request.system_prompt:
45
  messages.append(SystemMessage(content=request.system_prompt))
46
 
47
+ # For langchain Google Gemini, we need to construct content differently
 
 
 
 
 
 
 
48
  if request.images:
49
+ # For multimodal content, use a list format
50
+ content_parts = [request.text]
51
+
52
+ # Add images if provided
53
  valid_images = self._validate_image_paths(request.images)
54
  for image_path in valid_images:
55
  try:
56
+ # For langchain Google, pass image data as base64
57
  image_b64 = self._encode_image(image_path)
58
+ content_parts.append({
59
  "type": "image_url",
60
+ "image_url": f"data:image/jpeg;base64,{image_b64}"
 
 
61
  })
62
  except Exception as e:
63
  print(f"Error reading image {image_path}: {e}")
64
+
65
+ messages.append(HumanMessage(content=content_parts))
66
+ else:
67
+ # Text-only message
68
+ messages.append(HumanMessage(content=request.text))
69
 
70
  # Make API call using langchain
71
  try:
benchmarking/runner.py CHANGED
@@ -57,12 +57,12 @@ class BenchmarkRunner:
57
  self.output_dir = Path(config.output_dir)
58
  self.output_dir.mkdir(parents=True, exist_ok=True)
59
 
60
- # Set up logging
61
- self._setup_logging()
62
-
63
  # Generate unique run ID
64
  self.run_id = f"{config.benchmark_name}_{config.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
65
 
 
 
 
66
  self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
67
 
68
  def _setup_logging(self) -> None:
 
57
  self.output_dir = Path(config.output_dir)
58
  self.output_dir.mkdir(parents=True, exist_ok=True)
59
 
 
 
 
60
  # Generate unique run ID
61
  self.run_id = f"{config.benchmark_name}_{config.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
62
 
63
+ # Set up logging
64
+ self._setup_logging()
65
+
66
  self.logger.info(f"Initialized benchmark runner with ID: {self.run_id}")
67
 
68
  def _setup_logging(self) -> None: