BriranSus commited on
Commit
a667af8
·
1 Parent(s): f56b89a

initial commit

Browse files
Files changed (2) hide show
  1. app.py +225 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from transformers import ViTModel
6
+ from PIL import Image
7
+ import pickle
8
+ import re
9
+
10
+ class Vocabulary:
11
+ def __init__(self, freq_threshold=5):
12
+ self.freq_threshold = freq_threshold
13
+ self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
14
+ self.stoi = {v: k for k, v in self.itos.items()}
15
+ self.index = 4
16
+
17
+ def __len__(self):
18
+ return len(self.itos)
19
+
20
+ def tokenizer(self, text):
21
+ text = text.lower()
22
+ tokens = re.findall(r"\w+", text)
23
+ return tokens
24
+
25
+ def numericalize(self, text):
26
+ tokens = self.tokenizer(text)
27
+ numericalized = []
28
+ for token in tokens:
29
+ if token in self.stoi:
30
+ numericalized.append(self.stoi[token])
31
+ else:
32
+ numericalized.append(self.stoi["<UNK>"])
33
+ return numericalized
34
+
35
+ class Encoder(nn.Module):
36
+ def __init__(self, embed_dim, freeze=False):
37
+ super().__init__()
38
+ self.vit = ViTModel.from_pretrained("facebook/vit-mae-base")
39
+
40
+ if freeze:
41
+ for param in self.vit.parameters():
42
+ param.requires_grad = False
43
+
44
+ self.linear = nn.Sequential(
45
+ nn.Linear(self.vit.config.hidden_size, embed_dim),
46
+ nn.ReLU(),
47
+ nn.Dropout(0.1),
48
+ nn.Linear(embed_dim, embed_dim),
49
+ nn.LayerNorm(embed_dim)
50
+ )
51
+
52
+ def forward(self, images):
53
+ outputs = self.vit(pixel_values=images)
54
+ patch_embeddings = outputs.last_hidden_state[:, 1:, :]
55
+ features = self.linear(patch_embeddings)
56
+ return features
57
+
58
+ class MultiHeadAttention(nn.Module):
59
+ def __init__(self, hidden_dim, encoder_dim, num_heads=4):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ self.hidden_dim = hidden_dim
63
+ self.head_dim = hidden_dim // num_heads
64
+
65
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
66
+
67
+ self.query = nn.Linear(hidden_dim, hidden_dim)
68
+ self.key = nn.Linear(encoder_dim, hidden_dim)
69
+ self.value = nn.Linear(encoder_dim, hidden_dim)
70
+ self.fc_out = nn.Linear(hidden_dim, encoder_dim)
71
+
72
+ def forward(self, hidden, encoder_outputs):
73
+ B, N, _ = encoder_outputs.shape
74
+ Q = self.query(hidden).view(B, self.num_heads, self.head_dim)
75
+ K = self.key(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
76
+ V = self.value(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
77
+
78
+ scores = torch.matmul(Q.unsqueeze(2), K.transpose(-2, -1)) / (self.head_dim ** 0.5)
79
+ attn = torch.softmax(scores, dim=-1)
80
+ context = torch.matmul(attn, V)
81
+ context = context.transpose(1, 2).contiguous().view(B, self.hidden_dim)
82
+ return self.fc_out(context)
83
+
84
+ class Decoder(nn.Module):
85
+ def __init__(self, embed_dim, hidden_dim, vocab_size, encoder_dim=256, num_layers=2, dropout=0.3, num_heads=4):
86
+ super().__init__()
87
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
88
+ self.dropout = nn.Dropout(dropout)
89
+ self.lstm = nn.LSTM(embed_dim + encoder_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
90
+ self.attention = MultiHeadAttention(hidden_dim, encoder_dim, num_heads=num_heads)
91
+ self.fc = nn.Linear(hidden_dim, vocab_size)
92
+
93
+ def generate(self, features, max_len=50, start_index=1, end_index=2, beam_size=3, beam_search=True):
94
+ B = features.size(0)
95
+ device = features.device
96
+
97
+ states = (torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device),
98
+ torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device))
99
+
100
+ if not beam_search:
101
+ generated = []
102
+ current_token = torch.LongTensor([start_index]).to(device).unsqueeze(0)
103
+ for _ in range(max_len):
104
+ emb = self.embedding(current_token).squeeze(1)
105
+ context = self.attention(states[0][-1], features)
106
+ lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1)
107
+ out, states = self.lstm(lstm_input, states)
108
+ logits = self.fc(out.squeeze(1))
109
+ predicted = logits.argmax(dim=1).item()
110
+ generated.append(predicted)
111
+ if predicted == end_index: break
112
+ current_token = torch.LongTensor([predicted]).to(device).unsqueeze(0)
113
+ return generated
114
+ else:
115
+ beams = [([start_index], 0.0, states) for _ in range(beam_size)]
116
+ for _ in range(max_len):
117
+ new_beams = []
118
+ for seq, log_prob, (h, c) in beams:
119
+ current_token = torch.LongTensor([seq[-1]]).to(device).unsqueeze(0)
120
+ emb = self.embedding(current_token).squeeze(1)
121
+ context = self.attention(h[-1], features)
122
+ lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1)
123
+ out, (h_new, c_new) = self.lstm(lstm_input, (h, c))
124
+ logits = self.fc(out.squeeze(1))
125
+ log_probs = torch.log_softmax(logits, dim=1)
126
+ top_log_probs, top_indices = log_probs.topk(beam_size, dim=1)
127
+
128
+ for k in range(beam_size):
129
+ next_seq = seq + [top_indices[0, k].item()]
130
+ next_log_prob = log_prob + top_log_probs[0, k].item()
131
+ new_beams.append((next_seq, next_log_prob, (h_new, c_new)))
132
+
133
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
134
+ beams = new_beams
135
+ if all(seq[-1] == end_index for seq, _, _ in beams): break
136
+
137
+ best_seq = beams[0][0]
138
+ if best_seq[0] == start_index: best_seq = best_seq[1:]
139
+ return best_seq
140
+
141
+ class Model(nn.Module):
142
+ def __init__(self, encoder, decoder):
143
+ super().__init__()
144
+ self.encoder = encoder
145
+ self.decoder = decoder
146
+
147
+ def generate(self, images, max_len=50):
148
+ features = self.encoder(images)
149
+ captions = self.decoder.generate(features, max_len=max_len, beam_search=True)
150
+ return captions
151
+
152
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153
+ EMBED_DIM = 256
154
+ HIDDEN_DIM = 512
155
+ VOCAB_PATH = "vocab.pkl"
156
+ MODEL_PATH = "vit_lstm.pth"
157
+
158
+ print("Loading Vocabulary...")
159
+ try:
160
+ with open(VOCAB_PATH, "rb") as f:
161
+ vocab = pickle.load(f)
162
+ print(f"Vocabulary Loaded. Size: {len(vocab)}")
163
+ except FileNotFoundError:
164
+ raise RuntimeError("vocab.pkl not found! Please upload it.")
165
+
166
+ print("Initializing Model...")
167
+ encoder = Encoder(EMBED_DIM, freeze=True)
168
+ decoder = Decoder(EMBED_DIM, HIDDEN_DIM, len(vocab))
169
+ model = Model(encoder, decoder).to(DEVICE)
170
+
171
+ print("Loading Weights...")
172
+ try:
173
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
174
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
175
+ model.load_state_dict(checkpoint['model_state_dict'])
176
+ else:
177
+ model.load_state_dict(checkpoint)
178
+ model.eval()
179
+ print("Model Loaded Successfully!")
180
+ except FileNotFoundError:
181
+ raise RuntimeError("vit_lstm.pth not found! Please upload it.")
182
+ except Exception as e:
183
+ print(f"Warning loading weights: {e}")
184
+
185
+ inference_transform = transforms.Compose([
186
+ transforms.Resize((224, 224)),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
189
+ ])
190
+
191
+ def generate_caption(image):
192
+ if image is None:
193
+ return "Please upload an image."
194
+
195
+ try:
196
+ pil_image = image.convert("RGB")
197
+ image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
198
+
199
+ with torch.no_grad():
200
+ output_indices = model.generate(image_tensor)
201
+
202
+ result_words = []
203
+ for idx in output_indices:
204
+ word = vocab.itos.get(idx, "<UNK>")
205
+ if word == "<EOS>":
206
+ break
207
+ if word not in ("<SOS>", "<PAD>"):
208
+ result_words.append(word)
209
+
210
+ caption = " ".join(result_words)
211
+ return caption
212
+
213
+ except Exception as e:
214
+ return f"Error occurred: {str(e)}"
215
+
216
+ iface = gr.Interface(
217
+ fn=generate_caption,
218
+ inputs=gr.Image(type="pil", label="Upload Image"),
219
+ outputs=gr.Textbox(label="Generated Caption"),
220
+ title="ViT + LSTM Image Captioning",
221
+ description="Upload an image to generate a caption using a Vision Transformer (Encoder) and LSTM (Decoder) architecture."
222
+ )
223
+
224
+ if __name__ == "__main__":
225
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ gradio
5
+ pillow
6
+ numpy