File size: 16,463 Bytes
99f2cbc
 
9bb904a
 
eaff77c
99f2cbc
c7b65ec
aa37a55
 
eaff77c
b93ad3f
 
 
 
 
 
 
 
 
 
99f2cbc
 
 
 
 
 
 
 
 
 
9bb904a
 
c7b65ec
9bb904a
99f2cbc
 
9bb904a
99f2cbc
 
 
 
 
 
 
 
9bb904a
99f2cbc
9bb904a
c7b65ec
99f2cbc
9bb904a
16278b5
b93ad3f
89321e2
 
b93ad3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e97f266
b93ad3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f2cbc
eaff77c
aa37a55
 
 
 
 
 
 
 
eaff77c
 
 
 
 
 
 
 
aa37a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaff77c
 
 
 
b93ad3f
eaff77c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b93ad3f
 
eaff77c
 
 
 
 
 
 
 
 
 
 
 
 
aa37a55
 
 
 
 
 
 
 
 
 
 
eaff77c
 
 
 
 
 
 
 
 
 
c963ad3
 
 
aa37a55
 
c963ad3
aa37a55
 
c963ad3
 
aa37a55
 
c963ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa37a55
 
 
 
 
 
 
 
 
eaff77c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b93ad3f
 
eaff77c
 
 
 
b93ad3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""ReXVQA benchmark implementation."""

import json
import os
from typing import Dict, Optional, Any
from .base import Benchmark, BenchmarkDataPoint
from pathlib import Path
import tarfile
import zstandard as zstd
from huggingface_hub import hf_hub_download, list_repo_files
import os


def get_hf_token():
    """Get Hugging Face token from cache."""
    token_path = os.path.expanduser("~/.cache/huggingface/token")
    if os.path.exists(token_path):
        with open(token_path, 'r') as f:
            return f.read().strip()
    return None


class ReXVQABenchmark(Benchmark):
    """ReXVQA benchmark for chest radiology visual question answering.
    
    ReXVQA is a large-scale VQA dataset for chest radiology comprising approximately 
    696,000 questions paired with 160,000 chest X-rays. It tests 5 core radiological 
    reasoning skills: presence assessment, location analysis, negation detection, 
    differential diagnosis, and geometric reasoning.
    
    The dataset consists of two separate HuggingFace datasets:
    - ReXVQA: Contains questions, answers, and metadata
    - ReXGradient-160K: Contains metadata only (images are in separate part files)
    
    Paper: https://arxiv.org/abs/2506.04353
    Dataset: https://huggingface.co/datasets/rajpurkarlab/ReXVQA
    Images: https://huggingface.co/datasets/rajpurkarlab/ReXGradient-160K
    """

    def __init__(self, data_dir: str, **kwargs):
        """Initialize ReXVQA benchmark.
        
        Args:
            data_dir (str): Directory to store/cache downloaded data
            **kwargs: Additional configuration parameters
                split (str): Dataset split to use (default: 'test')
                trust_remote_code (bool): Whether to trust remote code (default: False)
                max_questions (int): Maximum number of questions to load (default: None, load all)
                images_dir (str): Directory containing extracted PNG images (default: None)
        """
        self.split = kwargs.get("split", "test")
        self.images_dir = f"{data_dir}/images/deid_png"

        super().__init__(data_dir, **kwargs)

    def _load_data(self) -> None:
        """Load ReXVQA data from HuggingFace."""
        try:
            # Download images and test_vqa_data.json locally if missing
            self.download_test_vqa_data_json(self.data_dir)
            self.download_rexgradient_images(self.data_dir, test_only=True)
            
            # Load JSON file
            json_file_path = os.path.join(self.data_dir, "metadata", "test_vqa_data.json")
            if not os.path.exists(json_file_path):
                raise FileNotFoundError(f"Could not find test_vqa_data.json in the expected location: {json_file_path}")
            print(f"Loading ReXVQA {self.split} split from local JSON file: {json_file_path}")
            with open(json_file_path, 'r', encoding='utf-8') as f:
                questions_data = json.load(f)
            
            # ReXVQA format: {question_id: {question_data}, ...}
            questions_list = []
            for question_id, question_data in questions_data.items():
                # Add the question_id to the question_data for reference
                question_data['id'] = question_id
                questions_list.append(question_data)
            print(f"Loaded {len(questions_list)} questions from local JSON file")
    
            # Process questions
            for i, item in enumerate(questions_list):
                try:
                    data_point = self._parse_rexvqa_item(item, i)
                    if data_point:
                        self.data_points.append(data_point)
                except Exception as e:
                    print(f"Error loading item {i}: {e}")
                    continue
                
        except Exception as e:
            raise RuntimeError(f"Failed to load ReXVQA dataset: {e}")

    def _parse_rexvqa_item(self, item: Dict[str, Any], index: int) -> Optional[BenchmarkDataPoint]:
        """Parse a ReXVQA dataset item.
        
        Args:
            item (Dict[str, Any]): Dataset item from JSON file
            index (int): Item index
            
        Returns:
            Optional[BenchmarkDataPoint]: Parsed data point
        """
        # Extract question ID
        question_id = item.get("id", f"rexvqa_{self.split}_{index}")

        # Extract question and options
        question = item.get("question", "")
        options = item.get("options", [])
        question_with_options = question + "\n\nOptions:\n" + "\n".join(options)
        
        # Extract correct answer
        correct_answer = item.get("correct_answer", "")
        
        # Extract images
        images = None
        if self.images_dir and "ImagePath" in item and item["ImagePath"]:
            images = []
            for rel_path in item["ImagePath"]:
                norm_rel_path = rel_path.lstrip("./")
                full_path = str(Path(self.images_dir).parent / norm_rel_path)
                images.append(full_path)
        
        # Extract metadata
        metadata = {
            "dataset": "rexvqa",
            "split": self.split,
            "study_id": item.get("study_id", ""),
            "study_instance_uid": item.get("StudyInstanceUid", ""),
            "reasoning_type": item.get("task_name", ""),  # task_name maps to reasoning_type
            "category": item.get("category", ""),
            "class": item.get("class", ""),
            "subcategory": item.get("subcategory", ""),
            "patient_id": item.get("PatientID", ""),
            "patient_age": item.get("PatientAge", ""),
            "patient_sex": item.get("PatientSex", ""),
            "study_date": item.get("StudyDate", ""),
            "indication": item.get("Indication", ""),
            "findings": item.get("Findings", ""),
            "impression": item.get("Impression", ""),
            "image_modality": item.get("ImageModality", []),
            "image_view_position": item.get("ImageViewPosition", []),
            "correct_answer_explanation": item.get("correct_answer_explanation", ""),
        }

        # Return data point
        return BenchmarkDataPoint(
            id=question_id,
            text=question_with_options,
            images=images,
            correct_answer=correct_answer,
            metadata=metadata
        )

    @staticmethod
    def download_rexgradient_images(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXGradient-160K", test_only: bool = True):
        """Download and extract ReXGradient-160K images if not already present.
        
        Args:
            output_dir: Directory to store downloaded and extracted images
            repo_id: HuggingFace repository ID for the dataset
            test_only: If True, only extract images from the test split (default: True)
        """
        output_dir = Path(output_dir)
        tar_path = output_dir / "deid_png.tar"
        images_dir = output_dir / "images"

        # Check if images already exist
        if images_dir.exists() and any(images_dir.rglob("*.png")):
            print(f"Images already exist in {images_dir}, skipping download.")
            return
            
        # Load test split metadata if test_only is True
        test_image_paths = set()
        if test_only:
            print("Loading test split metadata to identify test images...")
            try:
                # Load the test metadata to get image paths
                test_metadata_path = output_dir / "metadata" / "test_vqa_data.json"
                if test_metadata_path.exists():
                    with open(test_metadata_path, 'r', encoding='utf-8') as f:
                        test_data = json.load(f)
                    
                    # Extract all image paths from test data
                    for item in test_data.values():
                        if "ImagePath" in item and item["ImagePath"]:
                            for rel_path in item["ImagePath"]:
                                # Normalize path to match tar file structure
                                norm_path = rel_path.lstrip("./")
                                test_image_paths.add(norm_path)
                    
                    print(f"Found {len(test_image_paths)} test images to extract")
                else:
                    print("Warning: test_vqa_data.json not found, will extract all images")
                    test_only = False
            except Exception as e:
                print(f"Warning: Could not load test metadata: {e}, will extract all images")
                test_only = False
        output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Output directory: {output_dir}")
        try:
            print("Listing files in repository...")
            files = list_repo_files(repo_id, repo_type='dataset', token=get_hf_token())
            part_files = [f for f in files if f.startswith("deid_png.part")]
            if not part_files:
                print("No part files found. The images might be in a different format.")
                return
            print(f"Found {len(part_files)} part files.")
            # Download part files
            for part_file in part_files:
                output_path = output_dir / part_file
                if output_path.exists():
                    print(f"Skipping {part_file} (already exists)")
                    continue
                print(f"Downloading {part_file}...")
                hf_hub_download(
                    repo_id=repo_id,
                    filename=part_file,
                    local_dir=output_dir,
                    local_dir_use_symlinks=False,
                    repo_type='dataset',
                    token=get_hf_token()
                )
            # Concatenate part files
            if not tar_path.exists():
                print("\nConcatenating part files...")
                with open(tar_path, 'wb') as tar_file:
                    for part_file in sorted(part_files):
                        part_path = output_dir / part_file
                        if part_path.exists():
                            print(f"Adding {part_file}...")
                            with open(part_path, 'rb') as f:
                                tar_file.write(f.read())
                        else:
                            print(f"Warning: {part_file} not found, skipping...")
                
                # Clean up part files after successful concatenation
                print("Cleaning up part files...")
                for part_file in part_files:
                    part_path = output_dir / part_file
                    if part_path.exists():
                        try:
                            part_path.unlink()
                            print(f"Deleted {part_file}")
                        except Exception as e:
                            print(f"Could not delete {part_file}: {e}")
            else:
                print(f"Tar file already exists: {tar_path}")
            # Extract tar file
            if tar_path.exists():
                print("\nExtracting images...")
                images_dir.mkdir(exist_ok=True)
                if any(images_dir.rglob("*.png")):
                    print("Images already extracted.")
                else:
                    try:
                        # Stream extract with filtering for test-only images (no seeking)
                        print("Stream extracting zstd-compressed tar file with filtering (streaming mode)...")

                        # Create a decompressor
                        dctx = zstd.ZstdDecompressor()

                        # Stream extract with filtering
                        extracted_count = 0
                        total_png_members = 0

                        with open(tar_path, 'rb') as compressed_file:
                            with dctx.stream_reader(compressed_file) as decompressed_stream:
                                # Use streaming tar mode to avoid seeks
                                with tarfile.open(fileobj=decompressed_stream, mode='r|') as tar:
                                    for member in tar:
                                        # Only consider PNG files
                                        if not member.isfile() or not member.name.endswith('.png'):
                                            continue
                                        total_png_members += 1

                                        # Normalize name to match entries gathered from JSON
                                        normalized_name = member.name.lstrip('./')

                                        # Decide whether to extract this file
                                        should_extract = True
                                        if test_only:
                                            should_extract = normalized_name in test_image_paths

                                        if not should_extract:
                                            # Must still advance the stream for this member
                                            tar.members = []  # no-op in stream mode; ensure we don't hold refs
                                            continue

                                        # Ensure parent directories exist and write file by streaming
                                        target_path = Path(images_dir) / normalized_name
                                        target_path.parent.mkdir(parents=True, exist_ok=True)

                                        extracted_file_obj = tar.extractfile(member)
                                        if extracted_file_obj is None:
                                            continue
                                        with open(target_path, 'wb') as out_f:
                                            while True:
                                                chunk = extracted_file_obj.read(1024 * 1024)
                                                if not chunk:
                                                    break
                                                out_f.write(chunk)

                                        extracted_count += 1
                                        if extracted_count % 100 == 0:
                                            print(f"Extracted {extracted_count} test images...")

                        print(f"Extraction completed! Extracted {extracted_count} matching PNGs out of {total_png_members} PNG members in the archive")
                        
                        # Clean up compressed tar file after successful extraction
                        print("Cleaning up compressed tar file...")
                        try:
                            tar_path.unlink()
                            print(f"Deleted {tar_path}")
                        except Exception as e:
                            print(f"Could not delete {tar_path}: {e}")
                    except Exception as e:
                        print(f"Error extracting tar file: {e}")
                        return
                png_files = list(images_dir.rglob("*.png"))
                print(f"Extracted {len(png_files)} PNG images to {images_dir}")

        except Exception as e:
            print(f"Error: {e}")

    @staticmethod
    def download_test_vqa_data_json(output_dir: str = "benchmarking/data/rexvqa", repo_id: str = "rajpurkarlab/ReXVQA"):
        """Download test_vqa_data.json from the ReXVQA HuggingFace repo if not already present."""
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        json_path = output_dir / "metadata" / "test_vqa_data.json"
        if json_path.exists():
            print(f"test_vqa_data.json already exists at {json_path}, skipping download.")
            return
        print(f"Downloading test_vqa_data.json to {json_path}...")
        try:
            hf_hub_download(
                repo_id=repo_id,
                filename="metadata/test_vqa_data.json",
                local_dir=output_dir,
                local_dir_use_symlinks=False,
                repo_type='dataset',
                token=get_hf_token()
            )
            print("Download complete.")
        except Exception as e:
            print(f"Error downloading test_vqa_data.json: {e}")
            print("You may need to accept the license agreement on HuggingFace.")