# ========================================== # Data Processing & Vocabulary # ========================================== import os import json import re import random from collections import Counter, defaultdict from PIL import Image import torch from torch.utils.data import Dataset class Vocabulary: def __init__(self, freq_threshold=2): self.freq_threshold = freq_threshold self.pad_token = "" self.bos_token = "" self.eos_token = "" self.unk_token = "" self.word2idx = { self.pad_token: 0, self.bos_token: 1, self.eos_token: 2, self.unk_token: 3, } self.idx2word = {v: k for k, v in self.word2idx.items()} def __len__(self): return len(self.word2idx) def clean_text(self, text): text = text.lower() text = re.sub(r"[^a-z\s]", "", text) text = re.sub(r"\s+", " ", text).strip() return text def build_vocab(self, captions): counter = Counter() idx = len(self.word2idx) for cap in captions: cap = self.clean_text(cap) counter.update(cap.split()) for word, freq in counter.items(): if freq >= self.freq_threshold: self.word2idx[word] = idx self.idx2word[idx] = word idx += 1 def tokenize_caption(caption, vocab, max_len): """Tokenizes a caption string.""" caption = vocab.clean_text(caption) tokens = caption.split() # Truncate to max_len - 2 (for bos and eos) tokens = tokens[: max_len - 2] ids = [vocab.word2idx.get(tok, vocab.word2idx[""]) for tok in tokens] return [vocab.word2idx[""]] + ids + [vocab.word2idx[""]] class CocoCaptionDataset(Dataset): def __init__(self, images_dir, annotation_file, vocab, max_len=50, transform=None): self.images_dir = images_dir self.transform = transform self.vocab = vocab self.max_len = max_len # Load JSOn print(f"Loading annotations from {annotation_file}...") with open(annotation_file, "r") as f: coco = json.load(f) self.images = {img["id"]: img for img in coco["images"]} # image_id -> list of captions self.image_id_to_captions = defaultdict(list) for ann in coco["annotations"]: self.image_id_to_captions[ann["image_id"]].append(ann["caption"]) # Use one sample per image; pick a random caption each __getitem__ self.image_ids = list(self.image_id_to_captions.keys()) total_captions = sum(len(caps) for caps in self.image_id_to_captions.values()) print( f"Loaded {len(self.image_ids)} images with {total_captions} captions." ) def __len__(self): return len(self.image_ids) def __getitem__(self, idx): image_id = self.image_ids[idx] captions = self.image_id_to_captions[image_id] caption = random.choice(captions) image_info = self.images[image_id] file_name = image_info["file_name"] # Load image image_path = os.path.join(self.images_dir, file_name) try: image = Image.open(image_path).convert("RGB") except FileNotFoundError: # Handle missing images (can happen in partial datasets) # creating a blank image or raising error print(f"Warning: Image not found {image_path}") image = Image.new('RGB', (224, 224), (0, 0, 0)) if self.transform: image = self.transform(image) # Tokenize caption caption_ids = tokenize_caption(caption, self.vocab, self.max_len) return image, torch.tensor(caption_ids, dtype=torch.long), image_id def collate_fn(batch): """ Custom collate function to handle variable length captions. Returns: images: (batch_size, 3, 224, 224) padded_captions: (batch_size, max_seq_len) attention_mask: (batch_size, max_seq_len) - 1 for token, 0 for pad image_ids: Tuple of image ids """ images, captions, image_ids = zip(*batch) # Stack images images = torch.stack(images) # Pad captions lengths = [len(c) for c in captions] max_len = max(lengths) padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long) padding_mask = torch.ones(len(captions), max_len, dtype=torch.bool) for i, cap in enumerate(captions): end = lengths[i] padded_captions[i, :end] = cap padding_mask[i, :end] = False return images, padded_captions, padding_mask, image_ids