| """Custom EfficientNet-V2-S + Transformer image captioning models. |
| |
| This file contains the architecture needed to load Ali Sedghiye's custom |
| 5k and 100k PyTorch checkpoints. It intentionally contains only inference |
| code, not training code. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Tuple |
|
|
| import os |
| import torch |
| torch.set_num_threads(max(1, min(4, os.cpu_count() or 1))) |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision.models import efficientnet_v2_s |
| from torchvision import transforms as T |
|
|
|
|
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| custom_transform = T.Compose( |
| [ |
| T.Resize((224, 224)), |
| T.ToTensor(), |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ] |
| ) |
|
|
|
|
| class Vocabulary: |
| PAD, SOS, EOS, UNK = 0, 1, 2, 3 |
|
|
| def __init__(self, freq_threshold: int = 5): |
| self.freq_threshold = freq_threshold |
| self.itos: Dict[int, str] = { |
| 0: "<PAD>", |
| 1: "<SOS>", |
| 2: "<EOS>", |
| 3: "<UNK>", |
| } |
| self.stoi: Dict[str, int] = {v: k for k, v in self.itos.items()} |
|
|
| def __len__(self) -> int: |
| return len(self.itos) |
|
|
| @staticmethod |
| def tokenize(text: str) -> List[str]: |
| return text.lower().strip().rstrip(".").split() |
|
|
| def decode(self, indices: Iterable[int]) -> str: |
| words: List[str] = [] |
| for idx in indices: |
| idx = int(idx) |
| if idx == self.EOS: |
| break |
| if idx not in (self.PAD, self.SOS): |
| words.append(self.itos.get(idx, "<UNK>")) |
| return " ".join(words).replace(" ", " ").strip() |
|
|
| @classmethod |
| def from_json(cls, path: str | Path) -> "Vocabulary": |
| path = Path(path) |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| vocab = cls(freq_threshold=int(data.get("freq_threshold", 5))) |
| vocab.itos = {int(k): v for k, v in data["itos"].items()} |
| vocab.stoi = {k: int(v) for k, v in data["stoi"].items()} |
| return vocab |
|
|
|
|
| class EfficientNetEncoder(nn.Module): |
| def __init__(self, embed_dim: int = 256, fine_tune: bool = False): |
| super().__init__() |
| |
| |
| backbone = efficientnet_v2_s(weights=None) |
| self.features = backbone.features |
| self.proj = nn.Sequential( |
| nn.Linear(1280, embed_dim), |
| nn.LayerNorm(embed_dim), |
| nn.GELU(), |
| ) |
| self.pos_embed = nn.Parameter(torch.randn(1, 49, embed_dim) * 0.02) |
| self.set_fine_tune(fine_tune) |
|
|
| def set_fine_tune(self, enable: bool) -> None: |
| for p in self.features.parameters(): |
| p.requires_grad = False |
| if enable: |
| for i in [6, 7]: |
| for p in self.features[i].parameters(): |
| p.requires_grad = True |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| feat = self.features(x) |
| feat = feat.flatten(2).permute(0, 2, 1) |
| feat = self.proj(feat) |
| return feat + self.pos_embed |
|
|
|
|
| class TransformerDecoder(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| embed_dim: int = 256, |
| num_heads: int = 8, |
| num_layers: int = 6, |
| ff_dim: int = 1024, |
| max_len: int = 52, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) |
| self.pos = nn.Embedding(max_len, embed_dim) |
| self.drop = nn.Dropout(dropout) |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=embed_dim, |
| nhead=num_heads, |
| dim_feedforward=ff_dim, |
| dropout=dropout, |
| batch_first=True, |
| norm_first=True, |
| activation="gelu", |
| ) |
| self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) |
| self.norm = nn.LayerNorm(embed_dim) |
| self.head = nn.Linear(embed_dim, vocab_size) |
| self.head.weight = self.embed.weight |
| nn.init.normal_(self.embed.weight, std=0.02) |
| nn.init.normal_(self.pos.weight, std=0.02) |
|
|
| def forward( |
| self, |
| tgt_ids: torch.Tensor, |
| memory: torch.Tensor, |
| tgt_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| t = tgt_ids.size(1) |
| pos = torch.arange(t, device=tgt_ids.device).unsqueeze(0) |
| x = self.drop(self.embed(tgt_ids) + self.pos(pos)) |
| causal_mask = nn.Transformer.generate_square_subsequent_mask(t, device=x.device) |
| out = self.transformer( |
| x, |
| memory, |
| tgt_mask=causal_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask, |
| ) |
| return self.head(self.norm(out)) |
|
|
|
|
| class ImageCaptioningModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| embed_dim: int = 256, |
| num_heads: int = 8, |
| num_layers: int = 6, |
| ff_dim: int = 1024, |
| max_len: int = 52, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.encoder = EfficientNetEncoder(embed_dim=embed_dim) |
| self.decoder = TransformerDecoder( |
| vocab_size=vocab_size, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_layers=num_layers, |
| ff_dim=ff_dim, |
| max_len=max_len, |
| dropout=dropout, |
| ) |
|
|
| def forward(self, images: torch.Tensor, captions: torch.Tensor) -> torch.Tensor: |
| memory = self.encoder(images) |
| tgt = captions[:, :-1] |
| pad_mask = tgt == 0 |
| return self.decoder(tgt, memory, pad_mask) |
|
|
| @torch.inference_mode() |
| def generate_greedy( |
| self, |
| image_tensor: torch.Tensor, |
| vocab: Vocabulary, |
| device: torch.device | str, |
| max_len: int = 30, |
| ) -> str: |
| self.eval() |
| image_tensor = image_tensor.unsqueeze(0).to(device) |
| memory = self.encoder(image_tensor) |
| tokens = [vocab.SOS] |
| for _ in range(max_len): |
| tgt = torch.tensor([tokens], dtype=torch.long, device=device) |
| next_id = self.decoder(tgt, memory)[0, -1].argmax().item() |
| tokens.append(next_id) |
| if next_id == vocab.EOS: |
| break |
| return vocab.decode(tokens[1:]) |
|
|
| @torch.inference_mode() |
| def generate_beam( |
| self, |
| image_tensor: torch.Tensor, |
| vocab: Vocabulary, |
| device: torch.device | str, |
| beam_size: int = 3, |
| max_len: int = 30, |
| ) -> str: |
| self.eval() |
| beam_size = max(1, int(beam_size)) |
| image_tensor = image_tensor.unsqueeze(0).to(device) |
| memory = self.encoder(image_tensor) |
| beams: List[Tuple[float, List[int]]] = [(0.0, [vocab.SOS])] |
| completed: List[Tuple[float, List[int]]] = [] |
|
|
| for _ in range(max_len): |
| candidates: List[Tuple[float, List[int]]] = [] |
| for score, seq in beams: |
| if seq[-1] == vocab.EOS: |
| completed.append((score, seq)) |
| continue |
| tgt = torch.tensor([seq], dtype=torch.long, device=device) |
| log_prob = F.log_softmax(self.decoder(tgt, memory)[0, -1], dim=-1) |
| values, indices = log_prob.topk(beam_size) |
| for lp, idx in zip(values.tolist(), indices.tolist()): |
| candidates.append((score + float(lp), seq + [int(idx)])) |
| if not candidates: |
| break |
| candidates.sort(key=lambda x: x[0] / max(len(x[1]), 1), reverse=True) |
| beams = candidates[:beam_size] |
|
|
| best = max(completed + beams, key=lambda x: x[0] / max(len(x[1]), 1)) |
| return vocab.decode(best[1][1:]) |
|
|
|
|
| @dataclass |
| class LoadedCustomModel: |
| model: ImageCaptioningModel |
| vocab: Vocabulary |
| device: torch.device |
|
|
| def caption(self, image: Image.Image, decoding: str = "Beam search", beam_size: int = 3, max_len: int = 30) -> str: |
| image = image.convert("RGB") |
| image_tensor = custom_transform(image) |
| if decoding == "Greedy": |
| return self.model.generate_greedy(image_tensor, self.vocab, self.device, max_len=max_len) |
| return self.model.generate_beam( |
| image_tensor, |
| self.vocab, |
| self.device, |
| beam_size=beam_size, |
| max_len=max_len, |
| ) |
|
|
|
|
| def load_custom_model( |
| checkpoint_path: str | Path, |
| vocab_path: str | Path, |
| device: torch.device | str, |
| ) -> LoadedCustomModel: |
| checkpoint_path = Path(checkpoint_path) |
| vocab_path = Path(vocab_path) |
| device = torch.device(device) |
|
|
| vocab = Vocabulary.from_json(vocab_path) |
| model = ImageCaptioningModel(vocab_size=len(vocab)).to(device) |
|
|
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
| return LoadedCustomModel(model=model, vocab=vocab, device=device) |
|
|