vit-image-captioning / data_processing_vocabulary.py
mostafahagali's picture
Upload 9 files
601cad6 verified
# ==========================================
# Data Processing & Vocabulary
# ==========================================
import os
import json
import re
import random
from collections import Counter, defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
class Vocabulary:
def __init__(self, freq_threshold=2):
self.freq_threshold = freq_threshold
self.pad_token = "<pad>"
self.bos_token = "<bos>"
self.eos_token = "<eos>"
self.unk_token = "<unk>"
self.word2idx = {
self.pad_token: 0,
self.bos_token: 1,
self.eos_token: 2,
self.unk_token: 3,
}
self.idx2word = {v: k for k, v in self.word2idx.items()}
def __len__(self):
return len(self.word2idx)
def clean_text(self, text):
text = text.lower()
text = re.sub(r"[^a-z\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def build_vocab(self, captions):
counter = Counter()
idx = len(self.word2idx)
for cap in captions:
cap = self.clean_text(cap)
counter.update(cap.split())
for word, freq in counter.items():
if freq >= self.freq_threshold:
self.word2idx[word] = idx
self.idx2word[idx] = word
idx += 1
def tokenize_caption(caption, vocab, max_len):
"""Tokenizes a caption string."""
caption = vocab.clean_text(caption)
tokens = caption.split()
# Truncate to max_len - 2 (for bos and eos)
tokens = tokens[: max_len - 2]
ids = [vocab.word2idx.get(tok, vocab.word2idx["<unk>"]) for tok in tokens]
return [vocab.word2idx["<bos>"]] + ids + [vocab.word2idx["<eos>"]]
class CocoCaptionDataset(Dataset):
def __init__(self,
images_dir,
annotation_file,
vocab,
max_len=50,
transform=None):
self.images_dir = images_dir
self.transform = transform
self.vocab = vocab
self.max_len = max_len
# Load JSOn
print(f"Loading annotations from {annotation_file}...")
with open(annotation_file, "r") as f:
coco = json.load(f)
self.images = {img["id"]: img for img in coco["images"]}
# image_id -> list of captions
self.image_id_to_captions = defaultdict(list)
for ann in coco["annotations"]:
self.image_id_to_captions[ann["image_id"]].append(ann["caption"])
# Use one sample per image; pick a random caption each __getitem__
self.image_ids = list(self.image_id_to_captions.keys())
total_captions = sum(len(caps) for caps in self.image_id_to_captions.values())
print(
f"Loaded {len(self.image_ids)} images with {total_captions} captions."
)
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
image_id = self.image_ids[idx]
captions = self.image_id_to_captions[image_id]
caption = random.choice(captions)
image_info = self.images[image_id]
file_name = image_info["file_name"]
# Load image
image_path = os.path.join(self.images_dir, file_name)
try:
image = Image.open(image_path).convert("RGB")
except FileNotFoundError:
# Handle missing images (can happen in partial datasets)
# creating a blank image or raising error
print(f"Warning: Image not found {image_path}")
image = Image.new('RGB', (224, 224), (0, 0, 0))
if self.transform:
image = self.transform(image)
# Tokenize caption
caption_ids = tokenize_caption(caption, self.vocab, self.max_len)
return image, torch.tensor(caption_ids, dtype=torch.long), image_id
def collate_fn(batch):
"""
Custom collate function to handle variable length captions.
Returns:
images: (batch_size, 3, 224, 224)
padded_captions: (batch_size, max_seq_len)
attention_mask: (batch_size, max_seq_len) - 1 for token, 0 for pad
image_ids: Tuple of image ids
"""
images, captions, image_ids = zip(*batch)
# Stack images
images = torch.stack(images)
# Pad captions
lengths = [len(c) for c in captions]
max_len = max(lengths)
padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long)
padding_mask = torch.ones(len(captions), max_len, dtype=torch.bool)
for i, cap in enumerate(captions):
end = lengths[i]
padded_captions[i, :end] = cap
padding_mask[i, :end] = False
return images, padded_captions, padding_mask, image_ids