""" Data utilities for MR-JEPA. Includes: - Collator that handles variable-length options, multi-image samples - Dataloader factory - Benchmark configuration helpers """ import torch import torch.nn.functional as F from torch.utils.data import DataLoader from typing import Optional, Dict, List, Any, Tuple from PIL import Image import numpy as np from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkSample, BenchmarkType BENCHMARK_CONFIGS = { 'mmmu': { 'repo_id': 'MMMU/MMMU', 'eval_split': 'validation', 'metric': 'accuracy', 'answer_type': 'mc', 'configs': [ 'Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology' ], }, 'mathvista': { 'repo_id': 'AI4Math/MathVista', 'eval_split': 'testmini', 'metric': 'accuracy', 'answer_type': 'mixed', }, 'scienceqa': { 'repo_id': 'derek-thomas/ScienceQA', 'eval_split': 'test', 'train_split': 'train', 'metric': 'accuracy', 'answer_type': 'mc', }, 'ai2d': { 'repo_id': 'lmms-lab/ai2d', 'eval_split': 'test', 'metric': 'accuracy', 'answer_type': 'mc', }, 'mmbench': { 'repo_id': 'lmms-lab/MMBench', 'eval_split': 'dev', 'metric': 'accuracy', 'answer_type': 'mc', }, 'mmstar': { 'repo_id': 'Lin-Chen/MMStar', 'eval_split': 'val', 'metric': 'accuracy', 'answer_type': 'mc', }, 'docvqa': { 'repo_id': 'lmms-lab/DocVQA', 'eval_split': 'validation', 'metric': 'anls', 'answer_type': 'open', }, 'textvqa': { 'repo_id': 'lmms-lab/textvqa', 'eval_split': 'validation', 'metric': 'vqa_accuracy', 'answer_type': 'open', }, 'chartqa': { 'repo_id': 'lmms-lab/ChartQA', 'eval_split': 'test', 'metric': 'relaxed_accuracy', 'answer_type': 'open', }, } def get_benchmark_config(benchmark: str) -> Dict: """Get benchmark configuration.""" return BENCHMARK_CONFIGS[benchmark] class MRJEPACollator: """ Collator for MR-JEPA that handles: - Variable number of images per sample (MMMU) - Variable number of answer options - Mixed MC/open-ended questions - Image preprocessing via backbone processor - Text tokenization """ def __init__( self, image_processor, text_tokenizer, max_options: int = 8, max_text_length: int = 256, max_gen_length: int = 64, image_size: int = 518, ): self.image_processor = image_processor self.text_tokenizer = text_tokenizer self.max_options = max_options self.max_text_length = max_text_length self.max_gen_length = max_gen_length self.image_size = image_size def __call__(self, batch: List[BenchmarkSample]) -> Dict[str, torch.Tensor]: """Collate a batch of BenchmarkSamples.""" B = len(batch) # ==================== Images ==================== # Use first image for now (multi-image MMMU handled separately) images = [] for sample in batch: img = sample.images[0] if not isinstance(img, Image.Image): img = Image.new('RGB', (self.image_size, self.image_size), 'white') images.append(img.convert('RGB')) # Process images through backbone processor pixel_values = self.image_processor( images=images, return_tensors='pt', )['pixel_values'] # [B, C, H, W] # ==================== Question Text ==================== questions = [s.question for s in batch] text_encoded = self.text_tokenizer( questions, padding='max_length', truncation=True, max_length=self.max_text_length, return_tensors='pt', ) # ==================== Options (MC) ==================== # Encode each option separately, pad to max_options option_embeddings_list = [] option_masks = [] answer_labels = [] has_mc = any(s.options is not None for s in batch) if has_mc: for sample in batch: if sample.options: n_opts = min(len(sample.options), self.max_options) # Tokenize options opts_text = sample.options[:n_opts] # Pad option text list to max_options while len(opts_text) < self.max_options: opts_text.append("") mask = [True] * n_opts + [False] * (self.max_options - n_opts) option_masks.append(mask) # Answer label if isinstance(sample.answer, int): answer_labels.append(min(sample.answer, n_opts - 1)) elif isinstance(sample.answer, str) and len(sample.answer) == 1: answer_labels.append(ord(sample.answer.upper()) - ord('A')) else: answer_labels.append(0) else: option_masks.append([False] * self.max_options) answer_labels.append(0) # ==================== Open-ended answers ==================== gen_target_ids = None has_open = any(s.answer_type == 'open' for s in batch) if has_open: # Prepare generative targets gen_texts = [] for sample in batch: if sample.answer_type == 'open': if isinstance(sample.answer, list): gen_texts.append(str(sample.answer[0])) else: gen_texts.append(str(sample.answer)) else: gen_texts.append("") gen_encoded = self.text_tokenizer( gen_texts, padding='max_length', truncation=True, max_length=self.max_gen_length, return_tensors='pt', ) gen_target_ids = gen_encoded['input_ids'] # ==================== Build output dict ==================== result = { 'pixel_values': pixel_values, 'input_ids': text_encoded['input_ids'], 'attention_mask': text_encoded['attention_mask'], } if has_mc: result['option_mask'] = torch.tensor(option_masks, dtype=torch.bool) result['answer_labels'] = torch.tensor(answer_labels, dtype=torch.long) # We need to encode options through text encoder at runtime # Store raw option texts for the model to encode all_option_texts = [] for sample in batch: opts = sample.options or [""] * self.max_options opts = opts[:self.max_options] while len(opts) < self.max_options: opts.append("") all_option_texts.append(opts) result['option_texts'] = all_option_texts if gen_target_ids is not None: result['gen_target_ids'] = gen_target_ids # Metadata result['benchmarks'] = [s.benchmark for s in batch] result['answer_types'] = [s.answer_type for s in batch] result['raw_answers'] = [s.answer for s in batch] return result def build_dataloader( benchmark: str, split: str, image_processor, text_tokenizer, batch_size: int = 32, num_workers: int = 4, max_samples: Optional[int] = None, config: Optional[str] = None, **collator_kwargs, ) -> DataLoader: """Build a DataLoader for a specific benchmark.""" dataset = UnifiedBenchmarkDataset( benchmark=benchmark, split=split, config=config, max_samples=max_samples, ) collator = MRJEPACollator( image_processor=image_processor, text_tokenizer=text_tokenizer, **collator_kwargs, ) return DataLoader( dataset, batch_size=batch_size, shuffle=(split in ('train', 'training')), num_workers=num_workers, collate_fn=collator, pin_memory=True, drop_last=(split in ('train', 'training')), )