Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| class EncoderCNN(nn.Module): | |
| def __init__(self): | |
| super(EncoderCNN, self).__init__() | |
| resnet = models.resnet50(pretrained=True) | |
| for param in resnet.parameters(): | |
| param.requires_grad_(False) | |
| modules = list(resnet.children())[:-2] | |
| self.resnet = nn.Sequential(*modules) | |
| def forward(self, images): | |
| features = self.resnet(images) # (B, 2048, 7, 7) | |
| features = features.permute(0, 2, 3, 1) # (B, 7, 7, 2048) | |
| features = features.view(features.size(0), -1, features.size(-1)) # (B, 49, 2048) | |
| return features | |
| class Attention(nn.Module): | |
| def __init__(self, encoder_dim, decoder_dim, attention_dim): | |
| super(Attention, self).__init__() | |
| self.W = nn.Linear(decoder_dim, attention_dim) | |
| self.U = nn.Linear(encoder_dim, attention_dim) | |
| self.A = nn.Linear(attention_dim, 1) | |
| def forward(self, features, hidden_state): | |
| u_hs = self.U(features) | |
| w_ah = self.W(hidden_state) | |
| combined = torch.tanh(u_hs + w_ah.unsqueeze(1)) | |
| e = self.A(combined).squeeze(2) | |
| alpha = F.softmax(e, dim=1) | |
| context = (features * alpha.unsqueeze(2)).sum(dim=1) | |
| return alpha, context | |
| class DecoderRNN(nn.Module): | |
| def __init__(self, embed_tensor, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3): | |
| super().__init__() | |
| self.embedding = nn.Embedding.from_pretrained(embed_tensor, freeze=False) | |
| self.attention = Attention(encoder_dim, decoder_dim, attention_dim) | |
| self.init_h = nn.Linear(encoder_dim, decoder_dim) | |
| self.init_c = nn.Linear(encoder_dim, decoder_dim) | |
| self.lstm_cell = nn.LSTMCell(embed_tensor.size(1) + encoder_dim, decoder_dim) | |
| self.fcn = nn.Linear(decoder_dim, vocab_size) | |
| self.drop = nn.Dropout(drop_prob) | |
| def forward(self, features, captions): | |
| embeds = self.embedding(captions) | |
| h, c = self.init_hidden_state(features) | |
| seq_len = captions.size(1) - 1 | |
| batch_size = captions.size(0) | |
| num_features = features.size(1) | |
| preds = torch.zeros(batch_size, seq_len, self.fcn.out_features).to(features.device) | |
| for s in range(seq_len): | |
| alpha, context = self.attention(features, h) | |
| lstm_input = torch.cat((embeds[:, s], context), dim=1) | |
| h, c = self.lstm_cell(lstm_input, (h, c)) | |
| preds[:, s] = self.fcn(self.drop(h)) | |
| return preds | |
| def generate_caption(self, features, vocab, max_len=20): | |
| h, c = self.init_hidden_state(features) | |
| word = torch.tensor([vocab.stoi["<SOS>"]]).to(features.device) | |
| caption = [] | |
| for _ in range(max_len): | |
| embed = self.embedding(word).squeeze(1) | |
| alpha, context = self.attention(features, h) | |
| lstm_input = torch.cat((embed, context), dim=1) | |
| h, c = self.lstm_cell(lstm_input, (h, c)) | |
| output = self.fcn(self.drop(h)) | |
| predicted = output.argmax(1) | |
| word = predicted | |
| word_str = vocab.itos[predicted.item()] | |
| if word_str == "<EOS>": | |
| break | |
| caption.append(word_str) | |
| return " ".join(caption) | |
| def init_hidden_state(self, features): | |
| mean_features = features.mean(dim=1) | |
| h = self.init_h(mean_features) | |
| c = self.init_c(mean_features) | |
| return h, c | |
| class EncoderDecoder(nn.Module): | |
| def __init__(self, embed_tensor, vocab, attention_dim=256, encoder_dim=2048, decoder_dim=512): | |
| super().__init__() | |
| self.encoder = EncoderCNN() | |
| self.decoder = DecoderRNN(embed_tensor, len(vocab), attention_dim, encoder_dim, decoder_dim) | |
| self.vocab = vocab | |
| def forward(self, images, captions): | |
| features = self.encoder(images) | |
| return self.decoder(features, captions) | |
| def predict_caption(self, image_tensor, max_len=20): | |
| self.eval() | |
| with torch.no_grad(): | |
| features = self.encoder(image_tensor.unsqueeze(0).to(next(self.parameters()).device)) | |
| return self.decoder.generate_caption(features, self.vocab, max_len) | |