MR-JEPA / mr_jepa /data /data_utils.py
JorgeAV's picture
Initial MR-JEPA codebase: architecture, training, evaluation, and tests
dba2c56 verified
"""
Data utilities for MR-JEPA.
Includes:
- Collator that handles variable-length options, multi-image samples
- Dataloader factory
- Benchmark configuration helpers
"""
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Dict, List, Any, Tuple
from PIL import Image
import numpy as np
from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkSample, BenchmarkType
BENCHMARK_CONFIGS = {
'mmmu': {
'repo_id': 'MMMU/MMMU',
'eval_split': 'validation',
'metric': 'accuracy',
'answer_type': 'mc',
'configs': [
'Accounting', 'Agriculture', 'Architecture_and_Engineering',
'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology',
'Chemistry', 'Clinical_Medicine', 'Computer_Science',
'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics',
'Electronics', 'Energy_and_Power', 'Finance', 'Geography',
'History', 'Literature', 'Manage', 'Marketing',
'Materials', 'Math', 'Mechanical_Engineering', 'Music',
'Pharmacy', 'Physics', 'Psychology', 'Public_Health',
'Sociology'
],
},
'mathvista': {
'repo_id': 'AI4Math/MathVista',
'eval_split': 'testmini',
'metric': 'accuracy',
'answer_type': 'mixed',
},
'scienceqa': {
'repo_id': 'derek-thomas/ScienceQA',
'eval_split': 'test',
'train_split': 'train',
'metric': 'accuracy',
'answer_type': 'mc',
},
'ai2d': {
'repo_id': 'lmms-lab/ai2d',
'eval_split': 'test',
'metric': 'accuracy',
'answer_type': 'mc',
},
'mmbench': {
'repo_id': 'lmms-lab/MMBench',
'eval_split': 'dev',
'metric': 'accuracy',
'answer_type': 'mc',
},
'mmstar': {
'repo_id': 'Lin-Chen/MMStar',
'eval_split': 'val',
'metric': 'accuracy',
'answer_type': 'mc',
},
'docvqa': {
'repo_id': 'lmms-lab/DocVQA',
'eval_split': 'validation',
'metric': 'anls',
'answer_type': 'open',
},
'textvqa': {
'repo_id': 'lmms-lab/textvqa',
'eval_split': 'validation',
'metric': 'vqa_accuracy',
'answer_type': 'open',
},
'chartqa': {
'repo_id': 'lmms-lab/ChartQA',
'eval_split': 'test',
'metric': 'relaxed_accuracy',
'answer_type': 'open',
},
}
def get_benchmark_config(benchmark: str) -> Dict:
"""Get benchmark configuration."""
return BENCHMARK_CONFIGS[benchmark]
class MRJEPACollator:
"""
Collator for MR-JEPA that handles:
- Variable number of images per sample (MMMU)
- Variable number of answer options
- Mixed MC/open-ended questions
- Image preprocessing via backbone processor
- Text tokenization
"""
def __init__(
self,
image_processor,
text_tokenizer,
max_options: int = 8,
max_text_length: int = 256,
max_gen_length: int = 64,
image_size: int = 518,
):
self.image_processor = image_processor
self.text_tokenizer = text_tokenizer
self.max_options = max_options
self.max_text_length = max_text_length
self.max_gen_length = max_gen_length
self.image_size = image_size
def __call__(self, batch: List[BenchmarkSample]) -> Dict[str, torch.Tensor]:
"""Collate a batch of BenchmarkSamples."""
B = len(batch)
# ==================== Images ====================
# Use first image for now (multi-image MMMU handled separately)
images = []
for sample in batch:
img = sample.images[0]
if not isinstance(img, Image.Image):
img = Image.new('RGB', (self.image_size, self.image_size), 'white')
images.append(img.convert('RGB'))
# Process images through backbone processor
pixel_values = self.image_processor(
images=images,
return_tensors='pt',
)['pixel_values'] # [B, C, H, W]
# ==================== Question Text ====================
questions = [s.question for s in batch]
text_encoded = self.text_tokenizer(
questions,
padding='max_length',
truncation=True,
max_length=self.max_text_length,
return_tensors='pt',
)
# ==================== Options (MC) ====================
# Encode each option separately, pad to max_options
option_embeddings_list = []
option_masks = []
answer_labels = []
has_mc = any(s.options is not None for s in batch)
if has_mc:
for sample in batch:
if sample.options:
n_opts = min(len(sample.options), self.max_options)
# Tokenize options
opts_text = sample.options[:n_opts]
# Pad option text list to max_options
while len(opts_text) < self.max_options:
opts_text.append("")
mask = [True] * n_opts + [False] * (self.max_options - n_opts)
option_masks.append(mask)
# Answer label
if isinstance(sample.answer, int):
answer_labels.append(min(sample.answer, n_opts - 1))
elif isinstance(sample.answer, str) and len(sample.answer) == 1:
answer_labels.append(ord(sample.answer.upper()) - ord('A'))
else:
answer_labels.append(0)
else:
option_masks.append([False] * self.max_options)
answer_labels.append(0)
# ==================== Open-ended answers ====================
gen_target_ids = None
has_open = any(s.answer_type == 'open' for s in batch)
if has_open:
# Prepare generative targets
gen_texts = []
for sample in batch:
if sample.answer_type == 'open':
if isinstance(sample.answer, list):
gen_texts.append(str(sample.answer[0]))
else:
gen_texts.append(str(sample.answer))
else:
gen_texts.append("")
gen_encoded = self.text_tokenizer(
gen_texts,
padding='max_length',
truncation=True,
max_length=self.max_gen_length,
return_tensors='pt',
)
gen_target_ids = gen_encoded['input_ids']
# ==================== Build output dict ====================
result = {
'pixel_values': pixel_values,
'input_ids': text_encoded['input_ids'],
'attention_mask': text_encoded['attention_mask'],
}
if has_mc:
result['option_mask'] = torch.tensor(option_masks, dtype=torch.bool)
result['answer_labels'] = torch.tensor(answer_labels, dtype=torch.long)
# We need to encode options through text encoder at runtime
# Store raw option texts for the model to encode
all_option_texts = []
for sample in batch:
opts = sample.options or [""] * self.max_options
opts = opts[:self.max_options]
while len(opts) < self.max_options:
opts.append("")
all_option_texts.append(opts)
result['option_texts'] = all_option_texts
if gen_target_ids is not None:
result['gen_target_ids'] = gen_target_ids
# Metadata
result['benchmarks'] = [s.benchmark for s in batch]
result['answer_types'] = [s.answer_type for s in batch]
result['raw_answers'] = [s.answer for s in batch]
return result
def build_dataloader(
benchmark: str,
split: str,
image_processor,
text_tokenizer,
batch_size: int = 32,
num_workers: int = 4,
max_samples: Optional[int] = None,
config: Optional[str] = None,
**collator_kwargs,
) -> DataLoader:
"""Build a DataLoader for a specific benchmark."""
dataset = UnifiedBenchmarkDataset(
benchmark=benchmark,
split=split,
config=config,
max_samples=max_samples,
)
collator = MRJEPACollator(
image_processor=image_processor,
text_tokenizer=text_tokenizer,
**collator_kwargs,
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split in ('train', 'training')),
num_workers=num_workers,
collate_fn=collator,
pin_memory=True,
drop_last=(split in ('train', 'training')),
)