|
|
import os
|
|
|
import json
|
|
|
from PIL import Image
|
|
|
from collections import Counter
|
|
|
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
import torchvision.transforms as transforms
|
|
|
import spacy
|
|
|
|
|
|
|
|
|
spacy_eng = spacy.load("en_core_web_sm")
|
|
|
|
|
|
|
|
|
class Vocabulary:
|
|
|
def __init__(self, freq_threshold):
|
|
|
"""
|
|
|
freq_threshold: minimum word frequency to keep in vocab
|
|
|
"""
|
|
|
self.freq_threshold = freq_threshold
|
|
|
|
|
|
self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
|
|
|
self.stoi = {v: k for k, v in self.itos.items()}
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.itos)
|
|
|
|
|
|
@staticmethod
|
|
|
def tokenizer_eng(text):
|
|
|
"""
|
|
|
Uses spaCy tokenizer to split sentence into list of tokens
|
|
|
"""
|
|
|
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
|
|
|
|
|
def build_vocabulary(self, sentence_list):
|
|
|
"""
|
|
|
Builds vocab: {word -> index} for all words with freq >= threshold
|
|
|
"""
|
|
|
frequencies = Counter()
|
|
|
idx = 4
|
|
|
|
|
|
for sentence in sentence_list:
|
|
|
tokens = self.tokenizer_eng(sentence)
|
|
|
frequencies.update(tokens)
|
|
|
|
|
|
for word, freq in frequencies.items():
|
|
|
if freq >= self.freq_threshold:
|
|
|
self.stoi[word] = idx
|
|
|
self.itos[idx] = word
|
|
|
idx += 1
|
|
|
|
|
|
def numericalize(self, text):
|
|
|
"""
|
|
|
Converts text caption to list of vocab indices
|
|
|
"""
|
|
|
tokenized_text = self.tokenizer_eng(text)
|
|
|
return [
|
|
|
self.stoi.get(token, self.stoi["<unk>"])
|
|
|
for token in tokenized_text
|
|
|
]
|
|
|
|
|
|
|
|
|
class CaptionDataset(Dataset):
|
|
|
def __init__(self, images_dir, captions_file, vocab, transform=None):
|
|
|
"""
|
|
|
images_dir: path to images/train or images/val
|
|
|
captions_file: JSON file
|
|
|
vocab: Vocabulary object
|
|
|
transform: torchvision transform
|
|
|
"""
|
|
|
self.images_dir = images_dir
|
|
|
self.vocab = vocab
|
|
|
self.transform = transform
|
|
|
|
|
|
|
|
|
with open(captions_file, 'r') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
self.images = data["images"]
|
|
|
self.annotations = data["annotations"]
|
|
|
|
|
|
|
|
|
self.id_to_filename = {img["id"]: img["file_name"] for img in self.images}
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.annotations)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
ann = self.annotations[index]
|
|
|
image_id = ann["image_id"]
|
|
|
caption = ann["caption"]
|
|
|
|
|
|
|
|
|
img_path = os.path.join(self.images_dir, self.id_to_filename[image_id])
|
|
|
|
|
|
|
|
|
image = Image.open(img_path).convert("RGB")
|
|
|
|
|
|
if self.transform:
|
|
|
image = self.transform(image)
|
|
|
|
|
|
|
|
|
numericalized_caption = [self.vocab.stoi["<start>"]]
|
|
|
numericalized_caption += self.vocab.numericalize(caption)
|
|
|
numericalized_caption.append(self.vocab.stoi["<end>"])
|
|
|
|
|
|
return image, torch.tensor(numericalized_caption)
|
|
|
|
|
|
|
|
|
def build_vocab_from_json(captions_file, freq_threshold):
|
|
|
"""
|
|
|
Builds Vocabulary object from JSON file.
|
|
|
"""
|
|
|
with open(captions_file, 'r') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
all_captions = [ann["caption"] for ann in data["annotations"]]
|
|
|
|
|
|
vocab = Vocabulary(freq_threshold)
|
|
|
vocab.build_vocabulary(all_captions)
|
|
|
|
|
|
return vocab
|
|
|
|
|
|
|
|
|
def my_collate_fn(batch):
|
|
|
"""
|
|
|
Custom collate_fn for variable-length captions:
|
|
|
Pads captions in batch to max length in batch.
|
|
|
"""
|
|
|
images = []
|
|
|
captions = []
|
|
|
|
|
|
for img, cap in batch:
|
|
|
images.append(img)
|
|
|
captions.append(cap)
|
|
|
|
|
|
images = torch.stack(images, dim=0)
|
|
|
captions = pad_sequence(captions, batch_first=True, padding_value=0)
|
|
|
|
|
|
return images, captions
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
captions_train_json = "./Dataset/annotations/captions_train.json"
|
|
|
images_train_dir = "./Dataset/images/train/"
|
|
|
|
|
|
|
|
|
vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
|
|
|
print(f"Vocab size: {len(vocab)}")
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize((224, 224)),
|
|
|
transforms.ToTensor()
|
|
|
])
|
|
|
|
|
|
|
|
|
train_dataset = CaptionDataset(
|
|
|
images_dir=images_train_dir,
|
|
|
captions_file=captions_train_json,
|
|
|
vocab=vocab,
|
|
|
transform=transform
|
|
|
)
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
dataset=train_dataset,
|
|
|
batch_size=4,
|
|
|
shuffle=True,
|
|
|
collate_fn=my_collate_fn
|
|
|
)
|
|
|
|
|
|
|
|
|
for idx, (images, captions) in enumerate(train_loader):
|
|
|
print(f"\nBatch {idx + 1}")
|
|
|
print("Images shape:", images.shape)
|
|
|
print("Captions shape:", captions.shape)
|
|
|
print("Sample caption:", captions[0])
|
|
|
break
|
|
|
|