File size: 4,277 Bytes
dedcc83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
from torchvision import models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        backbone = [module for module in backbone.children()][:-1]
        backbone.append(nn.Flatten())
        self.backbone = nn.Sequential(*backbone)
     

    def forward(self, x):
        return self.backbone(x)
    
    def fine_tune(self, fine_tune=False):
        for param in self.parameters():
            param.requires_grad = False

        # If fine-tuning, only fine-tune bottom layers
        for c in list(self.backbone.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

class Decoder(nn.Module):
    def __init__(self, tokenizer, dropout=0.):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer)
        self.emb = nn.Embedding(self.vocab_size, 512) # size (b, 512)
        self.lstm = nn.LSTMCell(512, 512) 
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(512, len(tokenizer.vocab))
        self.init_h = nn.Linear(2048, 512)
        self.init_c = nn.Linear(2048, 512)
    
    def init_states(self, encoder_out):
        h = self.init_h(encoder_out)
        c = self.init_c(encoder_out)
        return h, c

    def forward(self, enc_out, captions, caplens, device):
        batch_size = enc_out.shape[0]
        caplens, sort_idx = caplens.squeeze(1).sort(dim=0, descending=True)
        enc_out = enc_out[sort_idx]
        captions = captions[sort_idx]
        h, c = self.init_states(enc_out)

        # Embedding
        embeddings = self.emb(captions)  # (batch_size, max_caption_length, embed_dim)


        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        caplens = (caplens - 1).tolist()


        # Create tensors to hold word predicion scores
        predictions = torch.zeros(batch_size, max(caplens), self.vocab_size).to(device)

        max_timesteps = max(caplens)

        for t in range(max_timesteps):
            batch_size_t = sum([l > t for l in caplens])
            h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds
        
        return  predictions, captions, caplens, sort_idx

    def predict(self, enc_out, device, max_steps):
        with torch.no_grad():
            batch_size = enc_out.shape[0]
            h, c = self.init_states(enc_out)

            captions = []

            for i in range(batch_size):
                temp = []
                next_token = self.emb(torch.LongTensor([self.tokenizer.val2idx['<start>']]).to(device))
                h_, c_ = h[i].unsqueeze(0), c[i].unsqueeze(0)

                step = 1
                while True:
                    h_, c_ = self.lstm(next_token, (h_, c_))
                    preds = self.fc(self.dropout(h_))

                    max_val, max_idx = torch.max(preds, dim=1)
                    max_idx = max_idx.item()
                    temp.append(max_idx)
                    
                    if max_idx in [self.tokenizer.val2idx['<end>']] or step == max_steps:
                        break
                    next_token = self.emb(torch.LongTensor([max_idx]).to(device))
                    step += 1
                captions.append(temp)
        return  captions

    

class CaptionModel(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = len(self.tokenizer)
        self.encoder = Encoder()
        self.decoder = Decoder(tokenizer)

    def forward(self, x, captions, caplens, device):
        encoder_out = self.encoder(x)
        predictions, captions, caplens, sort_idx = self.decoder(encoder_out, captions, caplens, device)
        return predictions, captions, caplens, sort_idx
    
    def predict(self, x, device, max_steps=25):
        encoder_out = self.encoder(x)
        captions = self.decoder.predict(encoder_out, device, max_steps)
        return captions