vit-image-captioning / image_captioning_model.py
mostafahagali's picture
Upload 9 files
601cad6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from cnn_encoder import CNNEncoder
from vit_encoder import ViTEncoder
from transformer_encoder import TransformerEncoder
from transformer_decoder import TransformerDecoder
class ImageCaptioningModel(nn.Module):
def __init__(
self,
vocab_size,
pad_id,
d_model=512,
num_encoder_layers=6,
num_decoder_layers=6,
num_heads=8,
dim_ff=2048,
max_seq_len=50,
dropout=0.1,
freeze_backbone=True,
use_vit=False
):
super().__init__()
self.use_vit = use_vit
if self.use_vit:
self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
else:
self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
self.transformer_encoder = TransformerEncoder(
d_model=d_model,
num_layers=num_encoder_layers,
num_heads=num_heads,
dim_ff=dim_ff,
max_len=200,
dropout=dropout,
use_vit=self.use_vit
)
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
pad_id=pad_id,
d_model=d_model,
num_layers=num_decoder_layers,
num_heads=num_heads,
dim_ff=dim_ff,
max_len=max_seq_len,
dropout=dropout,
)
self.d_model = d_model
def generate_square_subsequent_mask(self, sz):
return self.decoder.generate_square_subsequent_mask(sz)
def unfreeze_encoder(self, unfreeze=True):
self.encoder.unfreeze_backbone(unfreeze)
def encode_image(self, images):
img_features = self.encoder(images)
return self.transformer_encoder(img_features)
def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None):
img_features = self.encode_image(images)
return self.decoder(
captions=captions,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
)
def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"):
"""
Generates a caption using beam search decoding.
Args:
image: Preprocessed image tensor of shape (1, 3, H, W).
vocab: Vocabulary object with word2idx and idx2word mappings.
beam_width: Number of candidate sequences to keep at each step.
max_len: Maximum caption length.
alpha: Length normalization penalty. Higher values favor longer captions.
device: Device to run inference on.
Returns:
The highest-scoring caption as a string.
"""
self.eval()
with torch.no_grad():
img_features = self.encode_image(image)
bos_idx = vocab.word2idx["<bos>"]
eos_idx = vocab.word2idx["<eos>"]
# Each beam: (log_probability, token_indices_list)
beams = [(0.0, [bos_idx])]
completed = []
for _ in range(max_len):
candidates = []
for score, seq in beams:
# If this beam already ended, don't expand it
if seq[-1] == eos_idx:
completed.append((score, seq))
continue
tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device)
tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device)
logits = self.decoder(
captions=tgt_tensor,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=None,
)
# Get log-probabilities for the last token
log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
# Select top-k tokens
topk_log_probs, topk_indices = log_probs.topk(beam_width)
for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()):
new_seq = seq + [idx]
new_score = score + log_p
candidates.append((new_score, new_seq))
# Keep top beam_width candidates (sorted by score)
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# Early stop: all beams have ended
if not beams:
break
# Add any remaining incomplete beams to completed
completed.extend(beams)
# Length-normalized scoring: score / (length ^ alpha)
def normalize_score(score, length):
return score / (length ** alpha)
completed.sort(
key=lambda x: normalize_score(x[0], len(x[1])),
reverse=True
)
best_seq = completed[0][1]
# Convert indices to words, skipping special tokens
tokens = []
for idx in best_seq:
word = vocab.idx2word.get(idx, "<unk>")
if word not in ["<bos>", "<eos>", "<pad>"]:
tokens.append(word)
return " ".join(tokens)
def predict_caption(self, image, vocab, max_len=50, device="cpu"):
'''
Generates a caption using greedy decoding.
Args:
image: Preprocessed image tensor of shape (1, 3, H, W).
vocab: Vocabulary object with word2idx and idx2word mappings.
max_len: Maximum caption length.
device: Device to run inference on.
Returns:
The generated caption as a string.
'''
self.eval()
with torch.no_grad():
img_features = self.encode_image(image)
start_token_idx = vocab.word2idx["<bos>"]
end_token_idx = vocab.word2idx["<eos>"]
tgt_indices = [start_token_idx]
for _ in range(max_len):
tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device)
tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device)
logits = self.decoder(
captions=tgt_tensor,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=None,
)
last_token_logits = logits[:, -1, :]
predicted_id = last_token_logits.argmax(dim=-1).item()
if predicted_id == end_token_idx:
break
tgt_indices.append(predicted_id)
tokens = []
for idx in tgt_indices:
word = vocab.idx2word.get(idx, "<unk>")
if word not in ["<bos>", "<eos>", "<pad>"]:
tokens.append(word)
return " ".join(tokens)