File size: 5,445 Bytes
d1e6a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import json
from PIL import Image
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms
import spacy

# ===== Load spaCy English tokenizer =====
spacy_eng = spacy.load("en_core_web_sm")


class Vocabulary:
    def __init__(self, freq_threshold):
        """

        freq_threshold: minimum word frequency to keep in vocab

        """
        self.freq_threshold = freq_threshold

        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        """

        Uses spaCy tokenizer to split sentence into list of tokens

        """
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        """

        Builds vocab: {word -> index} for all words with freq >= threshold

        """
        frequencies = Counter()
        idx = 4  # Start indexing after special tokens

        for sentence in sentence_list:
            tokens = self.tokenizer_eng(sentence)
            frequencies.update(tokens)

        for word, freq in frequencies.items():
            if freq >= self.freq_threshold:
                self.stoi[word] = idx
                self.itos[idx] = word
                idx += 1

    def numericalize(self, text):
        """

        Converts text caption to list of vocab indices

        """
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi.get(token, self.stoi["<unk>"])
            for token in tokenized_text
        ]


class CaptionDataset(Dataset):
    def __init__(self, images_dir, captions_file, vocab, transform=None):
        """

        images_dir: path to images/train or images/val

        captions_file: JSON file

        vocab: Vocabulary object

        transform: torchvision transform

        """
        self.images_dir = images_dir
        self.vocab = vocab
        self.transform = transform

        # Load JSON
        with open(captions_file, 'r') as f:
            data = json.load(f)

        self.images = data["images"]
        self.annotations = data["annotations"]

        # Create map: image_id -> file_name
        self.id_to_filename = {img["id"]: img["file_name"] for img in self.images}

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        ann = self.annotations[index]
        image_id = ann["image_id"]
        caption = ann["caption"]

        # Build image path
        img_path = os.path.join(self.images_dir, self.id_to_filename[image_id])

        # Open image
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # Numericalize caption + add <start> and <end> tokens
        numericalized_caption = [self.vocab.stoi["<start>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<end>"])

        return image, torch.tensor(numericalized_caption)


def build_vocab_from_json(captions_file, freq_threshold):
    """

    Builds Vocabulary object from JSON file.

    """
    with open(captions_file, 'r') as f:
        data = json.load(f)

    all_captions = [ann["caption"] for ann in data["annotations"]]

    vocab = Vocabulary(freq_threshold)
    vocab.build_vocabulary(all_captions)

    return vocab


def my_collate_fn(batch):
    """

    Custom collate_fn for variable-length captions:

    Pads captions in batch to max length in batch.

    """
    images = []
    captions = []

    for img, cap in batch:
        images.append(img)
        captions.append(cap)

    images = torch.stack(images, dim=0)
    captions = pad_sequence(captions, batch_first=True, padding_value=0)  # pad with <pad> token idx 0

    return images, captions


# ====== Test block ======
if __name__ == "__main__":
    # === Paths ===
    captions_train_json = "./Dataset/annotations/captions_train.json"
    images_train_dir = "./Dataset/images/train/"

    # === Build vocab ===
    vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
    print(f"Vocab size: {len(vocab)}")

    # === Transforms ===
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # === Create dataset ===
    train_dataset = CaptionDataset(
        images_dir=images_train_dir,
        captions_file=captions_train_json,
        vocab=vocab,
        transform=transform
    )

    # === DataLoader with custom collate_fn ===
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=my_collate_fn  # ✅ REQUIRED for variable-length captions
    )

    # === Test loop ===
    for idx, (images, captions) in enumerate(train_loader):
        print(f"\nBatch {idx + 1}")
        print("Images shape:", images.shape)      # [B, 3, H, W]
        print("Captions shape:", captions.shape)  # [B, T] (padded)
        print("Sample caption:", captions[0])
        break  # one batch test only