import os import json import pickle import random from collections import Counter from tqdm import tqdm from PIL import Image import torch from torch.utils.data import Dataset import torchvision.transforms as transforms from .utils import get_logger, get_eval_transform # Import logger and transforms from utils logger = get_logger(__name__) class COCOVocabulary: """ Vocabulary builder for COCO captions. Handles tokenization, building word-to-index and index-to-word mappings, and converting captions to numerical indices. """ def __init__(self, min_word_freq=5): """ Initializes the COCOVocabulary. Args: min_word_freq (int): Minimum frequency for a word to be included in the vocabulary. Words less frequent than this will be replaced by . """ self.min_word_freq = min_word_freq self.word2idx = {} # Maps words to their numerical indices self.idx2word = {} # Maps numerical indices back to words self.word_freq = Counter() # Counts frequency of each word self.vocab_size = 0 # Total number of unique words in the vocabulary def build_vocabulary(self, captions): """ Builds the vocabulary from a list of captions. Args: captions (list): A list of strings, where each string is a caption. """ logger.info("Building vocabulary...") # 1. Count word frequencies for caption in tqdm(captions, desc="Counting word frequencies"): tokens = self.tokenize(caption) self.word_freq.update(tokens) # 2. Add special tokens special_tokens = ['', '', '', ''] for token in special_tokens: if token not in self.word2idx: # Avoid re-adding if already present self.word2idx[token] = len(self.word2idx) self.idx2word[len(self.idx2word)] = token # 3. Add words that meet the minimum frequency threshold for word, freq in self.word_freq.items(): if freq >= self.min_word_freq: if word not in self.word2idx: # Avoid re-adding words if they are special tokens self.word2idx[word] = len(self.word2idx) self.idx2word[len(self.idx2word)] = word self.vocab_size = len(self.word2idx) logger.info(f"Vocabulary built successfully. Size: {self.vocab_size}") def tokenize(self, caption): """ Simple tokenization: convert to lowercase, strip leading/trailing spaces, and split by space. Normalizes multiple spaces. Args: caption (str): The input caption string. Returns: list: A list of tokenized words. """ caption = caption.lower().strip() # Normalize multiple spaces into a single space caption = ' '.join(caption.split()) tokens = caption.split() return tokens def caption_to_indices(self, caption, max_length=20): """ Converts a caption string into a list of numerical indices. Adds and tokens and pads with up to max_length. Args: caption (str): The input caption string. max_length (int): The maximum desired length for the indexed caption. Returns: list: A list of integer indices representing the caption. """ tokens = self.tokenize(caption) indices = [self.word2idx['']] # Start with the token for token in tokens: if len(indices) >= max_length - 1: # Reserve space for break idx = self.word2idx.get(token, self.word2idx['']) # Use for unknown words indices.append(idx) indices.append(self.word2idx['']) # End with the token # Pad with tokens if the caption is shorter than max_length while len(indices) < max_length: indices.append(self.word2idx['']) return indices[:max_length] # Ensure the caption does not exceed max_length def indices_to_caption(self, indices): """ Converts a list of numerical indices back into a human-readable caption string. Stops at token and ignores and tokens. Args: indices (list or numpy.ndarray): A list or array of integer indices. Returns: str: The reconstructed caption string. """ words = [] for idx in indices: word = self.idx2word.get(idx, '') # Get word, default to if word == '': break # Stop decoding when token is encountered if word not in ['', '']: # Ignore special tokens words.append(word) return ' '.join(words) class COCODataset(Dataset): """ PyTorch Dataset for COCO Image Captioning. Loads image paths and their corresponding captions, and returns preprocessed image tensors and indexed caption tensors. """ def __init__(self, image_dir, caption_file, vocabulary=None, max_caption_length=20, subset_size=None, transform=None): """ Initializes the COCODataset. Args: image_dir (str): Path to the directory containing COCO images (e.g., 'train2017', 'val2017'). caption_file (str): Path to the COCO captions JSON file (e.g., 'captions_train2017.json'). vocabulary (COCOVocabulary, optional): A pre-built COCOVocabulary object. If None, a new vocabulary will be built from the captions. max_caption_length (int): Maximum length for indexed captions. subset_size (int, optional): If specified, uses a random subset of this size from the dataset. transform (torchvision.transforms.Compose, optional): Image transformations to apply. """ self.image_dir = image_dir self.max_caption_length = max_caption_length self.transform = transform if transform is not None else get_eval_transform() # Default transform try: with open(caption_file, 'r') as f: self.coco_data = json.load(f) logger.info(f"Successfully loaded captions from {caption_file}") except FileNotFoundError: logger.error(f"Caption file not found at {caption_file}. Please check the path.") raise except json.JSONDecodeError: logger.error(f"Error decoding JSON from {caption_file}. Ensure it's a valid JSON file.") raise # Create a mapping from image ID to its filename for quick lookup self.id_to_filename = {img_info['id']: img_info['file_name'] for img_info in self.coco_data['images']} self.data = [] # Stores (image_path, caption, image_id) tuples missing_image_files = 0 # Process annotations to pair image paths with captions for ann in tqdm(self.coco_data['annotations'], desc="Processing annotations"): image_id = ann['image_id'] if image_id in self.id_to_filename: caption = ann['caption'] filename = self.id_to_filename[image_id] image_full_path = os.path.join(image_dir, filename) if os.path.exists(image_full_path): self.data.append({ 'image_path': image_full_path, 'caption': caption, 'image_id': image_id # Store original image_id for evaluation }) else: missing_image_files += 1 # logger.warning(f"Image file not found: {image_full_path}. Skipping this annotation.") else: logger.warning(f"Image ID {image_id} not found in images list. Skipping annotation.") if missing_image_files > 0: logger.warning(f"Skipped {missing_image_files} annotations due to missing image files. " "Please ensure all images are in the specified directory.") # If subset_size is specified, take a random sample if subset_size and subset_size < len(self.data): self.data = random.sample(self.data, subset_size) logger.info(f"Using subset of {subset_size} samples for the dataset.") logger.info(f"Dataset size after filtering: {len(self.data)} samples.") # Build vocabulary if not provided if vocabulary is None: self.vocabulary = COCOVocabulary() captions_for_vocab = [item['caption'] for item in self.data] self.vocabulary.build_vocabulary(captions_for_vocab) else: self.vocabulary = vocabulary def __len__(self): """Returns the total number of samples in the dataset.""" return len(self.data) def __getitem__(self, idx): """ Retrieves an item from the dataset at the given index. Returns: tuple: (image_tensor, caption_tensor, caption_length, image_id) """ item = self.data[idx] # Load and transform image try: image = Image.open(item['image_path']).convert('RGB') if self.transform: image = self.transform(image) except Exception as e: logger.error(f"Error loading image {item['image_path']}: {e}. Returning a black image as fallback.") # Return a black image tensor of expected size (3, 224, 224) if image loading fails image = torch.zeros(3, 224, 224) # Convert caption to indices caption_indices = self.vocabulary.caption_to_indices( item['caption'], self.max_caption_length ) caption_tensor = torch.tensor(caption_indices, dtype=torch.long) # Calculate actual length of the caption (excluding padding, including START/END) try: # Find the index of token, length is (index + 1) end_idx = caption_indices.index(self.vocabulary.word2idx['']) caption_length = end_idx + 1 except ValueError: # If not found (shouldn't happen with proper max_caption_length), # count non-PAD tokens. caption_length = len([idx for idx in caption_indices if idx != self.vocabulary.word2idx['']]) caption_length = torch.tensor(caption_length, dtype=torch.long) # Return image tensor, caption tensor, actual caption length, and original image ID return image, caption_tensor, caption_length, item['image_id']