File size: 4,389 Bytes
3974b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

    def forward(self, images):
        features = self.resnet(images)                  # (B, 2048, 7, 7)
        features = features.permute(0, 2, 3, 1)         # (B, 7, 7, 2048)
        features = features.view(features.size(0), -1, features.size(-1))  # (B, 49, 2048)
        return features


class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.W = nn.Linear(decoder_dim, attention_dim)
        self.U = nn.Linear(encoder_dim, attention_dim)
        self.A = nn.Linear(attention_dim, 1)

    def forward(self, features, hidden_state):
        u_hs = self.U(features)
        w_ah = self.W(hidden_state)
        combined = torch.tanh(u_hs + w_ah.unsqueeze(1))
        e = self.A(combined).squeeze(2)
        alpha = F.softmax(e, dim=1)
        context = (features * alpha.unsqueeze(2)).sum(dim=1)
        return alpha, context


class DecoderRNN(nn.Module):
    def __init__(self, embed_tensor, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embed_tensor, freeze=False)
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_tensor.size(1) + encoder_dim, decoder_dim)
        self.fcn = nn.Linear(decoder_dim, vocab_size)
        self.drop = nn.Dropout(drop_prob)

    def forward(self, features, captions):
        embeds = self.embedding(captions)
        h, c = self.init_hidden_state(features)
        seq_len = captions.size(1) - 1
        batch_size = captions.size(0)
        num_features = features.size(1)
        preds = torch.zeros(batch_size, seq_len, self.fcn.out_features).to(features.device)

        for s in range(seq_len):
            alpha, context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            preds[:, s] = self.fcn(self.drop(h))

        return preds

    def generate_caption(self, features, vocab, max_len=20):
        h, c = self.init_hidden_state(features)
        word = torch.tensor([vocab.stoi["<SOS>"]]).to(features.device)
        caption = []

        for _ in range(max_len):
            embed = self.embedding(word).squeeze(1)
            alpha, context = self.attention(features, h)
            lstm_input = torch.cat((embed, context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            predicted = output.argmax(1)
            word = predicted
            word_str = vocab.itos[predicted.item()]
            if word_str == "<EOS>":
                break
            caption.append(word_str)

        return " ".join(caption)

    def init_hidden_state(self, features):
        mean_features = features.mean(dim=1)
        h = self.init_h(mean_features)
        c = self.init_c(mean_features)
        return h, c


class EncoderDecoder(nn.Module):
    def __init__(self, embed_tensor, vocab, attention_dim=256, encoder_dim=2048, decoder_dim=512):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(embed_tensor, len(vocab), attention_dim, encoder_dim, decoder_dim)
        self.vocab = vocab

    def forward(self, images, captions):
        features = self.encoder(images)
        return self.decoder(features, captions)

    def predict_caption(self, image_tensor, max_len=20):
        self.eval()
        with torch.no_grad():
            features = self.encoder(image_tensor.unsqueeze(0).to(next(self.parameters()).device))
            return self.decoder.generate_caption(features, self.vocab, max_len)