Spaces:
Sleeping
Sleeping
| import contextlib | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torchvision import models | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
| from .config import TrainingConfig, get_device | |
| class ImageCaptioningOutput: | |
| """ | |
| Container for model outputs. | |
| Attributes | |
| ---------- | |
| logits: | |
| Predicted token logits of shape (batch_size, seq_len, vocab_size), | |
| where seq_len is the number of text tokens (visual prefix tokens are removed). | |
| loss: | |
| Optional cross-entropy loss over caption tokens. | |
| """ | |
| logits: Tensor | |
| loss: Optional[Tensor] = None | |
| class EfficientNetB0Encoder(nn.Module): | |
| """ | |
| EfficientNet-B0 image encoder using torchvision. | |
| The classification head is removed and only the pooled feature vector | |
| (dimension 1280) is returned. | |
| """ | |
| def __init__(self, pretrained: bool = True) -> None: | |
| super().__init__() | |
| effnet = models.efficientnet_b0(pretrained=pretrained) | |
| self.features = effnet.features | |
| self.avgpool = effnet.avgpool | |
| self.flatten = nn.Flatten() | |
| # in_features of the final classifier is the encoder output dim | |
| self.out_dim: int = effnet.classifier[1].in_features | |
| def forward(self, images: Tensor) -> Tensor: | |
| """ | |
| Encode a batch of images into a pooled feature representation. | |
| Parameters | |
| ---------- | |
| images: | |
| Tensor of shape (batch_size, 3, 224, 224). | |
| """ | |
| x = self.features(images) | |
| x = self.avgpool(x) | |
| x = self.flatten(x) # (batch_size, out_dim) | |
| return x | |
| class ImageCaptioningModel(nn.Module): | |
| """ | |
| Image captioning model with an EfficientNet-B0 vision encoder and GPT-2 decoder. | |
| The model projects visual features into a sequence of prefix embeddings that | |
| are concatenated with GPT-2 token embeddings. GPT-2 then predicts caption tokens. | |
| """ | |
| def __init__( | |
| self, | |
| training_cfg: Optional[TrainingConfig] = None, | |
| pretrained_encoder: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.training_cfg = training_cfg or TrainingConfig() | |
| self.device: torch.device = get_device() | |
| # Vision encoder | |
| self.encoder = EfficientNetB0Encoder(pretrained=pretrained_encoder) | |
| # Text decoder (GPT-2 small) | |
| self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| if self.tokenizer.pad_token is None: | |
| # Use EOS as pad token | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2") | |
| self.gpt2.config.pad_token_id = self.tokenizer.pad_token_id | |
| # Number of visual prefix tokens | |
| self.prefix_length: int = int(self.training_cfg.prefix_length) | |
| if self.prefix_length < 1: | |
| raise ValueError("prefix_length must be >= 1") | |
| # Project image features to a sequence of prefix token embeddings | |
| self.visual_projection = nn.Linear( | |
| self.encoder.out_dim, | |
| self.gpt2.config.n_embd * self.prefix_length, | |
| ) | |
| self._printed_debug: bool = False | |
| self.to(self.device) | |
| # --------------------------------------------------------------------- # | |
| # Internal utilities | |
| # --------------------------------------------------------------------- # | |
| def encode_images(self, images: Tensor) -> Tensor: | |
| """ | |
| Encode images and produce visual prefix embeddings. | |
| Returns | |
| ------- | |
| Tensor of shape (batch_size, prefix_length, hidden_size). | |
| """ | |
| assert images.dim() == 4, f"Expected images of shape (B,3,H,W), got {images.shape}" | |
| img_features = self.encoder(images) # (B, encoder_out_dim) | |
| batch_size = img_features.size(0) | |
| prefix_embeddings = self.visual_projection(img_features) | |
| prefix_embeddings = prefix_embeddings.view( | |
| batch_size, | |
| self.prefix_length, | |
| self.gpt2.config.n_embd, | |
| ) | |
| return prefix_embeddings | |
| # --------------------------------------------------------------------- # | |
| # Forward (training) | |
| # --------------------------------------------------------------------- # | |
| def forward( | |
| self, | |
| images: Tensor, | |
| captions: Tensor, | |
| attention_mask: Optional[Tensor] = None, | |
| labels: Optional[Tensor] = None, | |
| ) -> ImageCaptioningOutput: | |
| """ | |
| Forward pass for training. | |
| Parameters | |
| ---------- | |
| images: | |
| Tensor of shape (batch_size, 3, 224, 224). | |
| captions: | |
| Token IDs of shape (batch_size, seq_len). | |
| attention_mask: | |
| Optional attention mask of shape (batch_size, seq_len). | |
| labels: | |
| Optional target token IDs of shape (batch_size, seq_len). | |
| If provided, cross-entropy loss is computed, ignoring positions | |
| with label -100. | |
| """ | |
| images = images.to(self.device) | |
| captions = captions.to(self.device) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(self.device) | |
| if labels is not None: | |
| labels = labels.to(self.device) | |
| batch_size, seq_len = captions.shape | |
| assert images.size(0) == batch_size, "Batch size mismatch between images and captions." | |
| prefix_embeddings = self.encode_images(images) # (B, P, H) | |
| token_embeddings = self.gpt2.transformer.wte(captions) # (B, T, H) | |
| inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1) # (B, P+T, H) | |
| if attention_mask is not None: | |
| prefix_mask = torch.ones( | |
| batch_size, | |
| self.prefix_length, | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| extended_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) | |
| else: | |
| extended_attention_mask = None | |
| if not self._printed_debug: | |
| print(f"[DEBUG] images shape: {images.shape}") | |
| print(f"[DEBUG] captions shape: {captions.shape}") | |
| print(f"[DEBUG] prefix_embeddings: {prefix_embeddings.shape}") | |
| print(f"[DEBUG] token_embeddings: {token_embeddings.shape}") | |
| print(f"[DEBUG] inputs_embeds shape: {inputs_embeds.shape}") | |
| if extended_attention_mask is not None: | |
| print(f"[DEBUG] attention_mask shape: {extended_attention_mask.shape}") | |
| self._printed_debug = True | |
| outputs = self.gpt2( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=extended_attention_mask, | |
| use_cache=False, | |
| return_dict=True, | |
| ) | |
| # Remove visual prefix positions from the logits so that | |
| # the returned logits only correspond to text tokens. | |
| logits = outputs.logits[:, self.prefix_length :, :] # (B, T, V) | |
| loss: Optional[Tensor] = None | |
| if labels is not None: | |
| if labels.shape != (batch_size, seq_len): | |
| raise ValueError( | |
| f"labels shape {labels.shape} does not match captions shape {(batch_size, seq_len)}" | |
| ) | |
| # Shift logits and labels for next-token prediction | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) | |
| loss = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ) | |
| return ImageCaptioningOutput(logits=logits, loss=loss) | |
| # --------------------------------------------------------------------- # | |
| # Generation (inference) | |
| # --------------------------------------------------------------------- # | |
| def generate( | |
| self, | |
| images: Tensor, | |
| max_length: int = 50, | |
| num_beams: int = 1, | |
| temperature: float = 1.0, | |
| top_k: int = 0, | |
| eos_token_id: Optional[int] = None, | |
| length_penalty: float = 0.0, | |
| repetition_penalty: float = 1.0, | |
| ) -> List[str]: | |
| """ | |
| Generate captions for a batch of images using a simple beam search. | |
| Notes | |
| ----- | |
| - For simplicity and clarity, this implementation currently supports | |
| batch_size == 1. A ValueError is raised otherwise. | |
| """ | |
| self.eval() | |
| images = images.to(self.device) | |
| batch_size = images.size(0) | |
| if batch_size != 1: | |
| raise ValueError(f"generate currently supports batch_size == 1, got {batch_size}") | |
| eos_token_id = eos_token_id or self.tokenizer.eos_token_id | |
| bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id | |
| prefix_embeddings = self.encode_images(images) # (1, P, H) | |
| # Each beam is (token_ids, log_prob) | |
| beams: List[Tuple[List[int], float]] = [([], 0.0)] | |
| def _length_normalized_score(tokens: List[int], score: float) -> float: | |
| if length_penalty is None or length_penalty == 0.0: | |
| return score | |
| length = max(1, len(tokens)) | |
| return score / (length ** length_penalty) | |
| for _ in range(max_length): | |
| all_candidates: List[Tuple[List[int], float]] = [] | |
| for seq, score in beams: | |
| if seq and seq[-1] == eos_token_id: | |
| # If already finished, keep as-is | |
| all_candidates.append((seq, score)) | |
| continue | |
| # Build a 2D tensor of token IDs with shape (1, L) | |
| if seq: | |
| input_ids = torch.tensor( | |
| [seq], | |
| device=self.device, | |
| dtype=torch.long, | |
| ) # (1, L) | |
| else: | |
| input_ids = torch.tensor( | |
| [[bos_token_id]], | |
| device=self.device, | |
| dtype=torch.long, | |
| ) # (1, 1) | |
| token_embeddings = self.gpt2.transformer.wte(input_ids) # (1, L, H) | |
| inputs_embeds = torch.cat([prefix_embeddings, token_embeddings], dim=1) | |
| attention_mask = torch.ones( | |
| inputs_embeds.size()[:-1], | |
| dtype=torch.long, | |
| device=self.device, | |
| ) | |
| outputs = self.gpt2( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| use_cache=False, | |
| return_dict=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] / max(temperature, 1e-5) | |
| if top_k > 0: | |
| topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1) | |
| log_probs = torch.log_softmax(topk_logits, dim=-1) | |
| for i in range(top_k): | |
| token_id = int(topk_indices[0, i]) | |
| candidate = (seq + [token_id], score + float(log_probs[0, i])) | |
| all_candidates.append(candidate) | |
| else: | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| topk_log_probs, topk_indices = torch.topk(log_probs, num_beams, dim=-1) | |
| for i in range(num_beams): | |
| token_id = int(topk_indices[0, i]) | |
| candidate = (seq + [token_id], score + float(topk_log_probs[0, i])) | |
| all_candidates.append(candidate) | |
| # Select best beams. With num_beams=1 and length_penalty=0 this | |
| # reduces to simple greedy decoding, which is fully deterministic. | |
| beams = sorted( | |
| all_candidates, | |
| key=lambda x: _length_normalized_score(x[0], x[1]), | |
| reverse=True, | |
| )[:num_beams] | |
| # If all beams ended with EOS, stop early | |
| if all(seq and seq[-1] == eos_token_id for seq, _ in beams): | |
| break | |
| best_seq, best_score = max( | |
| beams, | |
| key=lambda x: _length_normalized_score(x[0], x[1]), | |
| ) | |
| # Truncate at EOS if present | |
| if eos_token_id in best_seq: | |
| best_seq = best_seq[: best_seq.index(eos_token_id)] | |
| caption = self.tokenizer.decode(best_seq, skip_special_tokens=True) | |
| # Normalize whitespace so the final caption is a single, clean string. | |
| caption = " ".join(caption.strip().split()) | |
| return [caption] | |
| # --------------------------------------------------------------------- # | |
| # Dummy test helper | |
| # --------------------------------------------------------------------- # | |
| def test_dummy(self) -> None: | |
| """ | |
| Run a dummy forward pass to verify the model works end-to-end. | |
| This matches the specification in the prompt and asserts that the | |
| output logits have shape (2, 20, 50257) when captions have length 20. | |
| """ | |
| self.eval() | |
| vocab_size = int(self.gpt2.config.vocab_size) | |
| dummy_images = torch.randn(2, 3, 224, 224, device=self.device) | |
| dummy_captions = torch.randint(0, vocab_size, (2, 20), device=self.device) | |
| with torch.no_grad(), contextlib.ExitStack() as stack: | |
| if self.device.type == "cuda": | |
| stack.enter_context(torch.cuda.amp.autocast()) | |
| outputs = self(dummy_images, dummy_captions) | |
| logits = outputs.logits | |
| assert logits.shape == (2, 20, vocab_size), ( | |
| f"Output shape mismatch: expected (2, 20, {vocab_size}), " | |
| f"got {tuple(logits.shape)}" | |
| ) | |
| print("✓ Model architecture verified successfully!") | |