Spaces:
Build error
Build error
File size: 10,742 Bytes
ee1d4aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import os
import json
import pickle
import random
from collections import Counter
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from .utils import get_logger, get_eval_transform # Import logger and transforms from utils
logger = get_logger(__name__)
class COCOVocabulary:
"""
Vocabulary builder for COCO captions.
Handles tokenization, building word-to-index and index-to-word mappings,
and converting captions to numerical indices.
"""
def __init__(self, min_word_freq=5):
"""
Initializes the COCOVocabulary.
Args:
min_word_freq (int): Minimum frequency for a word to be included in the vocabulary.
Words less frequent than this will be replaced by <UNK>.
"""
self.min_word_freq = min_word_freq
self.word2idx = {} # Maps words to their numerical indices
self.idx2word = {} # Maps numerical indices back to words
self.word_freq = Counter() # Counts frequency of each word
self.vocab_size = 0 # Total number of unique words in the vocabulary
def build_vocabulary(self, captions):
"""
Builds the vocabulary from a list of captions.
Args:
captions (list): A list of strings, where each string is a caption.
"""
logger.info("Building vocabulary...")
# 1. Count word frequencies
for caption in tqdm(captions, desc="Counting word frequencies"):
tokens = self.tokenize(caption)
self.word_freq.update(tokens)
# 2. Add special tokens
special_tokens = ['<PAD>', '<START>', '<END>', '<UNK>']
for token in special_tokens:
if token not in self.word2idx: # Avoid re-adding if already present
self.word2idx[token] = len(self.word2idx)
self.idx2word[len(self.idx2word)] = token
# 3. Add words that meet the minimum frequency threshold
for word, freq in self.word_freq.items():
if freq >= self.min_word_freq:
if word not in self.word2idx: # Avoid re-adding words if they are special tokens
self.word2idx[word] = len(self.word2idx)
self.idx2word[len(self.idx2word)] = word
self.vocab_size = len(self.word2idx)
logger.info(f"Vocabulary built successfully. Size: {self.vocab_size}")
def tokenize(self, caption):
"""
Simple tokenization: convert to lowercase, strip leading/trailing spaces,
and split by space. Normalizes multiple spaces.
Args:
caption (str): The input caption string.
Returns:
list: A list of tokenized words.
"""
caption = caption.lower().strip()
# Normalize multiple spaces into a single space
caption = ' '.join(caption.split())
tokens = caption.split()
return tokens
def caption_to_indices(self, caption, max_length=20):
"""
Converts a caption string into a list of numerical indices.
Adds <START> and <END> tokens and pads with <PAD> up to max_length.
Args:
caption (str): The input caption string.
max_length (int): The maximum desired length for the indexed caption.
Returns:
list: A list of integer indices representing the caption.
"""
tokens = self.tokenize(caption)
indices = [self.word2idx['<START>']] # Start with the <START> token
for token in tokens:
if len(indices) >= max_length - 1: # Reserve space for <END>
break
idx = self.word2idx.get(token, self.word2idx['<UNK>']) # Use <UNK> for unknown words
indices.append(idx)
indices.append(self.word2idx['<END>']) # End with the <END> token
# Pad with <PAD> tokens if the caption is shorter than max_length
while len(indices) < max_length:
indices.append(self.word2idx['<PAD>'])
return indices[:max_length] # Ensure the caption does not exceed max_length
def indices_to_caption(self, indices):
"""
Converts a list of numerical indices back into a human-readable caption string.
Stops at <END> token and ignores <PAD> and <START> tokens.
Args:
indices (list or numpy.ndarray): A list or array of integer indices.
Returns:
str: The reconstructed caption string.
"""
words = []
for idx in indices:
word = self.idx2word.get(idx, '<UNK>') # Get word, default to <UNK>
if word == '<END>':
break # Stop decoding when <END> token is encountered
if word not in ['<PAD>', '<START>']: # Ignore special tokens
words.append(word)
return ' '.join(words)
class COCODataset(Dataset):
"""
PyTorch Dataset for COCO Image Captioning.
Loads image paths and their corresponding captions,
and returns preprocessed image tensors and indexed caption tensors.
"""
def __init__(self, image_dir, caption_file, vocabulary=None,
max_caption_length=20, subset_size=None, transform=None):
"""
Initializes the COCODataset.
Args:
image_dir (str): Path to the directory containing COCO images (e.g., 'train2017', 'val2017').
caption_file (str): Path to the COCO captions JSON file (e.g., 'captions_train2017.json').
vocabulary (COCOVocabulary, optional): A pre-built COCOVocabulary object. If None,
a new vocabulary will be built from the captions.
max_caption_length (int): Maximum length for indexed captions.
subset_size (int, optional): If specified, uses a random subset of this size from the dataset.
transform (torchvision.transforms.Compose, optional): Image transformations to apply.
"""
self.image_dir = image_dir
self.max_caption_length = max_caption_length
self.transform = transform if transform is not None else get_eval_transform() # Default transform
try:
with open(caption_file, 'r') as f:
self.coco_data = json.load(f)
logger.info(f"Successfully loaded captions from {caption_file}")
except FileNotFoundError:
logger.error(f"Caption file not found at {caption_file}. Please check the path.")
raise
except json.JSONDecodeError:
logger.error(f"Error decoding JSON from {caption_file}. Ensure it's a valid JSON file.")
raise
# Create a mapping from image ID to its filename for quick lookup
self.id_to_filename = {img_info['id']: img_info['file_name'] for img_info in self.coco_data['images']}
self.data = [] # Stores (image_path, caption, image_id) tuples
missing_image_files = 0
# Process annotations to pair image paths with captions
for ann in tqdm(self.coco_data['annotations'], desc="Processing annotations"):
image_id = ann['image_id']
if image_id in self.id_to_filename:
caption = ann['caption']
filename = self.id_to_filename[image_id]
image_full_path = os.path.join(image_dir, filename)
if os.path.exists(image_full_path):
self.data.append({
'image_path': image_full_path,
'caption': caption,
'image_id': image_id # Store original image_id for evaluation
})
else:
missing_image_files += 1
# logger.warning(f"Image file not found: {image_full_path}. Skipping this annotation.")
else:
logger.warning(f"Image ID {image_id} not found in images list. Skipping annotation.")
if missing_image_files > 0:
logger.warning(f"Skipped {missing_image_files} annotations due to missing image files. "
"Please ensure all images are in the specified directory.")
# If subset_size is specified, take a random sample
if subset_size and subset_size < len(self.data):
self.data = random.sample(self.data, subset_size)
logger.info(f"Using subset of {subset_size} samples for the dataset.")
logger.info(f"Dataset size after filtering: {len(self.data)} samples.")
# Build vocabulary if not provided
if vocabulary is None:
self.vocabulary = COCOVocabulary()
captions_for_vocab = [item['caption'] for item in self.data]
self.vocabulary.build_vocabulary(captions_for_vocab)
else:
self.vocabulary = vocabulary
def __len__(self):
"""Returns the total number of samples in the dataset."""
return len(self.data)
def __getitem__(self, idx):
"""
Retrieves an item from the dataset at the given index.
Returns:
tuple: (image_tensor, caption_tensor, caption_length, image_id)
"""
item = self.data[idx]
# Load and transform image
try:
image = Image.open(item['image_path']).convert('RGB')
if self.transform:
image = self.transform(image)
except Exception as e:
logger.error(f"Error loading image {item['image_path']}: {e}. Returning a black image as fallback.")
# Return a black image tensor of expected size (3, 224, 224) if image loading fails
image = torch.zeros(3, 224, 224)
# Convert caption to indices
caption_indices = self.vocabulary.caption_to_indices(
item['caption'], self.max_caption_length
)
caption_tensor = torch.tensor(caption_indices, dtype=torch.long)
# Calculate actual length of the caption (excluding padding, including START/END)
try:
# Find the index of <END> token, length is (index + 1)
end_idx = caption_indices.index(self.vocabulary.word2idx['<END>'])
caption_length = end_idx + 1
except ValueError:
# If <END> not found (shouldn't happen with proper max_caption_length),
# count non-PAD tokens.
caption_length = len([idx for idx in caption_indices if idx != self.vocabulary.word2idx['<PAD>']])
caption_length = torch.tensor(caption_length, dtype=torch.long)
# Return image tensor, caption tensor, actual caption length, and original image ID
return image, caption_tensor, caption_length, item['image_id']
|