Spaces:
Paused
Paused
File size: 16,463 Bytes
99f2cbc 9bb904a eaff77c 99f2cbc c7b65ec aa37a55 eaff77c b93ad3f 99f2cbc 9bb904a c7b65ec 9bb904a 99f2cbc 9bb904a 99f2cbc 9bb904a 99f2cbc 9bb904a c7b65ec 99f2cbc 9bb904a 16278b5 b93ad3f 89321e2 b93ad3f e97f266 b93ad3f 99f2cbc eaff77c aa37a55 eaff77c aa37a55 eaff77c b93ad3f eaff77c b93ad3f eaff77c aa37a55 eaff77c c963ad3 aa37a55 c963ad3 aa37a55 c963ad3 aa37a55 c963ad3 aa37a55 eaff77c b93ad3f eaff77c b93ad3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
"""ReXVQA benchmark implementation."""
import json
import os
from typing import Dict, Optional, Any
from .base import Benchmark, BenchmarkDataPoint
from pathlib import Path
import tarfile
import zstandard as zstd
from huggingface_hub import hf_hub_download, list_repo_files
import os
def get_hf_token():
"""Get Hugging Face token from cache."""
token_path = os.path.expanduser("~/.cache/huggingface/token")
if os.path.exists(token_path):
with open(token_path, 'r') as f:
return f.read().strip()
return None
class ReXVQABenchmark(Benchmark):
"""ReXVQA benchmark for chest radiology visual question answering.
ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately
696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological
reasoning skills: presence assessment, location analysis, negation detection,
differential diagnosis, and geometric reasoning.
The dataset consists of two separate HuggingFace datasets:
- ReXVQA: Contains questions, answers, and metadata
- ReXGradient-160K: Contains metadata only (images are in separate part files)
Paper: https://arxiv.org/abs/2506.04353
Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K
"""
def __init__(self, data_dir: str, **kwargs):
"""Initialize ReXVQA benchmark.
Args:
data_dir (str): Directory to store/cache downloaded data
**kwargs: Additional configuration parameters
split (str): Dataset split to use (default: 'test')
trust_remote_code (bool): Whether to trust remote code (default: False)
max_questions (int): Maximum number of questions to load (default: None, load all)
images_dir (str): Directory containing extracted PNG images (default: None)
"""
self.split = kwargs.get("split", "test")
self.images_dir = f"{data_dir}/images/deid_png"
super().__init__(data_dir, **kwargs)
def _load_data(self) -> None:
"""Load ReXVQA data from HuggingFace."""
try:
# Download images and test_vqa_data.json locally if missing
self.download_test_vqa_data_json(self.data_dir)
self.download_rexgradient_images(self.data_dir, test_only=True)
# Load JSON file
json_file_path = os.path.join(self.data_dir, "metadata", "test_vqa_data.json")
if not os.path.exists(json_file_path):
raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
with open(json_file_path, 'r', encoding='utf-8') as f:
questions_data = json.load(f)
# ReXVQA format: {question_id: {question_data}, ...}
questions_list = []
for question_id, question_data in questions_data.items():
# Add the question_id to the question_data for reference
question_data['id'] = question_id
questions_list.append(question_data)
print(f"Loaded {len(questions_list)} questions from local JSON file")
# Process questions
for i, item in enumerate(questions_list):
try:
data_point = self._parse_rexvqa_item(item, i)
if data_point:
self.data_points.append(data_point)
except Exception as e:
print(f"Error loading item {i}: {e}")
continue
except Exception as e:
raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
"""Parse a ReXVQA dataset item.
Args:
item (Dict[str, Any]): Dataset item from JSON file
index (int): Item index
Returns:
Optional[BenchmarkDataPoint]: Parsed data point
"""
# Extract question ID
question_id = item.get("id", f"rexvqa_{self.split}_{index}")
# Extract question and options
question = item.get("question", "")
options = item.get("options", [])
question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
# Extract correct answer
correct_answer = item.get("correct_answer", "")
# Extract images
images = None
if self.images_dir and "ImagePath" in item and item["ImagePath"]:
images = []
for rel_path in item["ImagePath"]:
norm_rel_path = rel_path.lstrip("./")
full_path = str(Path(self.images_dir).parent / norm_rel_path)
images.append(full_path)
# Extract metadata
metadata = {
"dataset": "rexvqa",
"split": self.split,
"study_id": item.get("study_id", ""),
"study_instance_uid": item.get("StudyInstanceUid", ""),
"reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
"category": item.get("category", ""),
"class": item.get("class", ""),
"subcategory": item.get("subcategory", ""),
"patient_id": item.get("PatientID", ""),
"patient_age": item.get("PatientAge", ""),
"patient_sex": item.get("PatientSex", ""),
"study_date": item.get("StudyDate", ""),
"indication": item.get("Indication", ""),
"findings": item.get("Findings", ""),
"impression": item.get("Impression", ""),
"image_modality": item.get("ImageModality", []),
"image_view_position": item.get("ImageViewPosition", []),
"correct_answer_explanation": item.get("correct_answer_explanation", ""),
}
# Return data point
return BenchmarkDataPoint(
id=question_id,
text=question_with_options,
images=images,
correct_answer=correct_answer,
metadata=metadata
)
@staticmethod
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
"""Download and extract ReXGradient-160K images if not already present.
Args:
output_dir: Directory to store downloaded and extracted images
repo_id: HuggingFace repository ID for the dataset
test_only: If True, only extract images from the test split (default: True)
"""
output_dir = Path(output_dir)
tar_path = output_dir / "deid_png.tar"
images_dir = output_dir / "images"
# Check if images already exist
if images_dir.exists() and any(images_dir.rglob("*.png")):
print(f"Images already exist in {images_dir}, skipping download.")
return
# Load test split metadata if test_only is True
test_image_paths = set()
if test_only:
print("Loading test split metadata to identify test images...")
try:
# Load the test metadata to get image paths
test_metadata_path = output_dir / "metadata" / "test_vqa_data.json"
if test_metadata_path.exists():
with open(test_metadata_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
# Extract all image paths from test data
for item in test_data.values():
if "ImagePath" in item and item["ImagePath"]:
for rel_path in item["ImagePath"]:
# Normalize path to match tar file structure
norm_path = rel_path.lstrip("./")
test_image_paths.add(norm_path)
print(f"Found {len(test_image_paths)} test images to extract")
else:
print("Warning: test_vqa_data.json not found, will extract all images")
test_only = False
except Exception as e:
print(f"Warning: Could not load test metadata: {e}, will extract all images")
test_only = False
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")
try:
print("Listing files in repository...")
files = list_repo_files(repo_id, repo_type='dataset', token=get_hf_token())
part_files = [f for f in files if f.startswith("deid_png.part")]
if not part_files:
print("No part files found. The images might be in a different format.")
return
print(f"Found {len(part_files)} part files.")
# Download part files
for part_file in part_files:
output_path = output_dir / part_file
if output_path.exists():
print(f"Skipping {part_file} (already exists)")
continue
print(f"Downloading {part_file}...")
hf_hub_download(
repo_id=repo_id,
filename=part_file,
local_dir=output_dir,
local_dir_use_symlinks=False,
repo_type='dataset',
token=get_hf_token()
)
# Concatenate part files
if not tar_path.exists():
print("\nConcatenating part files...")
with open(tar_path, 'wb') as tar_file:
for part_file in sorted(part_files):
part_path = output_dir / part_file
if part_path.exists():
print(f"Adding {part_file}...")
with open(part_path, 'rb') as f:
tar_file.write(f.read())
else:
print(f"Warning: {part_file} not found, skipping...")
# Clean up part files after successful concatenation
print("Cleaning up part files...")
for part_file in part_files:
part_path = output_dir / part_file
if part_path.exists():
try:
part_path.unlink()
print(f"Deleted {part_file}")
except Exception as e:
print(f"Could not delete {part_file}: {e}")
else:
print(f"Tar file already exists: {tar_path}")
# Extract tar file
if tar_path.exists():
print("\nExtracting images...")
images_dir.mkdir(exist_ok=True)
if any(images_dir.rglob("*.png")):
print("Images already extracted.")
else:
try:
# Stream extract with filtering for test-only images (no seeking)
print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...")
# Create a decompressor
dctx = zstd.ZstdDecompressor()
# Stream extract with filtering
extracted_count = 0
total_png_members = 0
with open(tar_path, 'rb') as compressed_file:
with dctx.stream_reader(compressed_file) as decompressed_stream:
# Use streaming tar mode to avoid seeks
with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar:
for member in tar:
# Only consider PNG files
if not member.isfile() or not member.name.endswith('.png'):
continue
total_png_members += 1
# Normalize name to match entries gathered from JSON
normalized_name = member.name.lstrip('./')
# Decide whether to extract this file
should_extract = True
if test_only:
should_extract = normalized_name in test_image_paths
if not should_extract:
# Must still advance the stream for this member
tar.members = [] # no-op in stream mode; ensure we don't hold refs
continue
# Ensure parent directories exist and write file by streaming
target_path = Path(images_dir) / normalized_name
target_path.parent.mkdir(parents=True, exist_ok=True)
extracted_file_obj = tar.extractfile(member)
if extracted_file_obj is None:
continue
with open(target_path, 'wb') as out_f:
while True:
chunk = extracted_file_obj.read(1024 * 1024)
if not chunk:
break
out_f.write(chunk)
extracted_count += 1
if extracted_count % 100 == 0:
print(f"Extracted {extracted_count} test images...")
print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive")
# Clean up compressed tar file after successful extraction
print("Cleaning up compressed tar file...")
try:
tar_path.unlink()
print(f"Deleted {tar_path}")
except Exception as e:
print(f"Could not delete {tar_path}: {e}")
except Exception as e:
print(f"Error extracting tar file: {e}")
return
png_files = list(images_dir.rglob("*.png"))
print(f"Extracted {len(png_files)} PNG images to {images_dir}")
except Exception as e:
print(f"Error: {e}")
@staticmethod
def download_test_vqa_data_json(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXVQA"):
"""Download test_vqa_data.json from the ReXVQA HuggingFace repo if not already present."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "metadata" / "test_vqa_data.json"
if json_path.exists():
print(f"test_vqa_data.json already exists at {json_path}, skipping download.")
return
print(f"Downloading test_vqa_data.json to {json_path}...")
try:
hf_hub_download(
repo_id=repo_id,
filename="metadata/test_vqa_data.json",
local_dir=output_dir,
local_dir_use_symlinks=False,
repo_type='dataset',
token=get_hf_token()
)
print("Download complete.")
except Exception as e:
print(f"Error downloading test_vqa_data.json: {e}")
print("You may need to accept the license agreement on HuggingFace.") |