File size: 10,742 Bytes
ee1d4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import os
import json
import pickle
import random
from collections import Counter
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from .utils import get_logger, get_eval_transform # Import logger and transforms from utils

logger = get_logger(__name__)


class COCOVocabulary:
    """
    Vocabulary builder for COCO captions.
    Handles tokenization, building word-to-index and index-to-word mappings,
    and converting captions to numerical indices.
    """
    def __init__(self, min_word_freq=5):
        """
        Initializes the COCOVocabulary.
        Args:
            min_word_freq (int): Minimum frequency for a word to be included in the vocabulary.
                                 Words less frequent than this will be replaced by <UNK>.
        """
        self.min_word_freq = min_word_freq
        self.word2idx = {} # Maps words to their numerical indices
        self.idx2word = {} # Maps numerical indices back to words
        self.word_freq = Counter() # Counts frequency of each word
        self.vocab_size = 0 # Total number of unique words in the vocabulary

    def build_vocabulary(self, captions):
        """
        Builds the vocabulary from a list of captions.
        Args:
            captions (list): A list of strings, where each string is a caption.
        """
        logger.info("Building vocabulary...")

        # 1. Count word frequencies
        for caption in tqdm(captions, desc="Counting word frequencies"):
            tokens = self.tokenize(caption)
            self.word_freq.update(tokens)

        # 2. Add special tokens
        special_tokens = ['<PAD>', '<START>', '<END>', '<UNK>']
        for token in special_tokens:
            if token not in self.word2idx: # Avoid re-adding if already present
                self.word2idx[token] = len(self.word2idx)
                self.idx2word[len(self.idx2word)] = token

        # 3. Add words that meet the minimum frequency threshold
        for word, freq in self.word_freq.items():
            if freq >= self.min_word_freq:
                if word not in self.word2idx: # Avoid re-adding words if they are special tokens
                    self.word2idx[word] = len(self.word2idx)
                    self.idx2word[len(self.idx2word)] = word

        self.vocab_size = len(self.word2idx)
        logger.info(f"Vocabulary built successfully. Size: {self.vocab_size}")

    def tokenize(self, caption):
        """
        Simple tokenization: convert to lowercase, strip leading/trailing spaces,
        and split by space. Normalizes multiple spaces.
        Args:
            caption (str): The input caption string.
        Returns:
            list: A list of tokenized words.
        """
        caption = caption.lower().strip()
        # Normalize multiple spaces into a single space
        caption = ' '.join(caption.split())
        tokens = caption.split()
        return tokens

    def caption_to_indices(self, caption, max_length=20):
        """
        Converts a caption string into a list of numerical indices.
        Adds <START> and <END> tokens and pads with <PAD> up to max_length.
        Args:
            caption (str): The input caption string.
            max_length (int): The maximum desired length for the indexed caption.
        Returns:
            list: A list of integer indices representing the caption.
        """
        tokens = self.tokenize(caption)
        indices = [self.word2idx['<START>']] # Start with the <START> token

        for token in tokens:
            if len(indices) >= max_length - 1: # Reserve space for <END>
                break
            idx = self.word2idx.get(token, self.word2idx['<UNK>']) # Use <UNK> for unknown words
            indices.append(idx)

        indices.append(self.word2idx['<END>']) # End with the <END> token

        # Pad with <PAD> tokens if the caption is shorter than max_length
        while len(indices) < max_length:
            indices.append(self.word2idx['<PAD>'])

        return indices[:max_length] # Ensure the caption does not exceed max_length

    def indices_to_caption(self, indices):
        """
        Converts a list of numerical indices back into a human-readable caption string.
        Stops at <END> token and ignores <PAD> and <START> tokens.
        Args:
            indices (list or numpy.ndarray): A list or array of integer indices.
        Returns:
            str: The reconstructed caption string.
        """
        words = []
        for idx in indices:
            word = self.idx2word.get(idx, '<UNK>') # Get word, default to <UNK>
            if word == '<END>':
                break # Stop decoding when <END> token is encountered
            if word not in ['<PAD>', '<START>']: # Ignore special tokens
                words.append(word)
        return ' '.join(words)


class COCODataset(Dataset):
    """
    PyTorch Dataset for COCO Image Captioning.
    Loads image paths and their corresponding captions,
    and returns preprocessed image tensors and indexed caption tensors.
    """
    def __init__(self, image_dir, caption_file, vocabulary=None,
                 max_caption_length=20, subset_size=None, transform=None):
        """
        Initializes the COCODataset.
        Args:
            image_dir (str): Path to the directory containing COCO images (e.g., 'train2017', 'val2017').
            caption_file (str): Path to the COCO captions JSON file (e.g., 'captions_train2017.json').
            vocabulary (COCOVocabulary, optional): A pre-built COCOVocabulary object. If None,
                                                   a new vocabulary will be built from the captions.
            max_caption_length (int): Maximum length for indexed captions.
            subset_size (int, optional): If specified, uses a random subset of this size from the dataset.
            transform (torchvision.transforms.Compose, optional): Image transformations to apply.
        """
        self.image_dir = image_dir
        self.max_caption_length = max_caption_length
        self.transform = transform if transform is not None else get_eval_transform() # Default transform

        try:
            with open(caption_file, 'r') as f:
                self.coco_data = json.load(f)
            logger.info(f"Successfully loaded captions from {caption_file}")
        except FileNotFoundError:
            logger.error(f"Caption file not found at {caption_file}. Please check the path.")
            raise
        except json.JSONDecodeError:
            logger.error(f"Error decoding JSON from {caption_file}. Ensure it's a valid JSON file.")
            raise

        # Create a mapping from image ID to its filename for quick lookup
        self.id_to_filename = {img_info['id']: img_info['file_name'] for img_info in self.coco_data['images']}

        self.data = [] # Stores (image_path, caption, image_id) tuples
        missing_image_files = 0

        # Process annotations to pair image paths with captions
        for ann in tqdm(self.coco_data['annotations'], desc="Processing annotations"):
            image_id = ann['image_id']
            if image_id in self.id_to_filename:
                caption = ann['caption']
                filename = self.id_to_filename[image_id]
                image_full_path = os.path.join(image_dir, filename)

                if os.path.exists(image_full_path):
                    self.data.append({
                        'image_path': image_full_path,
                        'caption': caption,
                        'image_id': image_id # Store original image_id for evaluation
                    })
                else:
                    missing_image_files += 1
                    # logger.warning(f"Image file not found: {image_full_path}. Skipping this annotation.")
            else:
                logger.warning(f"Image ID {image_id} not found in images list. Skipping annotation.")

        if missing_image_files > 0:
            logger.warning(f"Skipped {missing_image_files} annotations due to missing image files. "
                           "Please ensure all images are in the specified directory.")

        # If subset_size is specified, take a random sample
        if subset_size and subset_size < len(self.data):
            self.data = random.sample(self.data, subset_size)
            logger.info(f"Using subset of {subset_size} samples for the dataset.")

        logger.info(f"Dataset size after filtering: {len(self.data)} samples.")

        # Build vocabulary if not provided
        if vocabulary is None:
            self.vocabulary = COCOVocabulary()
            captions_for_vocab = [item['caption'] for item in self.data]
            self.vocabulary.build_vocabulary(captions_for_vocab)
        else:
            self.vocabulary = vocabulary

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset at the given index.
        Returns:
            tuple: (image_tensor, caption_tensor, caption_length, image_id)
        """
        item = self.data[idx]

        # Load and transform image
        try:
            image = Image.open(item['image_path']).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            logger.error(f"Error loading image {item['image_path']}: {e}. Returning a black image as fallback.")
            # Return a black image tensor of expected size (3, 224, 224) if image loading fails
            image = torch.zeros(3, 224, 224)

        # Convert caption to indices
        caption_indices = self.vocabulary.caption_to_indices(
            item['caption'], self.max_caption_length
        )
        caption_tensor = torch.tensor(caption_indices, dtype=torch.long)

        # Calculate actual length of the caption (excluding padding, including START/END)
        try:
            # Find the index of <END> token, length is (index + 1)
            end_idx = caption_indices.index(self.vocabulary.word2idx['<END>'])
            caption_length = end_idx + 1
        except ValueError:
            # If <END> not found (shouldn't happen with proper max_caption_length),
            # count non-PAD tokens.
            caption_length = len([idx for idx in caption_indices if idx != self.vocabulary.word2idx['<PAD>']])

        caption_length = torch.tensor(caption_length, dtype=torch.long)

        # Return image tensor, caption tensor, actual caption length, and original image ID
        return image, caption_tensor, caption_length, item['image_id']