image-caption-vi / model.py
kRnos22's picture
Upload 6 files
3974b9b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class EncoderCNN(nn.Module):
def __init__(self):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
param.requires_grad_(False)
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)
def forward(self, images):
features = self.resnet(images) # (B, 2048, 7, 7)
features = features.permute(0, 2, 3, 1) # (B, 7, 7, 2048)
features = features.view(features.size(0), -1, features.size(-1)) # (B, 49, 2048)
return features
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super(Attention, self).__init__()
self.W = nn.Linear(decoder_dim, attention_dim)
self.U = nn.Linear(encoder_dim, attention_dim)
self.A = nn.Linear(attention_dim, 1)
def forward(self, features, hidden_state):
u_hs = self.U(features)
w_ah = self.W(hidden_state)
combined = torch.tanh(u_hs + w_ah.unsqueeze(1))
e = self.A(combined).squeeze(2)
alpha = F.softmax(e, dim=1)
context = (features * alpha.unsqueeze(2)).sum(dim=1)
return alpha, context
class DecoderRNN(nn.Module):
def __init__(self, embed_tensor, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3):
super().__init__()
self.embedding = nn.Embedding.from_pretrained(embed_tensor, freeze=False)
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
self.init_h = nn.Linear(encoder_dim, decoder_dim)
self.init_c = nn.Linear(encoder_dim, decoder_dim)
self.lstm_cell = nn.LSTMCell(embed_tensor.size(1) + encoder_dim, decoder_dim)
self.fcn = nn.Linear(decoder_dim, vocab_size)
self.drop = nn.Dropout(drop_prob)
def forward(self, features, captions):
embeds = self.embedding(captions)
h, c = self.init_hidden_state(features)
seq_len = captions.size(1) - 1
batch_size = captions.size(0)
num_features = features.size(1)
preds = torch.zeros(batch_size, seq_len, self.fcn.out_features).to(features.device)
for s in range(seq_len):
alpha, context = self.attention(features, h)
lstm_input = torch.cat((embeds[:, s], context), dim=1)
h, c = self.lstm_cell(lstm_input, (h, c))
preds[:, s] = self.fcn(self.drop(h))
return preds
def generate_caption(self, features, vocab, max_len=20):
h, c = self.init_hidden_state(features)
word = torch.tensor([vocab.stoi["<SOS>"]]).to(features.device)
caption = []
for _ in range(max_len):
embed = self.embedding(word).squeeze(1)
alpha, context = self.attention(features, h)
lstm_input = torch.cat((embed, context), dim=1)
h, c = self.lstm_cell(lstm_input, (h, c))
output = self.fcn(self.drop(h))
predicted = output.argmax(1)
word = predicted
word_str = vocab.itos[predicted.item()]
if word_str == "<EOS>":
break
caption.append(word_str)
return " ".join(caption)
def init_hidden_state(self, features):
mean_features = features.mean(dim=1)
h = self.init_h(mean_features)
c = self.init_c(mean_features)
return h, c
class EncoderDecoder(nn.Module):
def __init__(self, embed_tensor, vocab, attention_dim=256, encoder_dim=2048, decoder_dim=512):
super().__init__()
self.encoder = EncoderCNN()
self.decoder = DecoderRNN(embed_tensor, len(vocab), attention_dim, encoder_dim, decoder_dim)
self.vocab = vocab
def forward(self, images, captions):
features = self.encoder(images)
return self.decoder(features, captions)
def predict_caption(self, image_tensor, max_len=20):
self.eval()
with torch.no_grad():
features = self.encoder(image_tensor.unsqueeze(0).to(next(self.parameters()).device))
return self.decoder.generate_caption(features, self.vocab, max_len)