import os import json from PIL import Image from collections import Counter import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import torchvision.transforms as transforms import spacy # ===== Load spaCy English tokenizer ===== spacy_eng = spacy.load("en_core_web_sm") class Vocabulary: def __init__(self, freq_threshold): """ freq_threshold: minimum word frequency to keep in vocab """ self.freq_threshold = freq_threshold self.itos = {0: "", 1: "", 2: "", 3: ""} self.stoi = {v: k for k, v in self.itos.items()} def __len__(self): return len(self.itos) @staticmethod def tokenizer_eng(text): """ Uses spaCy tokenizer to split sentence into list of tokens """ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)] def build_vocabulary(self, sentence_list): """ Builds vocab: {word -> index} for all words with freq >= threshold """ frequencies = Counter() idx = 4 # Start indexing after special tokens for sentence in sentence_list: tokens = self.tokenizer_eng(sentence) frequencies.update(tokens) for word, freq in frequencies.items(): if freq >= self.freq_threshold: self.stoi[word] = idx self.itos[idx] = word idx += 1 def numericalize(self, text): """ Converts text caption to list of vocab indices """ tokenized_text = self.tokenizer_eng(text) return [ self.stoi.get(token, self.stoi[""]) for token in tokenized_text ] class CaptionDataset(Dataset): def __init__(self, images_dir, captions_file, vocab, transform=None): """ images_dir: path to images/train or images/val captions_file: JSON file vocab: Vocabulary object transform: torchvision transform """ self.images_dir = images_dir self.vocab = vocab self.transform = transform # Load JSON with open(captions_file, 'r') as f: data = json.load(f) self.images = data["images"] self.annotations = data["annotations"] # Create map: image_id -> file_name self.id_to_filename = {img["id"]: img["file_name"] for img in self.images} def __len__(self): return len(self.annotations) def __getitem__(self, index): ann = self.annotations[index] image_id = ann["image_id"] caption = ann["caption"] # Build image path img_path = os.path.join(self.images_dir, self.id_to_filename[image_id]) # Open image image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) # Numericalize caption + add and tokens numericalized_caption = [self.vocab.stoi[""]] numericalized_caption += self.vocab.numericalize(caption) numericalized_caption.append(self.vocab.stoi[""]) return image, torch.tensor(numericalized_caption) def build_vocab_from_json(captions_file, freq_threshold): """ Builds Vocabulary object from JSON file. """ with open(captions_file, 'r') as f: data = json.load(f) all_captions = [ann["caption"] for ann in data["annotations"]] vocab = Vocabulary(freq_threshold) vocab.build_vocabulary(all_captions) return vocab def my_collate_fn(batch): """ Custom collate_fn for variable-length captions: Pads captions in batch to max length in batch. """ images = [] captions = [] for img, cap in batch: images.append(img) captions.append(cap) images = torch.stack(images, dim=0) captions = pad_sequence(captions, batch_first=True, padding_value=0) # pad with token idx 0 return images, captions # ====== Test block ====== if __name__ == "__main__": # === Paths === captions_train_json = "./Dataset/annotations/captions_train.json" images_train_dir = "./Dataset/images/train/" # === Build vocab === vocab = build_vocab_from_json(captions_train_json, freq_threshold=2) print(f"Vocab size: {len(vocab)}") # === Transforms === transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # === Create dataset === train_dataset = CaptionDataset( images_dir=images_train_dir, captions_file=captions_train_json, vocab=vocab, transform=transform ) # === DataLoader with custom collate_fn === train_loader = DataLoader( dataset=train_dataset, batch_size=4, shuffle=True, collate_fn=my_collate_fn # ✅ REQUIRED for variable-length captions ) # === Test loop === for idx, (images, captions) in enumerate(train_loader): print(f"\nBatch {idx + 1}") print("Images shape:", images.shape) # [B, 3, H, W] print("Captions shape:", captions.shape) # [B, T] (padded) print("Sample caption:", captions[0]) break # one batch test only