|
|
|
|
|
|
|
|
| import os
|
| import json
|
| import re
|
| import random
|
| from collections import Counter, defaultdict
|
| from PIL import Image
|
| import torch
|
| from torch.utils.data import Dataset
|
|
|
|
|
|
|
| class Vocabulary:
|
| def __init__(self, freq_threshold=2):
|
| self.freq_threshold = freq_threshold
|
|
|
| self.pad_token = "<pad>"
|
| self.bos_token = "<bos>"
|
| self.eos_token = "<eos>"
|
| self.unk_token = "<unk>"
|
|
|
| self.word2idx = {
|
| self.pad_token: 0,
|
| self.bos_token: 1,
|
| self.eos_token: 2,
|
| self.unk_token: 3,
|
| }
|
| self.idx2word = {v: k for k, v in self.word2idx.items()}
|
|
|
| def __len__(self):
|
| return len(self.word2idx)
|
|
|
| def clean_text(self, text):
|
| text = text.lower()
|
| text = re.sub(r"[^a-z\s]", "", text)
|
| text = re.sub(r"\s+", " ", text).strip()
|
| return text
|
|
|
| def build_vocab(self, captions):
|
| counter = Counter()
|
| idx = len(self.word2idx)
|
|
|
| for cap in captions:
|
| cap = self.clean_text(cap)
|
| counter.update(cap.split())
|
|
|
| for word, freq in counter.items():
|
| if freq >= self.freq_threshold:
|
| self.word2idx[word] = idx
|
| self.idx2word[idx] = word
|
| idx += 1
|
|
|
| def tokenize_caption(caption, vocab, max_len):
|
| """Tokenizes a caption string."""
|
| caption = vocab.clean_text(caption)
|
| tokens = caption.split()
|
|
|
|
|
| tokens = tokens[: max_len - 2]
|
|
|
| ids = [vocab.word2idx.get(tok, vocab.word2idx["<unk>"]) for tok in tokens]
|
|
|
| return [vocab.word2idx["<bos>"]] + ids + [vocab.word2idx["<eos>"]]
|
|
|
| class CocoCaptionDataset(Dataset):
|
| def __init__(self,
|
| images_dir,
|
| annotation_file,
|
| vocab,
|
| max_len=50,
|
| transform=None):
|
|
|
| self.images_dir = images_dir
|
| self.transform = transform
|
| self.vocab = vocab
|
| self.max_len = max_len
|
|
|
|
|
| print(f"Loading annotations from {annotation_file}...")
|
| with open(annotation_file, "r") as f:
|
| coco = json.load(f)
|
|
|
| self.images = {img["id"]: img for img in coco["images"]}
|
|
|
|
|
| self.image_id_to_captions = defaultdict(list)
|
| for ann in coco["annotations"]:
|
| self.image_id_to_captions[ann["image_id"]].append(ann["caption"])
|
|
|
|
|
| self.image_ids = list(self.image_id_to_captions.keys())
|
|
|
| total_captions = sum(len(caps) for caps in self.image_id_to_captions.values())
|
| print(
|
| f"Loaded {len(self.image_ids)} images with {total_captions} captions."
|
| )
|
|
|
| def __len__(self):
|
| return len(self.image_ids)
|
|
|
| def __getitem__(self, idx):
|
| image_id = self.image_ids[idx]
|
| captions = self.image_id_to_captions[image_id]
|
| caption = random.choice(captions)
|
| image_info = self.images[image_id]
|
| file_name = image_info["file_name"]
|
|
|
|
|
| image_path = os.path.join(self.images_dir, file_name)
|
| try:
|
| image = Image.open(image_path).convert("RGB")
|
| except FileNotFoundError:
|
|
|
|
|
| print(f"Warning: Image not found {image_path}")
|
| image = Image.new('RGB', (224, 224), (0, 0, 0))
|
|
|
| if self.transform:
|
| image = self.transform(image)
|
|
|
|
|
| caption_ids = tokenize_caption(caption, self.vocab, self.max_len)
|
|
|
| return image, torch.tensor(caption_ids, dtype=torch.long), image_id
|
|
|
| def collate_fn(batch):
|
| """
|
| Custom collate function to handle variable length captions.
|
| Returns:
|
| images: (batch_size, 3, 224, 224)
|
| padded_captions: (batch_size, max_seq_len)
|
| attention_mask: (batch_size, max_seq_len) - 1 for token, 0 for pad
|
| image_ids: Tuple of image ids
|
| """
|
| images, captions, image_ids = zip(*batch)
|
|
|
|
|
| images = torch.stack(images)
|
|
|
|
|
| lengths = [len(c) for c in captions]
|
| max_len = max(lengths)
|
|
|
| padded_captions = torch.zeros(len(captions), max_len, dtype=torch.long)
|
| padding_mask = torch.ones(len(captions), max_len, dtype=torch.bool)
|
|
|
| for i, cap in enumerate(captions):
|
| end = lengths[i]
|
| padded_captions[i, :end] = cap
|
| padding_mask[i, :end] = False
|
|
|
| return images, padded_captions, padding_mask, image_ids
|
|
|