VIKRAM989 commited on
Commit
d75e81d
·
verified ·
1 Parent(s): 4e28109

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +101 -77
model.py CHANGED
@@ -4,13 +4,11 @@ import torchvision.transforms as transforms
4
  import torchvision.models as models
5
  from PIL import Image
6
  import pickle
7
- import sys
8
  import os
9
  import re
10
  from collections import Counter
11
  from huggingface_hub import hf_hub_download
12
 
13
-
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  EMBED_DIM = 512
@@ -38,6 +36,7 @@ class Vocabulary:
38
 
39
  def build_vocabulary(self, sentence_list):
40
  frequencies = Counter()
 
41
  for sentence in sentence_list:
42
  tokens = self.tokenizer(sentence)
43
  frequencies.update(tokens)
@@ -51,11 +50,10 @@ class Vocabulary:
51
  def numericalize(self, text):
52
  tokens = self.tokenizer(text)
53
  numericalized = []
 
54
  for token in tokens:
55
- if token in self.stoi:
56
- numericalized.append(self.stoi[token])
57
- else:
58
- numericalized.append(self.stoi["unk"])
59
  return numericalized
60
 
61
 
@@ -65,21 +63,28 @@ class Vocabulary:
65
  class ResNetEncoder(nn.Module):
66
  def __init__(self, embed_dim):
67
  super().__init__()
68
- resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
69
- for param in resnet.parameters():
70
- param.requires_grad = True
71
  modules = list(resnet.children())[:-1]
 
72
  self.resnet = nn.Sequential(*modules)
73
-
74
  self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
 
75
  self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
76
 
77
  def forward(self, images):
 
78
  with torch.no_grad():
79
- features = self.resnet(images) # (batch_size, 2048, 1, 1)
 
80
  features = features.view(features.size(0), -1)
 
81
  features = self.fc(features)
 
82
  features = self.batch_norm(features)
 
83
  return features
84
 
85
 
@@ -87,20 +92,31 @@ class ResNetEncoder(nn.Module):
87
  # Decoder
88
  # -----------------------
89
  class DecoderLSTM(nn.Module):
 
90
  def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
 
91
  super().__init__()
 
92
  self.embedding = nn.Embedding(vocab_size, embed_dim)
 
93
  self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
 
94
  self.fc = nn.Linear(hidden_dim, vocab_size)
95
 
96
  def forward(self, features, captions):
97
- # remove the last token for input
98
- captions_in = captions[:, :-1]
99
- emb = self.embedding(captions_in)
 
 
100
  features = features.unsqueeze(1)
 
101
  lstm_input = torch.cat((features, emb), dim=1)
 
102
  outputs, _ = self.lstm(lstm_input)
 
103
  logits = self.fc(outputs)
 
104
  return logits
105
 
106
 
@@ -108,19 +124,26 @@ class DecoderLSTM(nn.Module):
108
  # Caption Model
109
  # -----------------------
110
  class ImageCaptioningModel(nn.Module):
 
111
  def __init__(self, encoder, decoder):
 
112
  super().__init__()
 
113
  self.encoder = encoder
 
114
  self.decoder = decoder
115
 
116
  def forward(self, images, captions):
 
117
  features = self.encoder(images)
 
118
  outputs = self.decoder(features, captions)
 
119
  return outputs
120
 
121
 
122
  # -----------------------
123
- # Caption generator
124
  # -----------------------
125
  def generate_caption(model, image, vocab):
126
 
@@ -128,102 +151,103 @@ def generate_caption(model, image, vocab):
128
 
129
  image = image.unsqueeze(0).to(DEVICE)
130
 
 
 
131
  with torch.no_grad():
132
- # Get image features
133
- features = model.encoder(image) # (1, embed_dim)
134
 
135
- # Start with the start token
 
136
  word_idx = vocab.stoi["startofseq"]
137
- sentence = []
138
-
139
- # Initialize hidden state for LSTM
140
- h = None
141
-
142
  for _ in range(MAX_LEN):
143
- # Create input: concatenate features with embedding of previous word
144
  word_tensor = torch.tensor([word_idx]).to(DEVICE)
145
- emb = model.decoder.embedding(word_tensor) # (1, embed_dim)
146
-
147
- if h is None:
148
- # First step: concatenate features with embedding
149
- lstm_input = torch.cat([features.unsqueeze(1), emb.unsqueeze(1)], dim=1) # (1, 2, embed_dim)
 
 
 
 
150
  else:
151
- lstm_input = emb.unsqueeze(1) # (1, 1, embed_dim)
152
-
153
- # Forward through LSTM
154
- output, h_new = model.decoder.lstm(lstm_input, h)
155
- h = h_new
156
-
157
- # Predict next token
158
- logits = model.decoder.fc(output[:, -1, :]) # (1, vocab_size)
159
  predicted = logits.argmax(1).item()
160
-
161
- # Get token from vocab
162
  token = vocab.itos[predicted]
163
-
164
  if token == "endofseq":
165
  break
166
-
167
  sentence.append(token)
 
168
  word_idx = predicted
169
 
170
  return " ".join(sentence)
171
 
172
 
173
  # -----------------------
174
- # Image transform
175
  # -----------------------
176
- transform = transforms.Compose([
177
- transforms.Resize((224,224)),
178
- transforms.ToTensor(),
179
- transforms.Normalize(
180
- mean=[0.485,0.456,0.406],
181
- std=[0.229,0.224,0.225]
182
- )
183
- ])
 
 
184
 
185
 
186
  # -----------------------
187
- # Main
188
  # -----------------------
189
- def main():
190
 
191
- image_path = sys.argv[1]
192
 
193
- # Get the directory where this script is located
194
- script_dir = os.path.dirname(os.path.abspath(__file__))
195
- CHECKPOINT_PATH = hf_hub_download(
196
  repo_id="VIKRAM989/image-label",
197
  filename="best_checkpoint.pth"
198
- )
199
- VOCAB_PATH = os.path.join(script_dir, "vocab.pkl")
200
 
201
- # load vocab
202
- with open(VOCAB_PATH, "rb") as f:
203
- vocab = pickle.load(f)
204
 
205
- vocab_size = len(vocab)
 
206
 
207
- # rebuild model
208
- encoder = ResNetEncoder(EMBED_DIM)
209
- decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
210
- model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
211
 
212
- # load checkpoint
213
- checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
214
 
215
- model.load_state_dict(checkpoint["model_state_dict"])
216
 
217
- model.eval()
218
 
219
- # load image
220
- img = Image.open(image_path).convert("RGB")
221
- img = transform(img)
222
 
223
- caption = generate_caption(model, img, vocab)
 
 
224
 
225
- print("\nCaption:", caption)
226
 
 
 
 
 
 
 
 
 
227
 
228
- if __name__ == "__main__":
229
- main()
 
4
  import torchvision.models as models
5
  from PIL import Image
6
  import pickle
 
7
  import os
8
  import re
9
  from collections import Counter
10
  from huggingface_hub import hf_hub_download
11
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  EMBED_DIM = 512
 
36
 
37
  def build_vocabulary(self, sentence_list):
38
  frequencies = Counter()
39
+
40
  for sentence in sentence_list:
41
  tokens = self.tokenizer(sentence)
42
  frequencies.update(tokens)
 
50
  def numericalize(self, text):
51
  tokens = self.tokenizer(text)
52
  numericalized = []
53
+
54
  for token in tokens:
55
+ numericalized.append(self.stoi.get(token, self.stoi["unk"]))
56
+
 
 
57
  return numericalized
58
 
59
 
 
63
  class ResNetEncoder(nn.Module):
64
  def __init__(self, embed_dim):
65
  super().__init__()
66
+
67
+ resnet = models.resnet50(weights=None)
68
+
69
  modules = list(resnet.children())[:-1]
70
+
71
  self.resnet = nn.Sequential(*modules)
72
+
73
  self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
74
+
75
  self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
76
 
77
  def forward(self, images):
78
+
79
  with torch.no_grad():
80
+ features = self.resnet(images)
81
+
82
  features = features.view(features.size(0), -1)
83
+
84
  features = self.fc(features)
85
+
86
  features = self.batch_norm(features)
87
+
88
  return features
89
 
90
 
 
92
  # Decoder
93
  # -----------------------
94
  class DecoderLSTM(nn.Module):
95
+
96
  def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
97
+
98
  super().__init__()
99
+
100
  self.embedding = nn.Embedding(vocab_size, embed_dim)
101
+
102
  self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
103
+
104
  self.fc = nn.Linear(hidden_dim, vocab_size)
105
 
106
  def forward(self, features, captions):
107
+
108
+ captions = captions[:, :-1]
109
+
110
+ emb = self.embedding(captions)
111
+
112
  features = features.unsqueeze(1)
113
+
114
  lstm_input = torch.cat((features, emb), dim=1)
115
+
116
  outputs, _ = self.lstm(lstm_input)
117
+
118
  logits = self.fc(outputs)
119
+
120
  return logits
121
 
122
 
 
124
  # Caption Model
125
  # -----------------------
126
  class ImageCaptioningModel(nn.Module):
127
+
128
  def __init__(self, encoder, decoder):
129
+
130
  super().__init__()
131
+
132
  self.encoder = encoder
133
+
134
  self.decoder = decoder
135
 
136
  def forward(self, images, captions):
137
+
138
  features = self.encoder(images)
139
+
140
  outputs = self.decoder(features, captions)
141
+
142
  return outputs
143
 
144
 
145
  # -----------------------
146
+ # Caption Generator
147
  # -----------------------
148
  def generate_caption(model, image, vocab):
149
 
 
151
 
152
  image = image.unsqueeze(0).to(DEVICE)
153
 
154
+ sentence = []
155
+
156
  with torch.no_grad():
 
 
157
 
158
+ features = model.encoder(image)
159
+
160
  word_idx = vocab.stoi["startofseq"]
161
+
162
+ hidden = None
163
+
 
 
164
  for _ in range(MAX_LEN):
165
+
166
  word_tensor = torch.tensor([word_idx]).to(DEVICE)
167
+
168
+ emb = model.decoder.embedding(word_tensor)
169
+
170
+ if hidden is None:
171
+
172
+ lstm_input = torch.cat(
173
+ [features.unsqueeze(1), emb.unsqueeze(1)], dim=1
174
+ )
175
+
176
  else:
177
+
178
+ lstm_input = emb.unsqueeze(1)
179
+
180
+ output, hidden = model.decoder.lstm(lstm_input, hidden)
181
+
182
+ logits = model.decoder.fc(output[:, -1, :])
183
+
 
184
  predicted = logits.argmax(1).item()
185
+
 
186
  token = vocab.itos[predicted]
187
+
188
  if token == "endofseq":
189
  break
190
+
191
  sentence.append(token)
192
+
193
  word_idx = predicted
194
 
195
  return " ".join(sentence)
196
 
197
 
198
  # -----------------------
199
+ # Image Transform
200
  # -----------------------
201
+ transform = transforms.Compose(
202
+ [
203
+ transforms.Resize((224, 224)),
204
+ transforms.ToTensor(),
205
+ transforms.Normalize(
206
+ mean=[0.485, 0.456, 0.406],
207
+ std=[0.229, 0.224, 0.225],
208
+ ),
209
+ ]
210
+ )
211
 
212
 
213
  # -----------------------
214
+ # Load Model Once
215
  # -----------------------
 
216
 
217
+ script_dir = os.path.dirname(os.path.abspath(__file__))
218
 
219
+ CHECKPOINT_PATH = hf_hub_download(
 
 
220
  repo_id="VIKRAM989/image-label",
221
  filename="best_checkpoint.pth"
222
+ )
 
223
 
224
+ VOCAB_PATH = os.path.join(script_dir, "vocab.pkl")
 
 
225
 
226
+ with open(VOCAB_PATH, "rb") as f:
227
+ vocab = pickle.load(f)
228
 
229
+ vocab_size = len(vocab)
 
 
 
230
 
231
+ encoder = ResNetEncoder(EMBED_DIM)
 
232
 
233
+ decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
234
 
235
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
236
 
237
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
 
 
238
 
239
+ model.load_state_dict(checkpoint["model_state_dict"])
240
+
241
+ model.eval()
242
 
 
243
 
244
+ # -----------------------
245
+ # Public Function for API
246
+ # -----------------------
247
+ def caption_image(pil_image):
248
+
249
+ img = transform(pil_image).to(DEVICE)
250
+
251
+ caption = generate_caption(model, img, vocab)
252
 
253
+ return caption