| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from cnn_encoder import CNNEncoder
|
| from vit_encoder import ViTEncoder
|
| from transformer_encoder import TransformerEncoder
|
| from transformer_decoder import TransformerDecoder
|
|
|
|
|
| class ImageCaptioningModel(nn.Module):
|
| def __init__(
|
| self,
|
| vocab_size,
|
| pad_id,
|
| d_model=512,
|
| num_encoder_layers=6,
|
| num_decoder_layers=6,
|
| num_heads=8,
|
| dim_ff=2048,
|
| max_seq_len=50,
|
| dropout=0.1,
|
| freeze_backbone=True,
|
| use_vit=False
|
| ):
|
| super().__init__()
|
|
|
| self.use_vit = use_vit
|
|
|
| if self.use_vit:
|
| self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
|
| else:
|
| self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
|
|
|
| self.transformer_encoder = TransformerEncoder(
|
| d_model=d_model,
|
| num_layers=num_encoder_layers,
|
| num_heads=num_heads,
|
| dim_ff=dim_ff,
|
| max_len=200,
|
| dropout=dropout,
|
| use_vit=self.use_vit
|
| )
|
| self.decoder = TransformerDecoder(
|
| vocab_size=vocab_size,
|
| pad_id=pad_id,
|
| d_model=d_model,
|
| num_layers=num_decoder_layers,
|
| num_heads=num_heads,
|
| dim_ff=dim_ff,
|
| max_len=max_seq_len,
|
| dropout=dropout,
|
| )
|
| self.d_model = d_model
|
|
|
| def generate_square_subsequent_mask(self, sz):
|
| return self.decoder.generate_square_subsequent_mask(sz)
|
|
|
| def unfreeze_encoder(self, unfreeze=True):
|
| self.encoder.unfreeze_backbone(unfreeze)
|
|
|
| def encode_image(self, images):
|
| img_features = self.encoder(images)
|
| return self.transformer_encoder(img_features)
|
|
|
| def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None):
|
| img_features = self.encode_image(images)
|
| return self.decoder(
|
| captions=captions,
|
| img_features=img_features,
|
| tgt_mask=tgt_mask,
|
| tgt_padding_mask=tgt_padding_mask,
|
| )
|
|
|
| def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"):
|
| """
|
| Generates a caption using beam search decoding.
|
|
|
| Args:
|
| image: Preprocessed image tensor of shape (1, 3, H, W).
|
| vocab: Vocabulary object with word2idx and idx2word mappings.
|
| beam_width: Number of candidate sequences to keep at each step.
|
| max_len: Maximum caption length.
|
| alpha: Length normalization penalty. Higher values favor longer captions.
|
| device: Device to run inference on.
|
|
|
| Returns:
|
| The highest-scoring caption as a string.
|
| """
|
| self.eval()
|
| with torch.no_grad():
|
| img_features = self.encode_image(image)
|
|
|
| bos_idx = vocab.word2idx["<bos>"]
|
| eos_idx = vocab.word2idx["<eos>"]
|
|
|
|
|
| beams = [(0.0, [bos_idx])]
|
| completed = []
|
|
|
| for _ in range(max_len):
|
| candidates = []
|
|
|
| for score, seq in beams:
|
|
|
| if seq[-1] == eos_idx:
|
| completed.append((score, seq))
|
| continue
|
|
|
| tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device)
|
| tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device)
|
|
|
| logits = self.decoder(
|
| captions=tgt_tensor,
|
| img_features=img_features,
|
| tgt_mask=tgt_mask,
|
| tgt_padding_mask=None,
|
| )
|
|
|
|
|
| log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
|
|
|
|
|
| topk_log_probs, topk_indices = log_probs.topk(beam_width)
|
|
|
| for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()):
|
| new_seq = seq + [idx]
|
| new_score = score + log_p
|
| candidates.append((new_score, new_seq))
|
|
|
|
|
| candidates.sort(key=lambda x: x[0], reverse=True)
|
| beams = candidates[:beam_width]
|
|
|
|
|
| if not beams:
|
| break
|
|
|
|
|
| completed.extend(beams)
|
|
|
|
|
| def normalize_score(score, length):
|
| return score / (length ** alpha)
|
|
|
| completed.sort(
|
| key=lambda x: normalize_score(x[0], len(x[1])),
|
| reverse=True
|
| )
|
|
|
| best_seq = completed[0][1]
|
|
|
|
|
| tokens = []
|
| for idx in best_seq:
|
| word = vocab.idx2word.get(idx, "<unk>")
|
| if word not in ["<bos>", "<eos>", "<pad>"]:
|
| tokens.append(word)
|
|
|
| return " ".join(tokens)
|
|
|
|
|
| def predict_caption(self, image, vocab, max_len=50, device="cpu"):
|
|
|
| '''
|
| Generates a caption using greedy decoding.
|
| Args:
|
| image: Preprocessed image tensor of shape (1, 3, H, W).
|
| vocab: Vocabulary object with word2idx and idx2word mappings.
|
| max_len: Maximum caption length.
|
| device: Device to run inference on.
|
|
|
| Returns:
|
| The generated caption as a string.
|
| '''
|
|
|
| self.eval()
|
| with torch.no_grad():
|
| img_features = self.encode_image(image)
|
|
|
| start_token_idx = vocab.word2idx["<bos>"]
|
| end_token_idx = vocab.word2idx["<eos>"]
|
|
|
| tgt_indices = [start_token_idx]
|
| for _ in range(max_len):
|
| tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device)
|
| tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device)
|
|
|
| logits = self.decoder(
|
| captions=tgt_tensor,
|
| img_features=img_features,
|
| tgt_mask=tgt_mask,
|
| tgt_padding_mask=None,
|
| )
|
|
|
| last_token_logits = logits[:, -1, :]
|
| predicted_id = last_token_logits.argmax(dim=-1).item()
|
| if predicted_id == end_token_idx:
|
| break
|
| tgt_indices.append(predicted_id)
|
|
|
| tokens = []
|
| for idx in tgt_indices:
|
| word = vocab.idx2word.get(idx, "<unk>")
|
| if word not in ["<bos>", "<eos>", "<pad>"]:
|
| tokens.append(word)
|
|
|
| return " ".join(tokens)
|
|
|