Mayanand commited on
Commit
dedcc83
·
1 Parent(s): 1074b8a

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +118 -0
model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
9
+ backbone = [module for module in backbone.children()][:-1]
10
+ backbone.append(nn.Flatten())
11
+ self.backbone = nn.Sequential(*backbone)
12
+
13
+
14
+ def forward(self, x):
15
+ return self.backbone(x)
16
+
17
+ def fine_tune(self, fine_tune=False):
18
+ for param in self.parameters():
19
+ param.requires_grad = False
20
+
21
+ # If fine-tuning, only fine-tune bottom layers
22
+ for c in list(self.backbone.children())[5:]:
23
+ for p in c.parameters():
24
+ p.requires_grad = fine_tune
25
+
26
+ class Decoder(nn.Module):
27
+ def __init__(self, tokenizer, dropout=0.):
28
+ super().__init__()
29
+ self.tokenizer = tokenizer
30
+ self.vocab_size = len(tokenizer)
31
+ self.emb = nn.Embedding(self.vocab_size, 512) # size (b, 512)
32
+ self.lstm = nn.LSTMCell(512, 512)
33
+ self.dropout = nn.Dropout(p=dropout)
34
+ self.fc = nn.Linear(512, len(tokenizer.vocab))
35
+ self.init_h = nn.Linear(2048, 512)
36
+ self.init_c = nn.Linear(2048, 512)
37
+
38
+ def init_states(self, encoder_out):
39
+ h = self.init_h(encoder_out)
40
+ c = self.init_c(encoder_out)
41
+ return h, c
42
+
43
+ def forward(self, enc_out, captions, caplens, device):
44
+ batch_size = enc_out.shape[0]
45
+ caplens, sort_idx = caplens.squeeze(1).sort(dim=0, descending=True)
46
+ enc_out = enc_out[sort_idx]
47
+ captions = captions[sort_idx]
48
+ h, c = self.init_states(enc_out)
49
+
50
+ # Embedding
51
+ embeddings = self.emb(captions) # (batch_size, max_caption_length, embed_dim)
52
+
53
+
54
+ # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
55
+ # So, decoding lengths are actual lengths - 1
56
+ caplens = (caplens - 1).tolist()
57
+
58
+
59
+ # Create tensors to hold word predicion scores
60
+ predictions = torch.zeros(batch_size, max(caplens), self.vocab_size).to(device)
61
+
62
+ max_timesteps = max(caplens)
63
+
64
+ for t in range(max_timesteps):
65
+ batch_size_t = sum([l > t for l in caplens])
66
+ h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t]))
67
+ preds = self.fc(self.dropout(h))
68
+ predictions[:batch_size_t, t, :] = preds
69
+
70
+ return predictions, captions, caplens, sort_idx
71
+
72
+ def predict(self, enc_out, device, max_steps):
73
+ with torch.no_grad():
74
+ batch_size = enc_out.shape[0]
75
+ h, c = self.init_states(enc_out)
76
+
77
+ captions = []
78
+
79
+ for i in range(batch_size):
80
+ temp = []
81
+ next_token = self.emb(torch.LongTensor([self.tokenizer.val2idx['<start>']]).to(device))
82
+ h_, c_ = h[i].unsqueeze(0), c[i].unsqueeze(0)
83
+
84
+ step = 1
85
+ while True:
86
+ h_, c_ = self.lstm(next_token, (h_, c_))
87
+ preds = self.fc(self.dropout(h_))
88
+
89
+ max_val, max_idx = torch.max(preds, dim=1)
90
+ max_idx = max_idx.item()
91
+ temp.append(max_idx)
92
+
93
+ if max_idx in [self.tokenizer.val2idx['<end>']] or step == max_steps:
94
+ break
95
+ next_token = self.emb(torch.LongTensor([max_idx]).to(device))
96
+ step += 1
97
+ captions.append(temp)
98
+ return captions
99
+
100
+
101
+
102
+ class CaptionModel(nn.Module):
103
+ def __init__(self, tokenizer):
104
+ super().__init__()
105
+ self.tokenizer = tokenizer
106
+ self.vocab_size = len(self.tokenizer)
107
+ self.encoder = Encoder()
108
+ self.decoder = Decoder(tokenizer)
109
+
110
+ def forward(self, x, captions, caplens, device):
111
+ encoder_out = self.encoder(x)
112
+ predictions, captions, caplens, sort_idx = self.decoder(encoder_out, captions, caplens, device)
113
+ return predictions, captions, caplens, sort_idx
114
+
115
+ def predict(self, x, device, max_steps=25):
116
+ encoder_out = self.encoder(x)
117
+ captions = self.decoder.predict(encoder_out, device, max_steps)
118
+ return captions