| """ |
| Data utilities for MR-JEPA. |
| |
| Includes: |
| - Collator that handles variable-length options, multi-image samples |
| - Dataloader factory |
| - Benchmark configuration helpers |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from typing import Optional, Dict, List, Any, Tuple |
| from PIL import Image |
| import numpy as np |
|
|
| from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkSample, BenchmarkType |
|
|
|
|
| BENCHMARK_CONFIGS = { |
| 'mmmu': { |
| 'repo_id': 'MMMU/MMMU', |
| 'eval_split': 'validation', |
| 'metric': 'accuracy', |
| 'answer_type': 'mc', |
| 'configs': [ |
| 'Accounting', 'Agriculture', 'Architecture_and_Engineering', |
| 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', |
| 'Chemistry', 'Clinical_Medicine', 'Computer_Science', |
| 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', |
| 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', |
| 'History', 'Literature', 'Manage', 'Marketing', |
| 'Materials', 'Math', 'Mechanical_Engineering', 'Music', |
| 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', |
| 'Sociology' |
| ], |
| }, |
| 'mathvista': { |
| 'repo_id': 'AI4Math/MathVista', |
| 'eval_split': 'testmini', |
| 'metric': 'accuracy', |
| 'answer_type': 'mixed', |
| }, |
| 'scienceqa': { |
| 'repo_id': 'derek-thomas/ScienceQA', |
| 'eval_split': 'test', |
| 'train_split': 'train', |
| 'metric': 'accuracy', |
| 'answer_type': 'mc', |
| }, |
| 'ai2d': { |
| 'repo_id': 'lmms-lab/ai2d', |
| 'eval_split': 'test', |
| 'metric': 'accuracy', |
| 'answer_type': 'mc', |
| }, |
| 'mmbench': { |
| 'repo_id': 'lmms-lab/MMBench', |
| 'eval_split': 'dev', |
| 'metric': 'accuracy', |
| 'answer_type': 'mc', |
| }, |
| 'mmstar': { |
| 'repo_id': 'Lin-Chen/MMStar', |
| 'eval_split': 'val', |
| 'metric': 'accuracy', |
| 'answer_type': 'mc', |
| }, |
| 'docvqa': { |
| 'repo_id': 'lmms-lab/DocVQA', |
| 'eval_split': 'validation', |
| 'metric': 'anls', |
| 'answer_type': 'open', |
| }, |
| 'textvqa': { |
| 'repo_id': 'lmms-lab/textvqa', |
| 'eval_split': 'validation', |
| 'metric': 'vqa_accuracy', |
| 'answer_type': 'open', |
| }, |
| 'chartqa': { |
| 'repo_id': 'lmms-lab/ChartQA', |
| 'eval_split': 'test', |
| 'metric': 'relaxed_accuracy', |
| 'answer_type': 'open', |
| }, |
| } |
|
|
|
|
| def get_benchmark_config(benchmark: str) -> Dict: |
| """Get benchmark configuration.""" |
| return BENCHMARK_CONFIGS[benchmark] |
|
|
|
|
| class MRJEPACollator: |
| """ |
| Collator for MR-JEPA that handles: |
| - Variable number of images per sample (MMMU) |
| - Variable number of answer options |
| - Mixed MC/open-ended questions |
| - Image preprocessing via backbone processor |
| - Text tokenization |
| """ |
| |
| def __init__( |
| self, |
| image_processor, |
| text_tokenizer, |
| max_options: int = 8, |
| max_text_length: int = 256, |
| max_gen_length: int = 64, |
| image_size: int = 518, |
| ): |
| self.image_processor = image_processor |
| self.text_tokenizer = text_tokenizer |
| self.max_options = max_options |
| self.max_text_length = max_text_length |
| self.max_gen_length = max_gen_length |
| self.image_size = image_size |
| |
| def __call__(self, batch: List[BenchmarkSample]) -> Dict[str, torch.Tensor]: |
| """Collate a batch of BenchmarkSamples.""" |
| B = len(batch) |
| |
| |
| |
| images = [] |
| for sample in batch: |
| img = sample.images[0] |
| if not isinstance(img, Image.Image): |
| img = Image.new('RGB', (self.image_size, self.image_size), 'white') |
| images.append(img.convert('RGB')) |
| |
| |
| pixel_values = self.image_processor( |
| images=images, |
| return_tensors='pt', |
| )['pixel_values'] |
| |
| |
| questions = [s.question for s in batch] |
| text_encoded = self.text_tokenizer( |
| questions, |
| padding='max_length', |
| truncation=True, |
| max_length=self.max_text_length, |
| return_tensors='pt', |
| ) |
| |
| |
| |
| option_embeddings_list = [] |
| option_masks = [] |
| answer_labels = [] |
| |
| has_mc = any(s.options is not None for s in batch) |
| |
| if has_mc: |
| for sample in batch: |
| if sample.options: |
| n_opts = min(len(sample.options), self.max_options) |
| |
| opts_text = sample.options[:n_opts] |
| |
| while len(opts_text) < self.max_options: |
| opts_text.append("") |
| |
| mask = [True] * n_opts + [False] * (self.max_options - n_opts) |
| option_masks.append(mask) |
| |
| |
| if isinstance(sample.answer, int): |
| answer_labels.append(min(sample.answer, n_opts - 1)) |
| elif isinstance(sample.answer, str) and len(sample.answer) == 1: |
| answer_labels.append(ord(sample.answer.upper()) - ord('A')) |
| else: |
| answer_labels.append(0) |
| else: |
| option_masks.append([False] * self.max_options) |
| answer_labels.append(0) |
| |
| |
| gen_target_ids = None |
| has_open = any(s.answer_type == 'open' for s in batch) |
| |
| if has_open: |
| |
| gen_texts = [] |
| for sample in batch: |
| if sample.answer_type == 'open': |
| if isinstance(sample.answer, list): |
| gen_texts.append(str(sample.answer[0])) |
| else: |
| gen_texts.append(str(sample.answer)) |
| else: |
| gen_texts.append("") |
| |
| gen_encoded = self.text_tokenizer( |
| gen_texts, |
| padding='max_length', |
| truncation=True, |
| max_length=self.max_gen_length, |
| return_tensors='pt', |
| ) |
| gen_target_ids = gen_encoded['input_ids'] |
| |
| |
| result = { |
| 'pixel_values': pixel_values, |
| 'input_ids': text_encoded['input_ids'], |
| 'attention_mask': text_encoded['attention_mask'], |
| } |
| |
| if has_mc: |
| result['option_mask'] = torch.tensor(option_masks, dtype=torch.bool) |
| result['answer_labels'] = torch.tensor(answer_labels, dtype=torch.long) |
| |
| |
| |
| all_option_texts = [] |
| for sample in batch: |
| opts = sample.options or [""] * self.max_options |
| opts = opts[:self.max_options] |
| while len(opts) < self.max_options: |
| opts.append("") |
| all_option_texts.append(opts) |
| result['option_texts'] = all_option_texts |
| |
| if gen_target_ids is not None: |
| result['gen_target_ids'] = gen_target_ids |
| |
| |
| result['benchmarks'] = [s.benchmark for s in batch] |
| result['answer_types'] = [s.answer_type for s in batch] |
| result['raw_answers'] = [s.answer for s in batch] |
| |
| return result |
|
|
|
|
| def build_dataloader( |
| benchmark: str, |
| split: str, |
| image_processor, |
| text_tokenizer, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| max_samples: Optional[int] = None, |
| config: Optional[str] = None, |
| **collator_kwargs, |
| ) -> DataLoader: |
| """Build a DataLoader for a specific benchmark.""" |
| dataset = UnifiedBenchmarkDataset( |
| benchmark=benchmark, |
| split=split, |
| config=config, |
| max_samples=max_samples, |
| ) |
| |
| collator = MRJEPACollator( |
| image_processor=image_processor, |
| text_tokenizer=text_tokenizer, |
| **collator_kwargs, |
| ) |
| |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=(split in ('train', 'training')), |
| num_workers=num_workers, |
| collate_fn=collator, |
| pin_memory=True, |
| drop_last=(split in ('train', 'training')), |
| ) |
|
|