Jaiking001's picture
first commit
d1e6a4c verified
import os
import json
from PIL import Image
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms
import spacy
# ===== Load spaCy English tokenizer =====
spacy_eng = spacy.load("en_core_web_sm")
class Vocabulary:
def __init__(self, freq_threshold):
"""
freq_threshold: minimum word frequency to keep in vocab
"""
self.freq_threshold = freq_threshold
self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
self.stoi = {v: k for k, v in self.itos.items()}
def __len__(self):
return len(self.itos)
@staticmethod
def tokenizer_eng(text):
"""
Uses spaCy tokenizer to split sentence into list of tokens
"""
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
def build_vocabulary(self, sentence_list):
"""
Builds vocab: {word -> index} for all words with freq >= threshold
"""
frequencies = Counter()
idx = 4 # Start indexing after special tokens
for sentence in sentence_list:
tokens = self.tokenizer_eng(sentence)
frequencies.update(tokens)
for word, freq in frequencies.items():
if freq >= self.freq_threshold:
self.stoi[word] = idx
self.itos[idx] = word
idx += 1
def numericalize(self, text):
"""
Converts text caption to list of vocab indices
"""
tokenized_text = self.tokenizer_eng(text)
return [
self.stoi.get(token, self.stoi["<unk>"])
for token in tokenized_text
]
class CaptionDataset(Dataset):
def __init__(self, images_dir, captions_file, vocab, transform=None):
"""
images_dir: path to images/train or images/val
captions_file: JSON file
vocab: Vocabulary object
transform: torchvision transform
"""
self.images_dir = images_dir
self.vocab = vocab
self.transform = transform
# Load JSON
with open(captions_file, 'r') as f:
data = json.load(f)
self.images = data["images"]
self.annotations = data["annotations"]
# Create map: image_id -> file_name
self.id_to_filename = {img["id"]: img["file_name"] for img in self.images}
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
ann = self.annotations[index]
image_id = ann["image_id"]
caption = ann["caption"]
# Build image path
img_path = os.path.join(self.images_dir, self.id_to_filename[image_id])
# Open image
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
# Numericalize caption + add <start> and <end> tokens
numericalized_caption = [self.vocab.stoi["<start>"]]
numericalized_caption += self.vocab.numericalize(caption)
numericalized_caption.append(self.vocab.stoi["<end>"])
return image, torch.tensor(numericalized_caption)
def build_vocab_from_json(captions_file, freq_threshold):
"""
Builds Vocabulary object from JSON file.
"""
with open(captions_file, 'r') as f:
data = json.load(f)
all_captions = [ann["caption"] for ann in data["annotations"]]
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(all_captions)
return vocab
def my_collate_fn(batch):
"""
Custom collate_fn for variable-length captions:
Pads captions in batch to max length in batch.
"""
images = []
captions = []
for img, cap in batch:
images.append(img)
captions.append(cap)
images = torch.stack(images, dim=0)
captions = pad_sequence(captions, batch_first=True, padding_value=0) # pad with <pad> token idx 0
return images, captions
# ====== Test block ======
if __name__ == "__main__":
# === Paths ===
captions_train_json = "./Dataset/annotations/captions_train.json"
images_train_dir = "./Dataset/images/train/"
# === Build vocab ===
vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
print(f"Vocab size: {len(vocab)}")
# === Transforms ===
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# === Create dataset ===
train_dataset = CaptionDataset(
images_dir=images_train_dir,
captions_file=captions_train_json,
vocab=vocab,
transform=transform
)
# === DataLoader with custom collate_fn ===
train_loader = DataLoader(
dataset=train_dataset,
batch_size=4,
shuffle=True,
collate_fn=my_collate_fn # ✅ REQUIRED for variable-length captions
)
# === Test loop ===
for idx, (images, captions) in enumerate(train_loader):
print(f"\nBatch {idx + 1}")
print("Images shape:", images.shape) # [B, 3, H, W]
print("Captions shape:", captions.shape) # [B, T] (padded)
print("Sample caption:", captions[0])
break # one batch test only