File size: 7,291 Bytes
601cad6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import torch
import torch.nn as nn
import torch.nn.functional as F
from cnn_encoder import CNNEncoder
from vit_encoder import ViTEncoder
from transformer_encoder import TransformerEncoder
from transformer_decoder import TransformerDecoder
class ImageCaptioningModel(nn.Module):
def __init__(
self,
vocab_size,
pad_id,
d_model=512,
num_encoder_layers=6,
num_decoder_layers=6,
num_heads=8,
dim_ff=2048,
max_seq_len=50,
dropout=0.1,
freeze_backbone=True,
use_vit=False
):
super().__init__()
self.use_vit = use_vit
if self.use_vit:
self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
else:
self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
self.transformer_encoder = TransformerEncoder(
d_model=d_model,
num_layers=num_encoder_layers,
num_heads=num_heads,
dim_ff=dim_ff,
max_len=200,
dropout=dropout,
use_vit=self.use_vit
)
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
pad_id=pad_id,
d_model=d_model,
num_layers=num_decoder_layers,
num_heads=num_heads,
dim_ff=dim_ff,
max_len=max_seq_len,
dropout=dropout,
)
self.d_model = d_model
def generate_square_subsequent_mask(self, sz):
return self.decoder.generate_square_subsequent_mask(sz)
def unfreeze_encoder(self, unfreeze=True):
self.encoder.unfreeze_backbone(unfreeze)
def encode_image(self, images):
img_features = self.encoder(images)
return self.transformer_encoder(img_features)
def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None):
img_features = self.encode_image(images)
return self.decoder(
captions=captions,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
)
def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"):
"""
Generates a caption using beam search decoding.
Args:
image: Preprocessed image tensor of shape (1, 3, H, W).
vocab: Vocabulary object with word2idx and idx2word mappings.
beam_width: Number of candidate sequences to keep at each step.
max_len: Maximum caption length.
alpha: Length normalization penalty. Higher values favor longer captions.
device: Device to run inference on.
Returns:
The highest-scoring caption as a string.
"""
self.eval()
with torch.no_grad():
img_features = self.encode_image(image)
bos_idx = vocab.word2idx["<bos>"]
eos_idx = vocab.word2idx["<eos>"]
# Each beam: (log_probability, token_indices_list)
beams = [(0.0, [bos_idx])]
completed = []
for _ in range(max_len):
candidates = []
for score, seq in beams:
# If this beam already ended, don't expand it
if seq[-1] == eos_idx:
completed.append((score, seq))
continue
tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device)
tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device)
logits = self.decoder(
captions=tgt_tensor,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=None,
)
# Get log-probabilities for the last token
log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
# Select top-k tokens
topk_log_probs, topk_indices = log_probs.topk(beam_width)
for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()):
new_seq = seq + [idx]
new_score = score + log_p
candidates.append((new_score, new_seq))
# Keep top beam_width candidates (sorted by score)
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# Early stop: all beams have ended
if not beams:
break
# Add any remaining incomplete beams to completed
completed.extend(beams)
# Length-normalized scoring: score / (length ^ alpha)
def normalize_score(score, length):
return score / (length ** alpha)
completed.sort(
key=lambda x: normalize_score(x[0], len(x[1])),
reverse=True
)
best_seq = completed[0][1]
# Convert indices to words, skipping special tokens
tokens = []
for idx in best_seq:
word = vocab.idx2word.get(idx, "<unk>")
if word not in ["<bos>", "<eos>", "<pad>"]:
tokens.append(word)
return " ".join(tokens)
def predict_caption(self, image, vocab, max_len=50, device="cpu"):
'''
Generates a caption using greedy decoding.
Args:
image: Preprocessed image tensor of shape (1, 3, H, W).
vocab: Vocabulary object with word2idx and idx2word mappings.
max_len: Maximum caption length.
device: Device to run inference on.
Returns:
The generated caption as a string.
'''
self.eval()
with torch.no_grad():
img_features = self.encode_image(image)
start_token_idx = vocab.word2idx["<bos>"]
end_token_idx = vocab.word2idx["<eos>"]
tgt_indices = [start_token_idx]
for _ in range(max_len):
tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device)
tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device)
logits = self.decoder(
captions=tgt_tensor,
img_features=img_features,
tgt_mask=tgt_mask,
tgt_padding_mask=None,
)
last_token_logits = logits[:, -1, :]
predicted_id = last_token_logits.argmax(dim=-1).item()
if predicted_id == end_token_idx:
break
tgt_indices.append(predicted_id)
tokens = []
for idx in tgt_indices:
word = vocab.idx2word.get(idx, "<unk>")
if word not in ["<bos>", "<eos>", "<pad>"]:
tokens.append(word)
return " ".join(tokens)
|