import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class EncoderCNN(nn.Module): def __init__(self): super(EncoderCNN, self).__init__() resnet = models.resnet50(pretrained=True) for param in resnet.parameters(): param.requires_grad_(False) modules = list(resnet.children())[:-2] self.resnet = nn.Sequential(*modules) def forward(self, images): features = self.resnet(images) # (B, 2048, 7, 7) features = features.permute(0, 2, 3, 1) # (B, 7, 7, 2048) features = features.view(features.size(0), -1, features.size(-1)) # (B, 49, 2048) return features class Attention(nn.Module): def __init__(self, encoder_dim, decoder_dim, attention_dim): super(Attention, self).__init__() self.W = nn.Linear(decoder_dim, attention_dim) self.U = nn.Linear(encoder_dim, attention_dim) self.A = nn.Linear(attention_dim, 1) def forward(self, features, hidden_state): u_hs = self.U(features) w_ah = self.W(hidden_state) combined = torch.tanh(u_hs + w_ah.unsqueeze(1)) e = self.A(combined).squeeze(2) alpha = F.softmax(e, dim=1) context = (features * alpha.unsqueeze(2)).sum(dim=1) return alpha, context class DecoderRNN(nn.Module): def __init__(self, embed_tensor, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3): super().__init__() self.embedding = nn.Embedding.from_pretrained(embed_tensor, freeze=False) self.attention = Attention(encoder_dim, decoder_dim, attention_dim) self.init_h = nn.Linear(encoder_dim, decoder_dim) self.init_c = nn.Linear(encoder_dim, decoder_dim) self.lstm_cell = nn.LSTMCell(embed_tensor.size(1) + encoder_dim, decoder_dim) self.fcn = nn.Linear(decoder_dim, vocab_size) self.drop = nn.Dropout(drop_prob) def forward(self, features, captions): embeds = self.embedding(captions) h, c = self.init_hidden_state(features) seq_len = captions.size(1) - 1 batch_size = captions.size(0) num_features = features.size(1) preds = torch.zeros(batch_size, seq_len, self.fcn.out_features).to(features.device) for s in range(seq_len): alpha, context = self.attention(features, h) lstm_input = torch.cat((embeds[:, s], context), dim=1) h, c = self.lstm_cell(lstm_input, (h, c)) preds[:, s] = self.fcn(self.drop(h)) return preds def generate_caption(self, features, vocab, max_len=20): h, c = self.init_hidden_state(features) word = torch.tensor([vocab.stoi[""]]).to(features.device) caption = [] for _ in range(max_len): embed = self.embedding(word).squeeze(1) alpha, context = self.attention(features, h) lstm_input = torch.cat((embed, context), dim=1) h, c = self.lstm_cell(lstm_input, (h, c)) output = self.fcn(self.drop(h)) predicted = output.argmax(1) word = predicted word_str = vocab.itos[predicted.item()] if word_str == "": break caption.append(word_str) return " ".join(caption) def init_hidden_state(self, features): mean_features = features.mean(dim=1) h = self.init_h(mean_features) c = self.init_c(mean_features) return h, c class EncoderDecoder(nn.Module): def __init__(self, embed_tensor, vocab, attention_dim=256, encoder_dim=2048, decoder_dim=512): super().__init__() self.encoder = EncoderCNN() self.decoder = DecoderRNN(embed_tensor, len(vocab), attention_dim, encoder_dim, decoder_dim) self.vocab = vocab def forward(self, images, captions): features = self.encoder(images) return self.decoder(features, captions) def predict_caption(self, image_tensor, max_len=20): self.eval() with torch.no_grad(): features = self.encoder(image_tensor.unsqueeze(0).to(next(self.parameters()).device)) return self.decoder.generate_caption(features, self.vocab, max_len)