Junzhe Li commited on
Commit
b93ad3f
·
1 Parent(s): 9006287

updated benchmarks

Browse files
benchmarking/benchmarks/base.py CHANGED
@@ -14,8 +14,6 @@ class BenchmarkDataPoint:
14
  text: str # The question/prompt
15
  images: Optional[List[str]] = None # List of image paths
16
  correct_answer: Optional[str] = None # Ground truth answer
17
- case_id: Optional[str] = None # For grouping related questions
18
- category: Optional[str] = None # Type of question/task
19
  metadata: Optional[Dict[str, Any]] = None # Additional metadata
20
 
21
 
@@ -36,26 +34,32 @@ class Benchmark(ABC):
36
  """
37
  self.data_dir = Path(data_dir)
38
  self.config = kwargs
 
39
  self.data_points = []
40
  self._load_data()
41
  self._shuffle_data()
42
 
 
 
 
 
 
 
 
43
  @abstractmethod
44
  def _load_data(self) -> None:
45
  """Load benchmark data from the data directory."""
46
  pass
47
 
48
- def _shuffle_data(self) -> None:
49
- """Shuffle the data points if a random seed is provided.
50
 
51
  This method is called automatically after data loading to ensure
52
  reproducible benchmark runs when a random seed is specified.
53
  """
54
- random_seed = self.config.get("random_seed", None)
55
- if random_seed is not None:
56
- random.seed(random_seed)
57
- random.shuffle(self.data_points)
58
- print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
59
 
60
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
61
  """Get a specific data point by index.
@@ -82,28 +86,6 @@ class Benchmark(ABC):
82
  """
83
  return [self.get_data_point(i) for i in indices]
84
 
85
- def get_by_category(self, category: str) -> List[BenchmarkDataPoint]:
86
- """Get all data points of a specific category.
87
-
88
- Args:
89
- category (str): Category to filter by
90
-
91
- Returns:
92
- List[BenchmarkDataPoint]: List of data points in the category
93
- """
94
- return [dp for dp in self if dp.category == category]
95
-
96
- def get_by_case_id(self, case_id: str) -> List[BenchmarkDataPoint]:
97
- """Get all data points for a specific case.
98
-
99
- Args:
100
- case_id (str): Case ID to filter by
101
-
102
- Returns:
103
- List[BenchmarkDataPoint]: List of data points for the case
104
- """
105
- return [dp for dp in self if dp.case_id == case_id]
106
-
107
  def __str__(self) -> str:
108
  """String representation of the benchmark."""
109
  return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})"
@@ -117,56 +99,6 @@ class Benchmark(ABC):
117
  for i in range(len(self)):
118
  yield self.get_data_point(i)
119
 
120
- def get_categories(self) -> List[str]:
121
- """Get all unique categories in the benchmark.
122
-
123
- Returns:
124
- List[str]: List of unique categories
125
- """
126
- categories = set()
127
- for dp in self:
128
- if dp.category:
129
- categories.add(dp.category)
130
- return sorted(list(categories))
131
-
132
- def get_case_ids(self) -> List[str]:
133
- """Get all unique case IDs in the benchmark.
134
-
135
- Returns:
136
- List[str]: List of unique case IDs
137
- """
138
- case_ids = set()
139
- for dp in self:
140
- if dp.case_id:
141
- case_ids.add(dp.case_id)
142
- return sorted(list(case_ids))
143
-
144
- def get_stats(self) -> Dict[str, Any]:
145
- """Get statistics about the benchmark.
146
-
147
- Returns:
148
- Dict[str, Any]: Dictionary containing benchmark statistics
149
- """
150
- stats = {
151
- "total_questions": len(self),
152
- "total_cases": len(self.get_case_ids()),
153
- "categories": self.get_categories(),
154
- "category_counts": {},
155
- "has_images": False,
156
- "num_images": 0,
157
- }
158
-
159
- for dp in self:
160
- # Category counts
161
- if dp.category:
162
- stats["category_counts"][dp.category] = stats["category_counts"].get(dp.category, 0) + 1
163
-
164
- # Image statistics
165
- if dp.images:
166
- stats["has_images"] = True
167
- stats["num_images"] += len(dp.images)
168
- return stats
169
-
170
  def validate_images(self) -> Tuple[List[str], List[str]]:
171
  """Validate that all image paths exist.
172
 
 
14
  text: str # The question/prompt
15
  images: Optional[List[str]] = None # List of image paths
16
  correct_answer: Optional[str] = None # Ground truth answer
 
 
17
  metadata: Optional[Dict[str, Any]] = None # Additional metadata
18
 
19
 
 
34
  """
35
  self.data_dir = Path(data_dir)
36
  self.config = kwargs
37
+
38
  self.data_points = []
39
  self._load_data()
40
  self._shuffle_data()
41
 
42
+ self.max_questions = kwargs.get("max_questions", None)
43
+ if self.max_questions:
44
+ self.data_points = self.data_points[:self.max_questions]
45
+ print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
46
+ else:
47
+ print(f"Loaded all {len(self.data_points)} questions from {self.__class__.__name__}")
48
+
49
  @abstractmethod
50
  def _load_data(self) -> None:
51
  """Load benchmark data from the data directory."""
52
  pass
53
 
54
+ def _shuffle_data(self, random_seed: Optional[int]=42) -> None:
55
+ """Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
56
 
57
  This method is called automatically after data loading to ensure
58
  reproducible benchmark runs when a random seed is specified.
59
  """
60
+ random.seed(random_seed)
61
+ random.shuffle(self.data_points)
62
+ print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
 
 
63
 
64
  def get_data_point(self, index: int) -> BenchmarkDataPoint:
65
  """Get a specific data point by index.
 
86
  """
87
  return [self.get_data_point(i) for i in indices]
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def __str__(self) -> str:
90
  """String representation of the benchmark."""
91
  return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})"
 
99
  for i in range(len(self)):
100
  yield self.get_data_point(i)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def validate_images(self) -> Tuple[List[str], List[str]]:
103
  """Validate that all image paths exist.
104
 
benchmarking/benchmarks/chestagentbench_benchmark.py CHANGED
@@ -9,19 +9,18 @@ class ChestAgentBenchBenchmark(Benchmark):
9
  Loads the dataset from a local metadata.jsonl file and parses each entry into a BenchmarkDataPoint.
10
  """
11
  def __init__(self, data_dir: str, **kwargs):
12
- self.max_questions = kwargs.get("max_questions", None)
13
  super().__init__(data_dir, **kwargs)
14
 
15
  def _load_data(self) -> None:
 
16
  metadata_path = Path(self.data_dir) / "metadata.jsonl"
17
  if not metadata_path.exists():
18
  raise FileNotFoundError(f"Could not find metadata.jsonl in {self.data_dir}")
19
  print(f"Loading ChestAgentBench from local file: {metadata_path}")
20
- self.data_points = []
 
21
  with open(metadata_path, "r", encoding="utf-8") as f:
22
  for i, line in enumerate(f):
23
- if self.max_questions and i >= self.max_questions:
24
- break
25
  try:
26
  item = json.loads(line)
27
  data_point = self._parse_item(item, i)
@@ -30,43 +29,32 @@ class ChestAgentBenchBenchmark(Benchmark):
30
  except Exception as e:
31
  print(f"Error loading item {i}: {e}")
32
  continue
 
33
 
34
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
35
- # Use full_question_id or question_id if available, else fallback
36
- question_id = item.get("full_question_id") or item.get("question_id") or f"chestagentbench_{index}"
37
  question = item.get("question", "")
38
  correct_answer = item.get("answer", "")
39
- explanation = item.get("explanation", "")
40
- images = item.get("images", [])
41
- case_id = item.get("case_id", "")
42
- category = item.get("categories", "")
43
- # Compose question text (options are embedded in the question string)
44
- question_with_options = question
45
  # Map image paths to local figures directory
 
46
  local_images = None
47
  if images:
48
- figures_dir = Path(self.data_dir) / "figures"
49
  local_images = []
50
  for img in images:
51
- # Handle relative paths like "figures/11583/figure_1.jpg"
52
- if img.startswith("figures/"):
53
- # Remove "figures/" prefix and construct full path
54
- relative_path = img[8:] # Remove "figures/" prefix
55
- full_path = figures_dir / relative_path
56
- local_images.append(str(full_path))
57
- else:
58
- # Fallback to original logic
59
- local_images.append(str(figures_dir / Path(img).name))
60
- # Metadata
61
  metadata = dict(item)
62
- metadata["explanation"] = explanation
63
  metadata["dataset"] = "chestagentbench"
 
 
64
  return BenchmarkDataPoint(
65
  id=question_id,
66
- text=question_with_options,
67
  images=local_images,
68
  correct_answer=correct_answer,
69
  metadata=metadata,
70
- case_id=case_id,
71
- category=category,
72
  )
 
9
  Loads the dataset from a local metadata.jsonl file and parses each entry into a BenchmarkDataPoint.
10
  """
11
  def __init__(self, data_dir: str, **kwargs):
 
12
  super().__init__(data_dir, **kwargs)
13
 
14
  def _load_data(self) -> None:
15
+ # Check if metadata.jsonl exists
16
  metadata_path = Path(self.data_dir) / "metadata.jsonl"
17
  if not metadata_path.exists():
18
  raise FileNotFoundError(f"Could not find metadata.jsonl in {self.data_dir}")
19
  print(f"Loading ChestAgentBench from local file: {metadata_path}")
20
+
21
+ # Load metadata.jsonl
22
  with open(metadata_path, "r", encoding="utf-8") as f:
23
  for i, line in enumerate(f):
 
 
24
  try:
25
  item = json.loads(line)
26
  data_point = self._parse_item(item, i)
 
29
  except Exception as e:
30
  print(f"Error loading item {i}: {e}")
31
  continue
32
+
33
 
34
  def _parse_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
35
+ # Extract required fields
36
+ question_id = item.get("full_question_id")
37
  question = item.get("question", "")
38
  correct_answer = item.get("answer", "")
39
+
 
 
 
 
 
40
  # Map image paths to local figures directory
41
+ images = item.get("images", [])
42
  local_images = None
43
  if images:
 
44
  local_images = []
45
  for img in images:
46
+ full_path = Path(self.data_dir) / img
47
+ local_images.append(str(full_path))
48
+
49
+ # Extract metadata
 
 
 
 
 
 
50
  metadata = dict(item)
 
51
  metadata["dataset"] = "chestagentbench"
52
+
53
+ # Return data point
54
  return BenchmarkDataPoint(
55
  id=question_id,
56
+ text=question,
57
  images=local_images,
58
  correct_answer=correct_answer,
59
  metadata=metadata,
 
 
60
  )
benchmarking/benchmarks/rexvqa_benchmark.py CHANGED
@@ -3,13 +3,21 @@
3
  import json
4
  import os
5
  from typing import Dict, Optional, Any
6
- from datasets import load_dataset
7
  from .base import Benchmark, BenchmarkDataPoint
8
  from pathlib import Path
9
- import subprocess
10
  import tarfile
11
  import zstandard as zstd
12
  from huggingface_hub import hf_hub_download, list_repo_files
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  class ReXVQABenchmark(Benchmark):
@@ -40,16 +48,107 @@ class ReXVQABenchmark(Benchmark):
40
  max_questions (int): Maximum number of questions to load (default: None, load all)
41
  images_dir (str): Directory containing extracted PNG images (default: None)
42
  """
 
 
43
  self.split = kwargs.get("split", "test")
44
- self.trust_remote_code = kwargs.get("trust_remote_code", False)
45
- self.max_questions = kwargs.get("max_questions", None)
46
- self.image_dataset = None
47
- self.image_mapping = {} # Maps study_id to image data
48
-
49
- # Set images_dir BEFORE parent initialization to avoid AttributeError
50
  self.images_dir = f"{data_dir}/images/deid_png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- super().__init__(data_dir, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  @staticmethod
55
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
@@ -99,7 +198,7 @@ class ReXVQABenchmark(Benchmark):
99
  print(f"Output directory: {output_dir}")
100
  try:
101
  print("Listing files in repository...")
102
- files = list_repo_files(repo_id, repo_type='dataset')
103
  part_files = [f for f in files if f.startswith("deid_png.part")]
104
  if not part_files:
105
  print("No part files found. The images might be in a different format.")
@@ -117,7 +216,8 @@ class ReXVQABenchmark(Benchmark):
117
  filename=part_file,
118
  local_dir=output_dir,
119
  local_dir_use_symlinks=False,
120
- repo_type='dataset'
 
121
  )
122
  # Concatenate part files
123
  if not tar_path.exists():
@@ -237,168 +337,10 @@ class ReXVQABenchmark(Benchmark):
237
  filename="metadata/test_vqa_data.json",
238
  local_dir=output_dir,
239
  local_dir_use_symlinks=False,
240
- repo_type='dataset'
 
241
  )
242
  print("Download complete.")
243
  except Exception as e:
244
  print(f"Error downloading test_vqa_data.json: {e}")
245
- print("You may need to accept the license agreement on HuggingFace.")
246
-
247
- def _load_data(self) -> None:
248
- """Load ReXVQA data from local JSON file."""
249
- try:
250
- # Check for images and test_vqa_data.json, download if missing
251
- self.download_test_vqa_data_json(self.data_dir)
252
- self.download_rexgradient_images(self.data_dir, test_only=True)
253
-
254
- # Construct path to the JSON file
255
- json_file_path = os.path.join("benchmarking", "data", "rexvqa", "metadata", "test_vqa_data.json")
256
-
257
- # Check if file exists
258
- if not os.path.exists(json_file_path):
259
- raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
260
-
261
- print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
262
-
263
- # Load JSON file directly
264
- with open(json_file_path, 'r', encoding='utf-8') as f:
265
- questions_data = json.load(f)
266
-
267
- # ReXVQA format: {question_id: {question_data}, ...}
268
- questions_list = []
269
- for question_id, question_data in questions_data.items():
270
- # Add the question_id to the question_data for reference
271
- question_data['id'] = question_id
272
- questions_list.append(question_data)
273
-
274
- print(f"Loaded {len(questions_list)} questions from local JSON file")
275
-
276
- # Load images dataset from ReXGradient-160K (metadata only)
277
- print("Loading ReXGradient-160K metadata dataset...")
278
- try:
279
- self.image_dataset = load_dataset(
280
- "rajpurkarlab/ReXGradient-160K",
281
- split="test",
282
- cache_dir=self.data_dir,
283
- trust_remote_code=self.trust_remote_code
284
- )
285
- print(f"Loaded {len(self.image_dataset)} image metadata entries from ReXGradient-160K")
286
-
287
- # Create mapping from study_id to image metadata
288
- self._create_image_mapping()
289
-
290
- except Exception as e:
291
- print(f"Warning: Could not load ReXGradient-160K dataset: {e}")
292
- print("Proceeding without images...")
293
- self.load_images = False
294
-
295
- self.data_points = []
296
-
297
- # Process questions (limit if max_questions is specified)
298
- questions_to_process = questions_list
299
- if self.max_questions:
300
- questions_to_process = questions_list[:min(self.max_questions, len(questions_list))]
301
-
302
- for i, item in enumerate(questions_to_process):
303
- try:
304
- data_point = self._parse_rexvqa_item(item, i)
305
- if data_point:
306
- self.data_points.append(data_point)
307
-
308
- except Exception as e:
309
- print(f"Error loading item {i}: {e}")
310
- continue
311
-
312
- except Exception as e:
313
- raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
314
-
315
- def _create_image_mapping(self) -> None:
316
- """Create mapping from study_id to image metadata."""
317
- if not self.image_dataset:
318
- return
319
-
320
- print("Creating image mapping...")
321
-
322
- for item in self.image_dataset:
323
- study_instance_uid = item.get("StudyInstanceUid", "")
324
- if study_instance_uid:
325
- # Store the image metadata for this study using StudyInstanceUid as key
326
- if study_instance_uid not in self.image_mapping:
327
- self.image_mapping[study_instance_uid] = []
328
- self.image_mapping[study_instance_uid].append(item)
329
-
330
- print(f"Created image mapping for {len(self.image_mapping)} studies")
331
-
332
- def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
333
- """Parse a ReXVQA dataset item.
334
-
335
- Args:
336
- item (Dict[str, Any]): Dataset item from JSON file
337
- index (int): Item index
338
-
339
- Returns:
340
- Optional[BenchmarkDataPoint]: Parsed data point
341
- """
342
- # Extract basic information
343
- question_id = item.get("id", f"rexvqa_{self.split}_{index}")
344
- question = item.get("question", "")
345
-
346
- # Handle multiple choice options
347
- options = item.get("options", [])
348
- if options:
349
- # Add options to the question for multiple choice format
350
- question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
351
- else:
352
- question_with_options = question
353
-
354
- # Get correct answer
355
- correct_answer = item.get("correct_answer", "")
356
-
357
- if not question:
358
- return None
359
-
360
- # Handle images using ImagePath field
361
- images = None
362
- if self.images_dir and "ImagePath" in item and item["ImagePath"]:
363
- images = []
364
- for rel_path in item["ImagePath"]:
365
- # Remove leading ../ if present
366
- norm_rel_path = rel_path.lstrip("./")
367
- # Join with images_dir root
368
- full_path = str(Path(self.images_dir).parent / norm_rel_path)
369
- images.append(full_path)
370
-
371
- # Extract metadata
372
- metadata = {
373
- "dataset": "rexvqa",
374
- "split": self.split,
375
- "study_id": item.get("study_id", ""),
376
- "study_instance_uid": item.get("StudyInstanceUid", ""),
377
- "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
378
- "category": item.get("category", ""),
379
- "class": item.get("class", ""),
380
- "subcategory": item.get("subcategory", ""),
381
- "patient_id": item.get("PatientID", ""),
382
- "patient_age": item.get("PatientAge", ""),
383
- "patient_sex": item.get("PatientSex", ""),
384
- "study_date": item.get("StudyDate", ""),
385
- "indication": item.get("Indication", ""),
386
- "findings": item.get("Findings", ""),
387
- "impression": item.get("Impression", ""),
388
- "image_modality": item.get("ImageModality", []),
389
- "image_view_position": item.get("ImageViewPosition", []),
390
- "correct_answer_explanation": item.get("correct_answer_explanation", ""),
391
- }
392
-
393
- case_id = item.get("study_id", "")
394
- category = item.get("task_name", "")
395
-
396
- return BenchmarkDataPoint(
397
- id=question_id,
398
- text=question_with_options,
399
- images=images,
400
- correct_answer=correct_answer,
401
- metadata=metadata,
402
- case_id=case_id,
403
- category=category,
404
- )
 
3
  import json
4
  import os
5
  from typing import Dict, Optional, Any
 
6
  from .base import Benchmark, BenchmarkDataPoint
7
  from pathlib import Path
 
8
  import tarfile
9
  import zstandard as zstd
10
  from huggingface_hub import hf_hub_download, list_repo_files
11
+ import os
12
+
13
+
14
+ def get_hf_token():
15
+ """Get Hugging Face token from cache."""
16
+ token_path = os.path.expanduser("~/.cache/huggingface/token")
17
+ if os.path.exists(token_path):
18
+ with open(token_path, 'r') as f:
19
+ return f.read().strip()
20
+ return None
21
 
22
 
23
  class ReXVQABenchmark(Benchmark):
 
48
  max_questions (int): Maximum number of questions to load (default: None, load all)
49
  images_dir (str): Directory containing extracted PNG images (default: None)
50
  """
51
+ super().__init__(data_dir, **kwargs)
52
+
53
  self.split = kwargs.get("split", "test")
 
 
 
 
 
 
54
  self.images_dir = f"{data_dir}/images/deid_png"
55
+
56
+ def _load_data(self) -> None:
57
+ """Load ReXVQA data from HuggingFace."""
58
+ try:
59
+ # Download images and test_vqa_data.json locally if missing
60
+ self.download_test_vqa_data_json(self.data_dir)
61
+ self.download_rexgradient_images(self.data_dir, test_only=True)
62
+
63
+ # Load JSON file
64
+ json_file_path = os.path.join(self.data_dir, "metadata", "test_vqa_data.json")
65
+ if not os.path.exists(json_file_path):
66
+ raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
67
+ print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
68
+ with open(json_file_path, 'r', encoding='utf-8') as f:
69
+ questions_data = json.load(f)
70
+
71
+ # ReXVQA format: {question_id: {question_data}, ...}
72
+ questions_list = []
73
+ for question_id, question_data in questions_data.items():
74
+ # Add the question_id to the question_data for reference
75
+ question_data['id'] = question_id
76
+ questions_list.append(question_data)
77
+ print(f"Loaded {len(questions_list)} questions from local JSON file")
78
+
79
+ # Process questions
80
+ for i, item in enumerate(questions_list):
81
+ try:
82
+ data_point = self._parse_rexvqa_item(item, i)
83
+ if data_point:
84
+ self.data_points.append(data_point)
85
+ except Exception as e:
86
+ print(f"Error loading item {i}: {e}")
87
+ continue
88
+
89
+ except Exception as e:
90
+ raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
91
+
92
+ def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
93
+ """Parse a ReXVQA dataset item.
94
 
95
+ Args:
96
+ item (Dict[str, Any]): Dataset item from JSON file
97
+ index (int): Item index
98
+
99
+ Returns:
100
+ Optional[BenchmarkDataPoint]: Parsed data point
101
+ """
102
+ # Extract question ID
103
+ question_id = item.get("id", f"rexvqa_{self.split}_{index}")
104
+
105
+ # Extract question and options
106
+ question = item.get("question", "")
107
+ options = item.get("options", [])
108
+ question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
109
+
110
+ # Extract correct answer
111
+ correct_answer = item.get("correct_answer", "")
112
+
113
+ # Extract images
114
+ images = None
115
+ if self.images_dir and "ImagePath" in item and item["ImagePath"]:
116
+ images = []
117
+ for rel_path in item["ImagePath"]:
118
+ norm_rel_path = rel_path.lstrip("./")
119
+ full_path = str(Path(self.images_dir).parent / norm_rel_path)
120
+ images.append(full_path)
121
+
122
+ # Extract metadata
123
+ metadata = {
124
+ "dataset": "rexvqa",
125
+ "split": self.split,
126
+ "study_id": item.get("study_id", ""),
127
+ "study_instance_uid": item.get("StudyInstanceUid", ""),
128
+ "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
129
+ "category": item.get("category", ""),
130
+ "class": item.get("class", ""),
131
+ "subcategory": item.get("subcategory", ""),
132
+ "patient_id": item.get("PatientID", ""),
133
+ "patient_age": item.get("PatientAge", ""),
134
+ "patient_sex": item.get("PatientSex", ""),
135
+ "study_date": item.get("StudyDate", ""),
136
+ "indication": item.get("Indication", ""),
137
+ "findings": item.get("Findings", ""),
138
+ "impression": item.get("Impression", ""),
139
+ "image_modality": item.get("ImageModality", []),
140
+ "image_view_position": item.get("ImageViewPosition", []),
141
+ "correct_answer_explanation": item.get("correct_answer_explanation", ""),
142
+ }
143
+
144
+ # Return data point
145
+ return BenchmarkDataPoint(
146
+ id=question_id,
147
+ text=question_with_options,
148
+ images=images,
149
+ correct_answer=correct_answer,
150
+ metadata=metadata
151
+ )
152
 
153
  @staticmethod
154
  def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
 
198
  print(f"Output directory: {output_dir}")
199
  try:
200
  print("Listing files in repository...")
201
+ files = list_repo_files(repo_id, repo_type='dataset', token=get_hf_token())
202
  part_files = [f for f in files if f.startswith("deid_png.part")]
203
  if not part_files:
204
  print("No part files found. The images might be in a different format.")
 
216
  filename=part_file,
217
  local_dir=output_dir,
218
  local_dir_use_symlinks=False,
219
+ repo_type='dataset',
220
+ token=get_hf_token()
221
  )
222
  # Concatenate part files
223
  if not tar_path.exists():
 
337
  filename="metadata/test_vqa_data.json",
338
  local_dir=output_dir,
339
  local_dir_use_symlinks=False,
340
+ repo_type='dataset',
341
+ token=get_hf_token()
342
  )
343
  print("Download complete.")
344
  except Exception as e:
345
  print(f"Error downloading test_vqa_data.json: {e}")
346
+ print("You may need to accept the license agreement on HuggingFace.")