import torch import torch.nn as nn import torch.nn.functional as F from cnn_encoder import CNNEncoder from vit_encoder import ViTEncoder from transformer_encoder import TransformerEncoder from transformer_decoder import TransformerDecoder class ImageCaptioningModel(nn.Module): def __init__( self, vocab_size, pad_id, d_model=512, num_encoder_layers=6, num_decoder_layers=6, num_heads=8, dim_ff=2048, max_seq_len=50, dropout=0.1, freeze_backbone=True, use_vit=False ): super().__init__() self.use_vit = use_vit if self.use_vit: self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone) else: self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone) self.transformer_encoder = TransformerEncoder( d_model=d_model, num_layers=num_encoder_layers, num_heads=num_heads, dim_ff=dim_ff, max_len=200, dropout=dropout, use_vit=self.use_vit ) self.decoder = TransformerDecoder( vocab_size=vocab_size, pad_id=pad_id, d_model=d_model, num_layers=num_decoder_layers, num_heads=num_heads, dim_ff=dim_ff, max_len=max_seq_len, dropout=dropout, ) self.d_model = d_model def generate_square_subsequent_mask(self, sz): return self.decoder.generate_square_subsequent_mask(sz) def unfreeze_encoder(self, unfreeze=True): self.encoder.unfreeze_backbone(unfreeze) def encode_image(self, images): img_features = self.encoder(images) return self.transformer_encoder(img_features) def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None): img_features = self.encode_image(images) return self.decoder( captions=captions, img_features=img_features, tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, ) def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"): """ Generates a caption using beam search decoding. Args: image: Preprocessed image tensor of shape (1, 3, H, W). vocab: Vocabulary object with word2idx and idx2word mappings. beam_width: Number of candidate sequences to keep at each step. max_len: Maximum caption length. alpha: Length normalization penalty. Higher values favor longer captions. device: Device to run inference on. Returns: The highest-scoring caption as a string. """ self.eval() with torch.no_grad(): img_features = self.encode_image(image) bos_idx = vocab.word2idx[""] eos_idx = vocab.word2idx[""] # Each beam: (log_probability, token_indices_list) beams = [(0.0, [bos_idx])] completed = [] for _ in range(max_len): candidates = [] for score, seq in beams: # If this beam already ended, don't expand it if seq[-1] == eos_idx: completed.append((score, seq)) continue tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device) tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device) logits = self.decoder( captions=tgt_tensor, img_features=img_features, tgt_mask=tgt_mask, tgt_padding_mask=None, ) # Get log-probabilities for the last token log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0) # Select top-k tokens topk_log_probs, topk_indices = log_probs.topk(beam_width) for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()): new_seq = seq + [idx] new_score = score + log_p candidates.append((new_score, new_seq)) # Keep top beam_width candidates (sorted by score) candidates.sort(key=lambda x: x[0], reverse=True) beams = candidates[:beam_width] # Early stop: all beams have ended if not beams: break # Add any remaining incomplete beams to completed completed.extend(beams) # Length-normalized scoring: score / (length ^ alpha) def normalize_score(score, length): return score / (length ** alpha) completed.sort( key=lambda x: normalize_score(x[0], len(x[1])), reverse=True ) best_seq = completed[0][1] # Convert indices to words, skipping special tokens tokens = [] for idx in best_seq: word = vocab.idx2word.get(idx, "") if word not in ["", "", ""]: tokens.append(word) return " ".join(tokens) def predict_caption(self, image, vocab, max_len=50, device="cpu"): ''' Generates a caption using greedy decoding. Args: image: Preprocessed image tensor of shape (1, 3, H, W). vocab: Vocabulary object with word2idx and idx2word mappings. max_len: Maximum caption length. device: Device to run inference on. Returns: The generated caption as a string. ''' self.eval() with torch.no_grad(): img_features = self.encode_image(image) start_token_idx = vocab.word2idx[""] end_token_idx = vocab.word2idx[""] tgt_indices = [start_token_idx] for _ in range(max_len): tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device) tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device) logits = self.decoder( captions=tgt_tensor, img_features=img_features, tgt_mask=tgt_mask, tgt_padding_mask=None, ) last_token_logits = logits[:, -1, :] predicted_id = last_token_logits.argmax(dim=-1).item() if predicted_id == end_token_idx: break tgt_indices.append(predicted_id) tokens = [] for idx in tgt_indices: word = vocab.idx2word.get(idx, "") if word not in ["", "", ""]: tokens.append(word) return " ".join(tokens)