"""Custom EfficientNet-V2-S + Transformer image captioning models. This file contains the architecture needed to load Ali Sedghiye's custom 5k and 100k PyTorch checkpoints. It intentionally contains only inference code, not training code. """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Tuple import os import torch torch.set_num_threads(max(1, min(4, os.cpu_count() or 1))) import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision.models import efficientnet_v2_s from torchvision import transforms as T IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] custom_transform = T.Compose( [ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] ) class Vocabulary: PAD, SOS, EOS, UNK = 0, 1, 2, 3 def __init__(self, freq_threshold: int = 5): self.freq_threshold = freq_threshold self.itos: Dict[int, str] = { 0: "", 1: "", 2: "", 3: "", } self.stoi: Dict[str, int] = {v: k for k, v in self.itos.items()} def __len__(self) -> int: return len(self.itos) @staticmethod def tokenize(text: str) -> List[str]: return text.lower().strip().rstrip(".").split() def decode(self, indices: Iterable[int]) -> str: words: List[str] = [] for idx in indices: idx = int(idx) if idx == self.EOS: break if idx not in (self.PAD, self.SOS): words.append(self.itos.get(idx, "")) return " ".join(words).replace(" ", " ").strip() @classmethod def from_json(cls, path: str | Path) -> "Vocabulary": path = Path(path) with path.open("r", encoding="utf-8") as f: data = json.load(f) vocab = cls(freq_threshold=int(data.get("freq_threshold", 5))) vocab.itos = {int(k): v for k, v in data["itos"].items()} vocab.stoi = {k: int(v) for k, v in data["stoi"].items()} return vocab class EfficientNetEncoder(nn.Module): def __init__(self, embed_dim: int = 256, fine_tune: bool = False): super().__init__() # weights=None prevents a download at Space startup. The checkpoint # contains the trained encoder weights and will overwrite initialization. backbone = efficientnet_v2_s(weights=None) self.features = backbone.features self.proj = nn.Sequential( nn.Linear(1280, embed_dim), nn.LayerNorm(embed_dim), nn.GELU(), ) self.pos_embed = nn.Parameter(torch.randn(1, 49, embed_dim) * 0.02) self.set_fine_tune(fine_tune) def set_fine_tune(self, enable: bool) -> None: for p in self.features.parameters(): p.requires_grad = False if enable: for i in [6, 7]: for p in self.features[i].parameters(): p.requires_grad = True def forward(self, x: torch.Tensor) -> torch.Tensor: feat = self.features(x) # (B, 1280, 7, 7) feat = feat.flatten(2).permute(0, 2, 1) # (B, 49, 1280) feat = self.proj(feat) # (B, 49, embed_dim) return feat + self.pos_embed class TransformerDecoder(nn.Module): def __init__( self, vocab_size: int, embed_dim: int = 256, num_heads: int = 8, num_layers: int = 6, ff_dim: int = 1024, max_len: int = 52, dropout: float = 0.1, ): super().__init__() self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.pos = nn.Embedding(max_len, embed_dim) self.drop = nn.Dropout(dropout) decoder_layer = nn.TransformerDecoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout, batch_first=True, norm_first=True, activation="gelu", ) self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, vocab_size) self.head.weight = self.embed.weight nn.init.normal_(self.embed.weight, std=0.02) nn.init.normal_(self.pos.weight, std=0.02) def forward( self, tgt_ids: torch.Tensor, memory: torch.Tensor, tgt_key_padding_mask: torch.Tensor | None = None, ) -> torch.Tensor: t = tgt_ids.size(1) pos = torch.arange(t, device=tgt_ids.device).unsqueeze(0) x = self.drop(self.embed(tgt_ids) + self.pos(pos)) causal_mask = nn.Transformer.generate_square_subsequent_mask(t, device=x.device) out = self.transformer( x, memory, tgt_mask=causal_mask, tgt_key_padding_mask=tgt_key_padding_mask, ) return self.head(self.norm(out)) class ImageCaptioningModel(nn.Module): def __init__( self, vocab_size: int, embed_dim: int = 256, num_heads: int = 8, num_layers: int = 6, ff_dim: int = 1024, max_len: int = 52, dropout: float = 0.1, ): super().__init__() self.encoder = EfficientNetEncoder(embed_dim=embed_dim) self.decoder = TransformerDecoder( vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, ff_dim=ff_dim, max_len=max_len, dropout=dropout, ) def forward(self, images: torch.Tensor, captions: torch.Tensor) -> torch.Tensor: memory = self.encoder(images) tgt = captions[:, :-1] pad_mask = tgt == 0 return self.decoder(tgt, memory, pad_mask) @torch.inference_mode() def generate_greedy( self, image_tensor: torch.Tensor, vocab: Vocabulary, device: torch.device | str, max_len: int = 30, ) -> str: self.eval() image_tensor = image_tensor.unsqueeze(0).to(device) memory = self.encoder(image_tensor) tokens = [vocab.SOS] for _ in range(max_len): tgt = torch.tensor([tokens], dtype=torch.long, device=device) next_id = self.decoder(tgt, memory)[0, -1].argmax().item() tokens.append(next_id) if next_id == vocab.EOS: break return vocab.decode(tokens[1:]) @torch.inference_mode() def generate_beam( self, image_tensor: torch.Tensor, vocab: Vocabulary, device: torch.device | str, beam_size: int = 3, max_len: int = 30, ) -> str: self.eval() beam_size = max(1, int(beam_size)) image_tensor = image_tensor.unsqueeze(0).to(device) memory = self.encoder(image_tensor) beams: List[Tuple[float, List[int]]] = [(0.0, [vocab.SOS])] completed: List[Tuple[float, List[int]]] = [] for _ in range(max_len): candidates: List[Tuple[float, List[int]]] = [] for score, seq in beams: if seq[-1] == vocab.EOS: completed.append((score, seq)) continue tgt = torch.tensor([seq], dtype=torch.long, device=device) log_prob = F.log_softmax(self.decoder(tgt, memory)[0, -1], dim=-1) values, indices = log_prob.topk(beam_size) for lp, idx in zip(values.tolist(), indices.tolist()): candidates.append((score + float(lp), seq + [int(idx)])) if not candidates: break candidates.sort(key=lambda x: x[0] / max(len(x[1]), 1), reverse=True) beams = candidates[:beam_size] best = max(completed + beams, key=lambda x: x[0] / max(len(x[1]), 1)) return vocab.decode(best[1][1:]) @dataclass class LoadedCustomModel: model: ImageCaptioningModel vocab: Vocabulary device: torch.device def caption(self, image: Image.Image, decoding: str = "Beam search", beam_size: int = 3, max_len: int = 30) -> str: image = image.convert("RGB") image_tensor = custom_transform(image) if decoding == "Greedy": return self.model.generate_greedy(image_tensor, self.vocab, self.device, max_len=max_len) return self.model.generate_beam( image_tensor, self.vocab, self.device, beam_size=beam_size, max_len=max_len, ) def load_custom_model( checkpoint_path: str | Path, vocab_path: str | Path, device: torch.device | str, ) -> LoadedCustomModel: checkpoint_path = Path(checkpoint_path) vocab_path = Path(vocab_path) device = torch.device(device) vocab = Vocabulary.from_json(vocab_path) model = ImageCaptioningModel(vocab_size=len(vocab)).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint model.load_state_dict(state_dict, strict=True) model.eval() return LoadedCustomModel(model=model, vocab=vocab, device=device)