medrax2 / benchmarking /benchmarks /rexvqa_benchmark.py
Junzhe Li
revamped benchmarking suite
89321e2
"""ReXVQA benchmark implementation."""
import json
import os
from typing import Dict, Optional, Any
from .base import Benchmark, BenchmarkDataPoint
from pathlib import Path
import tarfile
import zstandard as zstd
from huggingface_hub import hf_hub_download, list_repo_files
import os
def get_hf_token():
"""Get Hugging Face token from cache."""
token_path = os.path.expanduser("~/.cache/huggingface/token")
if os.path.exists(token_path):
with open(token_path, 'r') as f:
return f.read().strip()
return None
class ReXVQABenchmark(Benchmark):
"""ReXVQA benchmark for chest radiology visual question answering.
ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately
696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological
reasoning skills: presence assessment, location analysis, negation detection,
differential diagnosis, and geometric reasoning.
The dataset consists of two separate HuggingFace datasets:
- ReXVQA: Contains questions, answers, and metadata
- ReXGradient-160K: Contains metadata only (images are in separate part files)
Paper: https://arxiv.org/abs/2506.04353
Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K
"""
def __init__(self, data_dir: str, **kwargs):
"""Initialize ReXVQA benchmark.
Args:
data_dir (str): Directory to store/cache downloaded data
**kwargs: Additional configuration parameters
split (str): Dataset split to use (default: 'test')
trust_remote_code (bool): Whether to trust remote code (default: False)
max_questions (int): Maximum number of questions to load (default: None, load all)
images_dir (str): Directory containing extracted PNG images (default: None)
"""
self.split = kwargs.get("split", "test")
self.images_dir = f"{data_dir}/images/deid_png"
super().__init__(data_dir, **kwargs)
def _load_data(self) -> None:
"""Load ReXVQA data from HuggingFace."""
try:
# Download images and test_vqa_data.json locally if missing
self.download_test_vqa_data_json(self.data_dir)
self.download_rexgradient_images(self.data_dir, test_only=True)
# Load JSON file
json_file_path = os.path.join(self.data_dir, "metadata", "test_vqa_data.json")
if not os.path.exists(json_file_path):
raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
with open(json_file_path, 'r', encoding='utf-8') as f:
questions_data = json.load(f)
# ReXVQA format: {question_id: {question_data}, ...}
questions_list = []
for question_id, question_data in questions_data.items():
# Add the question_id to the question_data for reference
question_data['id'] = question_id
questions_list.append(question_data)
print(f"Loaded {len(questions_list)} questions from local JSON file")
# Process questions
for i, item in enumerate(questions_list):
try:
data_point = self._parse_rexvqa_item(item, i)
if data_point:
self.data_points.append(data_point)
except Exception as e:
print(f"Error loading item {i}: {e}")
continue
except Exception as e:
raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")
def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
"""Parse a ReXVQA dataset item.
Args:
item (Dict[str, Any]): Dataset item from JSON file
index (int): Item index
Returns:
Optional[BenchmarkDataPoint]: Parsed data point
"""
# Extract question ID
question_id = item.get("id", f"rexvqa_{self.split}_{index}")
# Extract question and options
question = item.get("question", "")
options = item.get("options", [])
question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
# Extract correct answer
correct_answer = item.get("correct_answer", "")
# Extract images
images = None
if self.images_dir and "ImagePath" in item and item["ImagePath"]:
images = []
for rel_path in item["ImagePath"]:
norm_rel_path = rel_path.lstrip("./")
full_path = str(Path(self.images_dir).parent / norm_rel_path)
images.append(full_path)
# Extract metadata
metadata = {
"dataset": "rexvqa",
"split": self.split,
"study_id": item.get("study_id", ""),
"study_instance_uid": item.get("StudyInstanceUid", ""),
"reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type
"category": item.get("category", ""),
"class": item.get("class", ""),
"subcategory": item.get("subcategory", ""),
"patient_id": item.get("PatientID", ""),
"patient_age": item.get("PatientAge", ""),
"patient_sex": item.get("PatientSex", ""),
"study_date": item.get("StudyDate", ""),
"indication": item.get("Indication", ""),
"findings": item.get("Findings", ""),
"impression": item.get("Impression", ""),
"image_modality": item.get("ImageModality", []),
"image_view_position": item.get("ImageViewPosition", []),
"correct_answer_explanation": item.get("correct_answer_explanation", ""),
}
# Return data point
return BenchmarkDataPoint(
id=question_id,
text=question_with_options,
images=images,
correct_answer=correct_answer,
metadata=metadata
)
@staticmethod
def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
"""Download and extract ReXGradient-160K images if not already present.
Args:
output_dir: Directory to store downloaded and extracted images
repo_id: HuggingFace repository ID for the dataset
test_only: If True, only extract images from the test split (default: True)
"""
output_dir = Path(output_dir)
tar_path = output_dir / "deid_png.tar"
images_dir = output_dir / "images"
# Check if images already exist
if images_dir.exists() and any(images_dir.rglob("*.png")):
print(f"Images already exist in {images_dir}, skipping download.")
return
# Load test split metadata if test_only is True
test_image_paths = set()
if test_only:
print("Loading test split metadata to identify test images...")
try:
# Load the test metadata to get image paths
test_metadata_path = output_dir / "metadata" / "test_vqa_data.json"
if test_metadata_path.exists():
with open(test_metadata_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
# Extract all image paths from test data
for item in test_data.values():
if "ImagePath" in item and item["ImagePath"]:
for rel_path in item["ImagePath"]:
# Normalize path to match tar file structure
norm_path = rel_path.lstrip("./")
test_image_paths.add(norm_path)
print(f"Found {len(test_image_paths)} test images to extract")
else:
print("Warning: test_vqa_data.json not found, will extract all images")
test_only = False
except Exception as e:
print(f"Warning: Could not load test metadata: {e}, will extract all images")
test_only = False
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")
try:
print("Listing files in repository...")
files = list_repo_files(repo_id, repo_type='dataset', token=get_hf_token())
part_files = [f for f in files if f.startswith("deid_png.part")]
if not part_files:
print("No part files found. The images might be in a different format.")
return
print(f"Found {len(part_files)} part files.")
# Download part files
for part_file in part_files:
output_path = output_dir / part_file
if output_path.exists():
print(f"Skipping {part_file} (already exists)")
continue
print(f"Downloading {part_file}...")
hf_hub_download(
repo_id=repo_id,
filename=part_file,
local_dir=output_dir,
local_dir_use_symlinks=False,
repo_type='dataset',
token=get_hf_token()
)
# Concatenate part files
if not tar_path.exists():
print("\nConcatenating part files...")
with open(tar_path, 'wb') as tar_file:
for part_file in sorted(part_files):
part_path = output_dir / part_file
if part_path.exists():
print(f"Adding {part_file}...")
with open(part_path, 'rb') as f:
tar_file.write(f.read())
else:
print(f"Warning: {part_file} not found, skipping...")
# Clean up part files after successful concatenation
print("Cleaning up part files...")
for part_file in part_files:
part_path = output_dir / part_file
if part_path.exists():
try:
part_path.unlink()
print(f"Deleted {part_file}")
except Exception as e:
print(f"Could not delete {part_file}: {e}")
else:
print(f"Tar file already exists: {tar_path}")
# Extract tar file
if tar_path.exists():
print("\nExtracting images...")
images_dir.mkdir(exist_ok=True)
if any(images_dir.rglob("*.png")):
print("Images already extracted.")
else:
try:
# Stream extract with filtering for test-only images (no seeking)
print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...")
# Create a decompressor
dctx = zstd.ZstdDecompressor()
# Stream extract with filtering
extracted_count = 0
total_png_members = 0
with open(tar_path, 'rb') as compressed_file:
with dctx.stream_reader(compressed_file) as decompressed_stream:
# Use streaming tar mode to avoid seeks
with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar:
for member in tar:
# Only consider PNG files
if not member.isfile() or not member.name.endswith('.png'):
continue
total_png_members += 1
# Normalize name to match entries gathered from JSON
normalized_name = member.name.lstrip('./')
# Decide whether to extract this file
should_extract = True
if test_only:
should_extract = normalized_name in test_image_paths
if not should_extract:
# Must still advance the stream for this member
tar.members = [] # no-op in stream mode; ensure we don't hold refs
continue
# Ensure parent directories exist and write file by streaming
target_path = Path(images_dir) / normalized_name
target_path.parent.mkdir(parents=True, exist_ok=True)
extracted_file_obj = tar.extractfile(member)
if extracted_file_obj is None:
continue
with open(target_path, 'wb') as out_f:
while True:
chunk = extracted_file_obj.read(1024 * 1024)
if not chunk:
break
out_f.write(chunk)
extracted_count += 1
if extracted_count % 100 == 0:
print(f"Extracted {extracted_count} test images...")
print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive")
# Clean up compressed tar file after successful extraction
print("Cleaning up compressed tar file...")
try:
tar_path.unlink()
print(f"Deleted {tar_path}")
except Exception as e:
print(f"Could not delete {tar_path}: {e}")
except Exception as e:
print(f"Error extracting tar file: {e}")
return
png_files = list(images_dir.rglob("*.png"))
print(f"Extracted {len(png_files)} PNG images to {images_dir}")
except Exception as e:
print(f"Error: {e}")
@staticmethod
def download_test_vqa_data_json(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXVQA"):
"""Download test_vqa_data.json from the ReXVQA HuggingFace repo if not already present."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "metadata" / "test_vqa_data.json"
if json_path.exists():
print(f"test_vqa_data.json already exists at {json_path}, skipping download.")
return
print(f"Downloading test_vqa_data.json to {json_path}...")
try:
hf_hub_download(
repo_id=repo_id,
filename="metadata/test_vqa_data.json",
local_dir=output_dir,
local_dir_use_symlinks=False,
repo_type='dataset',
token=get_hf_token()
)
print("Download complete.")
except Exception as e:
print(f"Error downloading test_vqa_data.json: {e}")
print("You may need to accept the license agreement on HuggingFace.")