Prakhar Trivedi commited on
Commit
2713ac2
·
1 Parent(s): 5cea8ef

added app script for model loading and inference

Browse files
Files changed (2) hide show
  1. app.py +198 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, pickle, re
2
+ from io import BytesIO
3
+ from torchvision import models, transforms
4
+ from matplotlib import pyplot as plt
5
+ from torch import nn
6
+ from collections import Counter
7
+ from PIL import Image
8
+ import gradio as gr
9
+ from huggingface_hub import snapshot_download
10
+
11
+ EMBED_DIM = 256
12
+ HIDDEN_DIM = 512
13
+ MAX_SEQ_LENGTH = 25
14
+ VOCAB_SIZE = 8492
15
+ DEVICE = torch.device("cpu")
16
+
17
+ transform_inference = transforms.Compose([
18
+ transforms.Resize((384, 384)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.5, 0.5, 0.5],
22
+ std=[0.5, 0.5, 0.5]
23
+ )
24
+ ])
25
+
26
+ class Vocabulary:
27
+ def __init__(self, freq_threshold=5):
28
+ self.freq_threshold = freq_threshold
29
+ # self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
30
+ self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
31
+ self.stoi = {v: k for k, v in self.itos.items()}
32
+ self.index = 4
33
+
34
+ def __len__(self):
35
+ return len(self.itos)
36
+
37
+ def tokenizer(self, text):
38
+ text = text.lower()
39
+ tokens = re.findall(r"\w+", text)
40
+ return tokens
41
+
42
+ def build_vocabulary(self, sentence_list):
43
+ frequencies = Counter()
44
+ for sentence in sentence_list:
45
+ tokens = self.tokenizer(sentence)
46
+ frequencies.update(tokens)
47
+
48
+ for word, freq in frequencies.items():
49
+ if freq >= self.freq_threshold:
50
+ self.stoi[word] = self.index
51
+ self.itos[self.index] = word
52
+ self.index += 1
53
+
54
+ def numericalize(self, text):
55
+ tokens = self.tokenizer(text)
56
+ numericalized = []
57
+ for token in tokens:
58
+ if token in self.stoi:
59
+ numericalized.append(self.stoi[token])
60
+ else:
61
+ numericalized.append(self.stoi["<unk>"])
62
+ return numericalized
63
+
64
+ class ViTEncoder(nn.Module):
65
+ def __init__(self, embed_dim):
66
+ super().__init__()
67
+ # Load pretrained ViT
68
+ weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # High-quality pretrained weights
69
+ vit = models.vit_b_16(weights=weights)
70
+
71
+ # Remove classification head
72
+ self.vit = vit
73
+ self.vit.heads = nn.Identity()
74
+
75
+ # Optional: fine-tune ViT
76
+ for param in self.vit.parameters():
77
+ param.requires_grad = False # Set to False if you want to freeze the encoder
78
+
79
+ # Projection to embedding dim for decoder
80
+ self.fc = nn.Linear(self.vit.hidden_dim, embed_dim)
81
+ self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
82
+
83
+ def forward(self, images):
84
+ # images: (B, 3, H, W)
85
+ features = self.vit(images) # (B, vit.hidden_dim)
86
+ features = self.fc(features) # (B, embed_dim)
87
+ features = self.batch_norm(features)
88
+ return features
89
+
90
+ class DecoderLSTM(nn.Module):
91
+ def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
92
+ super().__init__()
93
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
94
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
95
+ self.fc = nn.Linear(hidden_dim, vocab_size)
96
+ self.vocab_size = vocab_size
97
+
98
+ def forward(self, features, captions, states):
99
+ embeddings = self.embedding(captions)
100
+ inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1)
101
+ lstm_out, states = self.lstm(inputs, states)
102
+ logits = self.fc(lstm_out)
103
+ return logits, states
104
+
105
+ def generate(self, features, max_len=20): # changed
106
+ batch_size = features.size(0)
107
+ states = None
108
+ generated_captions = []
109
+
110
+ start_idx = 1 # startofseq
111
+ end_idx = 2 # endofseq
112
+ current_tokens = [start_idx]
113
+
114
+ for _ in range(max_len):
115
+ input_tokens = torch.LongTensor(current_tokens).to(features.device).unsqueeze(0)
116
+ logits, states = self.forward(features, input_tokens, states)
117
+ logits = logits.contiguous().view(-1, VOCAB_SIZE)
118
+ predicted = logits.argmax(dim=1)[-1].item()
119
+
120
+ generated_captions.append(predicted)
121
+ current_tokens.append(predicted)
122
+
123
+ return generated_captions
124
+
125
+ class ImageCaptioningModel(nn.Module):
126
+ def __init__(self, encoder, decoder):
127
+ super().__init__()
128
+ self.encoder = encoder
129
+ self.decoder = decoder
130
+
131
+ def generate(self, images, max_len=MAX_SEQ_LENGTH): # changed
132
+ features = self.encoder(images)
133
+ return self.decoder.generate(features, max_len=max_len)
134
+
135
+ def load_model_and_vocab(repo_id):
136
+ download_dir = snapshot_download(repo_id)
137
+ print(download_dir)
138
+ model_path = os.path.join(download_dir, "best_finetuned_infer.pth")
139
+ vocab_path = os.path.join(download_dir, "vocab.pkl")
140
+
141
+ encoder = ViTEncoder(embed_dim=EMBED_DIM)
142
+ decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, VOCAB_SIZE)
143
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
144
+
145
+ state_dict = torch.load(model_path, map_location=DEVICE)
146
+ model.load_state_dict(state_dict['model_state_dict'])
147
+ model.eval()
148
+
149
+ with open(vocab_path, 'rb') as f:
150
+ vocab = pickle.load(f)
151
+
152
+ return model, vocab
153
+
154
+ model, vocab = load_model_and_vocab("prakhartrivedi/ImageCaptioningSpace")
155
+ print("Model and vocabulary loaded successfully.")
156
+
157
+ def generate_caption_for_image(img):
158
+ pil_img = img.convert("RGB")
159
+ img_tensor = transform_inference(pil_img).unsqueeze(0).to(DEVICE)
160
+
161
+ with torch.no_grad():
162
+ output_indices = model.generate(img_tensor, max_len=MAX_SEQ_LENGTH)
163
+
164
+ result_words = []
165
+ end_token_idx = vocab.stoi["endofseq"]
166
+ for idx in output_indices:
167
+ if idx == end_token_idx:
168
+ break
169
+ word = vocab.itos.get(idx, "unk")
170
+ if word not in ["startofseq", "pad", "endofseq"]:
171
+ result_words.append(word)
172
+ cap = " ".join(result_words)
173
+
174
+ # Convert tensor (1, 3, H, W) to (H, W, 3) and detach from graph
175
+ image_np = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
176
+ image_np = (image_np * 0.5 + 0.5).clip(0, 1) # unnormalize
177
+
178
+ # Plot the image and caption
179
+ plt.figure(figsize=(5, 5))
180
+ plt.imshow(image_np)
181
+ plt.axis("off")
182
+ plt.title(cap)
183
+
184
+ # Save the plot to a buffer
185
+ buf = BytesIO()
186
+ plt.savefig(buf, format='png', bbox_inches='tight')
187
+ plt.close()
188
+ buf.seek(0)
189
+
190
+ # Convert buffer to PIL image
191
+ pil_img = Image.open(buf)
192
+ return pil_img
193
+
194
+ gr.Interface(
195
+ fn=generate_caption_for_image,
196
+ inputs=gr.Image(type="pil"),
197
+ outputs=gr.Image(type="pil")
198
+ ).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ torch
3
+ pillow
4
+ numpy
5
+ torchvision
6
+ matplotlib