Spaces:
Sleeping
Sleeping
| """ReXVQA benchmark implementation.""" | |
| import json | |
| import os | |
| from typing import Dict, Optional, Any | |
| from .base import Benchmark, BenchmarkDataPoint | |
| from pathlib import Path | |
| import tarfile | |
| import zstandard as zstd | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| import os | |
| def get_hf_token(): | |
| """Get Hugging Face token from cache.""" | |
| token_path = os.path.expanduser("~/.cache/huggingface/token") | |
| if os.path.exists(token_path): | |
| with open(token_path, 'r') as f: | |
| return f.read().strip() | |
| return None | |
| class ReXVQABenchmark(Benchmark): | |
| """ReXVQA benchmark for chest radiology visual question answering. | |
| ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately | |
| 696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological | |
| reasoning skills: presence assessment, location analysis, negation detection, | |
| differential diagnosis, and geometric reasoning. | |
| The dataset consists of two separate HuggingFace datasets: | |
| - ReXVQA: Contains questions, answers, and metadata | |
| - ReXGradient-160K: Contains metadata only (images are in separate part files) | |
| Paper: https://arxiv.org/abs/2506.04353 | |
| Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA | |
| Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K | |
| """ | |
| def __init__(self, data_dir: str, **kwargs): | |
| """Initialize ReXVQA benchmark. | |
| Args: | |
| data_dir (str): Directory to store/cache downloaded data | |
| **kwargs: Additional configuration parameters | |
| split (str): Dataset split to use (default: 'test') | |
| trust_remote_code (bool): Whether to trust remote code (default: False) | |
| max_questions (int): Maximum number of questions to load (default: None, load all) | |
| images_dir (str): Directory containing extracted PNG images (default: None) | |
| """ | |
| self.split = kwargs.get("split", "test") | |
| self.images_dir = f"{data_dir}/images/deid_png" | |
| super().__init__(data_dir, **kwargs) | |
| def _load_data(self) -> None: | |
| """Load ReXVQA data from HuggingFace.""" | |
| try: | |
| # Download images and test_vqa_data.json locally if missing | |
| self.download_test_vqa_data_json(self.data_dir) | |
| self.download_rexgradient_images(self.data_dir, test_only=True) | |
| # Load JSON file | |
| json_file_path = os.path.join(self.data_dir, "metadata", "test_vqa_data.json") | |
| if not os.path.exists(json_file_path): | |
| raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}") | |
| print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}") | |
| with open(json_file_path, 'r', encoding='utf-8') as f: | |
| questions_data = json.load(f) | |
| # ReXVQA format: {question_id: {question_data}, ...} | |
| questions_list = [] | |
| for question_id, question_data in questions_data.items(): | |
| # Add the question_id to the question_data for reference | |
| question_data['id'] = question_id | |
| questions_list.append(question_data) | |
| print(f"Loaded {len(questions_list)} questions from local JSON file") | |
| # Process questions | |
| for i, item in enumerate(questions_list): | |
| try: | |
| data_point = self._parse_rexvqa_item(item, i) | |
| if data_point: | |
| self.data_points.append(data_point) | |
| except Exception as e: | |
| print(f"Error loading item {i}: {e}") | |
| continue | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load ReXVQA dataset: {e}") | |
| def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]: | |
| """Parse a ReXVQA dataset item. | |
| Args: | |
| item (Dict[str, Any]): Dataset item from JSON file | |
| index (int): Item index | |
| Returns: | |
| Optional[BenchmarkDataPoint]: Parsed data point | |
| """ | |
| # Extract question ID | |
| question_id = item.get("id", f"rexvqa_{self.split}_{index}") | |
| # Extract question and options | |
| question = item.get("question", "") | |
| options = item.get("options", []) | |
| question_with_options = question + "\n\nOptions:\n" + "\n".join(options) | |
| # Extract correct answer | |
| correct_answer = item.get("correct_answer", "") | |
| # Extract images | |
| images = None | |
| if self.images_dir and "ImagePath" in item and item["ImagePath"]: | |
| images = [] | |
| for rel_path in item["ImagePath"]: | |
| norm_rel_path = rel_path.lstrip("./") | |
| full_path = str(Path(self.images_dir).parent / norm_rel_path) | |
| images.append(full_path) | |
| # Extract metadata | |
| metadata = { | |
| "dataset": "rexvqa", | |
| "split": self.split, | |
| "study_id": item.get("study_id", ""), | |
| "study_instance_uid": item.get("StudyInstanceUid", ""), | |
| "reasoning_type": item.get("task_name", ""), # task_name maps to reasoning_type | |
| "category": item.get("category", ""), | |
| "class": item.get("class", ""), | |
| "subcategory": item.get("subcategory", ""), | |
| "patient_id": item.get("PatientID", ""), | |
| "patient_age": item.get("PatientAge", ""), | |
| "patient_sex": item.get("PatientSex", ""), | |
| "study_date": item.get("StudyDate", ""), | |
| "indication": item.get("Indication", ""), | |
| "findings": item.get("Findings", ""), | |
| "impression": item.get("Impression", ""), | |
| "image_modality": item.get("ImageModality", []), | |
| "image_view_position": item.get("ImageViewPosition", []), | |
| "correct_answer_explanation": item.get("correct_answer_explanation", ""), | |
| } | |
| # Return data point | |
| return BenchmarkDataPoint( | |
| id=question_id, | |
| text=question_with_options, | |
| images=images, | |
| correct_answer=correct_answer, | |
| metadata=metadata | |
| ) | |
| def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True): | |
| """Download and extract ReXGradient-160K images if not already present. | |
| Args: | |
| output_dir: Directory to store downloaded and extracted images | |
| repo_id: HuggingFace repository ID for the dataset | |
| test_only: If True, only extract images from the test split (default: True) | |
| """ | |
| output_dir = Path(output_dir) | |
| tar_path = output_dir / "deid_png.tar" | |
| images_dir = output_dir / "images" | |
| # Check if images already exist | |
| if images_dir.exists() and any(images_dir.rglob("*.png")): | |
| print(f"Images already exist in {images_dir}, skipping download.") | |
| return | |
| # Load test split metadata if test_only is True | |
| test_image_paths = set() | |
| if test_only: | |
| print("Loading test split metadata to identify test images...") | |
| try: | |
| # Load the test metadata to get image paths | |
| test_metadata_path = output_dir / "metadata" / "test_vqa_data.json" | |
| if test_metadata_path.exists(): | |
| with open(test_metadata_path, 'r', encoding='utf-8') as f: | |
| test_data = json.load(f) | |
| # Extract all image paths from test data | |
| for item in test_data.values(): | |
| if "ImagePath" in item and item["ImagePath"]: | |
| for rel_path in item["ImagePath"]: | |
| # Normalize path to match tar file structure | |
| norm_path = rel_path.lstrip("./") | |
| test_image_paths.add(norm_path) | |
| print(f"Found {len(test_image_paths)} test images to extract") | |
| else: | |
| print("Warning: test_vqa_data.json not found, will extract all images") | |
| test_only = False | |
| except Exception as e: | |
| print(f"Warning: Could not load test metadata: {e}, will extract all images") | |
| test_only = False | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"Output directory: {output_dir}") | |
| try: | |
| print("Listing files in repository...") | |
| files = list_repo_files(repo_id, repo_type='dataset', token=get_hf_token()) | |
| part_files = [f for f in files if f.startswith("deid_png.part")] | |
| if not part_files: | |
| print("No part files found. The images might be in a different format.") | |
| return | |
| print(f"Found {len(part_files)} part files.") | |
| # Download part files | |
| for part_file in part_files: | |
| output_path = output_dir / part_file | |
| if output_path.exists(): | |
| print(f"Skipping {part_file} (already exists)") | |
| continue | |
| print(f"Downloading {part_file}...") | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=part_file, | |
| local_dir=output_dir, | |
| local_dir_use_symlinks=False, | |
| repo_type='dataset', | |
| token=get_hf_token() | |
| ) | |
| # Concatenate part files | |
| if not tar_path.exists(): | |
| print("\nConcatenating part files...") | |
| with open(tar_path, 'wb') as tar_file: | |
| for part_file in sorted(part_files): | |
| part_path = output_dir / part_file | |
| if part_path.exists(): | |
| print(f"Adding {part_file}...") | |
| with open(part_path, 'rb') as f: | |
| tar_file.write(f.read()) | |
| else: | |
| print(f"Warning: {part_file} not found, skipping...") | |
| # Clean up part files after successful concatenation | |
| print("Cleaning up part files...") | |
| for part_file in part_files: | |
| part_path = output_dir / part_file | |
| if part_path.exists(): | |
| try: | |
| part_path.unlink() | |
| print(f"Deleted {part_file}") | |
| except Exception as e: | |
| print(f"Could not delete {part_file}: {e}") | |
| else: | |
| print(f"Tar file already exists: {tar_path}") | |
| # Extract tar file | |
| if tar_path.exists(): | |
| print("\nExtracting images...") | |
| images_dir.mkdir(exist_ok=True) | |
| if any(images_dir.rglob("*.png")): | |
| print("Images already extracted.") | |
| else: | |
| try: | |
| # Stream extract with filtering for test-only images (no seeking) | |
| print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...") | |
| # Create a decompressor | |
| dctx = zstd.ZstdDecompressor() | |
| # Stream extract with filtering | |
| extracted_count = 0 | |
| total_png_members = 0 | |
| with open(tar_path, 'rb') as compressed_file: | |
| with dctx.stream_reader(compressed_file) as decompressed_stream: | |
| # Use streaming tar mode to avoid seeks | |
| with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar: | |
| for member in tar: | |
| # Only consider PNG files | |
| if not member.isfile() or not member.name.endswith('.png'): | |
| continue | |
| total_png_members += 1 | |
| # Normalize name to match entries gathered from JSON | |
| normalized_name = member.name.lstrip('./') | |
| # Decide whether to extract this file | |
| should_extract = True | |
| if test_only: | |
| should_extract = normalized_name in test_image_paths | |
| if not should_extract: | |
| # Must still advance the stream for this member | |
| tar.members = [] # no-op in stream mode; ensure we don't hold refs | |
| continue | |
| # Ensure parent directories exist and write file by streaming | |
| target_path = Path(images_dir) / normalized_name | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| extracted_file_obj = tar.extractfile(member) | |
| if extracted_file_obj is None: | |
| continue | |
| with open(target_path, 'wb') as out_f: | |
| while True: | |
| chunk = extracted_file_obj.read(1024 * 1024) | |
| if not chunk: | |
| break | |
| out_f.write(chunk) | |
| extracted_count += 1 | |
| if extracted_count % 100 == 0: | |
| print(f"Extracted {extracted_count} test images...") | |
| print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive") | |
| # Clean up compressed tar file after successful extraction | |
| print("Cleaning up compressed tar file...") | |
| try: | |
| tar_path.unlink() | |
| print(f"Deleted {tar_path}") | |
| except Exception as e: | |
| print(f"Could not delete {tar_path}: {e}") | |
| except Exception as e: | |
| print(f"Error extracting tar file: {e}") | |
| return | |
| png_files = list(images_dir.rglob("*.png")) | |
| print(f"Extracted {len(png_files)} PNG images to {images_dir}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| def download_test_vqa_data_json(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXVQA"): | |
| """Download test_vqa_data.json from the ReXVQA HuggingFace repo if not already present.""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| json_path = output_dir / "metadata" / "test_vqa_data.json" | |
| if json_path.exists(): | |
| print(f"test_vqa_data.json already exists at {json_path}, skipping download.") | |
| return | |
| print(f"Downloading test_vqa_data.json to {json_path}...") | |
| try: | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename="metadata/test_vqa_data.json", | |
| local_dir=output_dir, | |
| local_dir_use_symlinks=False, | |
| repo_type='dataset', | |
| token=get_hf_token() | |
| ) | |
| print("Download complete.") | |
| except Exception as e: | |
| print(f"Error downloading test_vqa_data.json: {e}") | |
| print("You may need to accept the license agreement on HuggingFace.") |