import os import json import random from PIL import Image import torch from torch.utils.data import Dataset from torchvision import transforms from transformers import T5TokenizerFast from data.transforms import build_coco_transform class CocoCaptionDataset(Dataset): def __init__( self, split="train", image_size=224, tokenizer_name="t5-small", max_caption_length=64, data_dir="data/processed", random_caption=True, normalize=True, ): assert split in ["train", "val", "test"] self.split = split self.image_size = image_size self.random_caption = random_caption self.max_caption_length = max_caption_length self.images_dir = os.path.join(data_dir, "images") self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_name) # Load captions.json and splits.json captions_file = os.path.join(data_dir, "captions.json") splits_file = os.path.join(data_dir, "splits.json") with open(captions_file) as f: self.captions_data = json.load(f) with open(splits_file) as f: self.splits = json.load(f) # Cast IDs to strings self.image_ids = [str(i) for i in self.splits[split]] self.transform = build_coco_transform(image_size=image_size) def __len__(self): return len(self.image_ids) def __getitem__(self, idx): image_id = self.image_ids[idx] img_path = os.path.join(self.images_dir, f"{int(image_id):012d}.jpg") img = Image.open(img_path).convert("RGB") pixel_values = self.transform(img) captions = self.captions_data[image_id]["captions"] if self.random_caption: caption = random.choice(captions) else: caption = captions[0] # deterministic for eval # Tokenize caption (no prefix needed for T5 small) encoding = self.tokenizer( caption, padding="max_length", truncation=True, max_length=self.max_caption_length, return_tensors="pt" ) input_ids = encoding.input_ids.squeeze(0) attention_mask = encoding.attention_mask.squeeze(0) return { "pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, "image_id": image_id, }