AsadAnalyst commited on
Commit
d8c12d2
·
verified ·
1 Parent(s): 1874b32

Upload 7 files

Browse files
Files changed (7) hide show
  1. README (1).md +13 -0
  2. app.py +376 -0
  3. best_model.pth +3 -0
  4. config.json +11 -0
  5. gitattributes +35 -0
  6. requirements (1).txt +5 -0
  7. vocab.pkl +3 -0
README (1).md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Neural Stories Teller
3
+ emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 6.5.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Neural Storyteller – Gradio App for Hugging Face Spaces (Attention model)."""
3
+
4
+ import os, json, pickle
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision import models, transforms
10
+ from PIL import Image
11
+ import gradio as gr
12
+
13
+ # ── Device ──
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # ── Load config ──
17
+ with open("config.json", "r") as f:
18
+ cfg = json.load(f)
19
+
20
+ EMBED_SIZE = cfg["embed_size"]
21
+ HIDDEN_SIZE = cfg["hidden_size"]
22
+ NUM_REGIONS = cfg["num_regions"]
23
+ VOCAB_SIZE = cfg["vocab_size"]
24
+ MAX_LEN = cfg["max_len"]
25
+ DROPOUT = cfg["dropout"]
26
+ BEAM_WIDTH = cfg["beam_width"]
27
+ LENGTH_PEN = cfg.get("length_penalty", 0.7)
28
+ REP_PEN = cfg.get("repetition_penalty", 1.2)
29
+
30
+ # ── Vocabulary class (required for unpickling) ──
31
+ class Vocabulary:
32
+ PAD, START, END, UNK = '<pad>', '<start>', '<end>', '<unk>'
33
+
34
+ def __init__(self, freq_threshold=5):
35
+ self.freq_threshold = freq_threshold
36
+ self.word2idx = {}
37
+ self.idx2word = {}
38
+ self._idx = 0
39
+
40
+ def __len__(self):
41
+ return len(self.word2idx)
42
+
43
+ # ── Load vocabulary ──
44
+ with open("vocab.pkl", "rb") as f:
45
+ vocab = pickle.load(f)
46
+
47
+
48
+ # ══════════════ Model Definitions (must match training) ══════════════
49
+
50
+ class Encoder(nn.Module):
51
+ def __init__(self, feature_dim=2048, hidden_size=HIDDEN_SIZE,
52
+ num_regions=NUM_REGIONS, dropout=DROPOUT):
53
+ super().__init__()
54
+ self.num_regions = num_regions
55
+ self.hidden_size = hidden_size
56
+ self.project = nn.Linear(feature_dim, hidden_size * num_regions)
57
+ self.bn = nn.BatchNorm1d(hidden_size * num_regions)
58
+ self.dropout = nn.Dropout(dropout)
59
+ self.init_h = nn.Linear(feature_dim, hidden_size)
60
+ self.init_c = nn.Linear(feature_dim, hidden_size)
61
+
62
+ def forward(self, features):
63
+ proj = self.dropout(F.relu(self.bn(self.project(features))))
64
+ regions = proj.view(-1, self.num_regions, self.hidden_size)
65
+ h0 = torch.tanh(self.init_h(features))
66
+ c0 = torch.tanh(self.init_c(features))
67
+ return regions, h0, c0
68
+
69
+
70
+ class BahdanauAttention(nn.Module):
71
+ def __init__(self, hidden_size):
72
+ super().__init__()
73
+ self.W_enc = nn.Linear(hidden_size, hidden_size)
74
+ self.W_dec = nn.Linear(hidden_size, hidden_size)
75
+ self.V = nn.Linear(hidden_size, 1)
76
+
77
+ def forward(self, encoder_out, decoder_hidden):
78
+ energy = self.V(torch.tanh(
79
+ self.W_enc(encoder_out) + self.W_dec(decoder_hidden).unsqueeze(1)
80
+ ))
81
+ weights = F.softmax(energy.squeeze(2), dim=1)
82
+ context = (weights.unsqueeze(2) * encoder_out).sum(1)
83
+ return context, weights
84
+
85
+
86
+ class AttentionDecoder(nn.Module):
87
+ def __init__(self, vocab_size, embed_size=EMBED_SIZE,
88
+ hidden_size=HIDDEN_SIZE, dropout=DROPOUT):
89
+ super().__init__()
90
+ self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)
91
+ self.attention = BahdanauAttention(hidden_size)
92
+ self.lstm_cell = nn.LSTMCell(embed_size + hidden_size, hidden_size)
93
+ self.fc_out = nn.Linear(hidden_size + hidden_size, vocab_size)
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ def forward_step(self, word_idx, h, c, encoder_out):
97
+ embed = self.dropout(self.embed(word_idx))
98
+ context, attn_w = self.attention(encoder_out, h)
99
+ lstm_in = torch.cat([embed, context], dim=1)
100
+ h, c = self.lstm_cell(lstm_in, (h, c))
101
+ logits = self.fc_out(self.dropout(torch.cat([h, context], dim=1)))
102
+ return logits, h, c, attn_w
103
+
104
+
105
+ class Seq2SeqCaptioner(nn.Module):
106
+ def __init__(self, vocab_size, embed_size=EMBED_SIZE,
107
+ hidden_size=HIDDEN_SIZE, dropout=DROPOUT,
108
+ num_regions=NUM_REGIONS):
109
+ super().__init__()
110
+ self.encoder = Encoder(2048, hidden_size, num_regions, dropout)
111
+ self.decoder = AttentionDecoder(vocab_size, embed_size, hidden_size, dropout)
112
+ self.hidden_size = hidden_size
113
+
114
+ def forward(self, features, captions, teacher_forcing_ratio=1.0):
115
+ import random
116
+ B = features.size(0)
117
+ T = captions.size(1) - 1
118
+ V = self.decoder.fc_out.out_features
119
+ encoder_out, h, c = self.encoder(features)
120
+ outputs = torch.zeros(B, T, V, device=features.device)
121
+ inp = captions[:, 0]
122
+ for t in range(T):
123
+ logits, h, c, _ = self.decoder.forward_step(inp, h, c, encoder_out)
124
+ outputs[:, t] = logits
125
+ if random.random() < teacher_forcing_ratio:
126
+ inp = captions[:, t + 1]
127
+ else:
128
+ inp = logits.argmax(dim=-1)
129
+ return outputs
130
+
131
+
132
+ # ── Load trained weights ──
133
+ caption_model = Seq2SeqCaptioner(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE, DROPOUT, NUM_REGIONS).to(device)
134
+ caption_model.load_state_dict(torch.load("best_model.pth", map_location=device))
135
+ caption_model.eval()
136
+
137
+ # ── ResNet50 feature extractor ──
138
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
139
+ resnet = nn.Sequential(*list(resnet.children())[:-1])
140
+ resnet = resnet.to(device)
141
+ resnet.eval()
142
+
143
+ img_transform = transforms.Compose([
144
+ transforms.Resize((224, 224)),
145
+ transforms.ToTensor(),
146
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
147
+ std=[0.229, 0.224, 0.225]),
148
+ ])
149
+
150
+
151
+ # ── Greedy Search (faster, simpler) ──
152
+ @torch.no_grad()
153
+ def greedy_search_inference(feature):
154
+ feature = feature.unsqueeze(0).to(device)
155
+ encoder_out, h, c = caption_model.encoder(feature)
156
+
157
+ start_idx = vocab.word2idx[vocab.START]
158
+ end_idx = vocab.word2idx[vocab.END]
159
+
160
+ sequence = [start_idx]
161
+ inp = torch.tensor([start_idx], device=device)
162
+
163
+ for _ in range(MAX_LEN):
164
+ logits, h, c, _ = caption_model.decoder.forward_step(inp, h, c, encoder_out)
165
+ predicted = logits.argmax(dim=-1).item()
166
+
167
+ if predicted == end_idx:
168
+ break
169
+
170
+ sequence.append(predicted)
171
+ inp = torch.tensor([predicted], device=device)
172
+
173
+ words = [vocab.idx2word[i] for i in sequence
174
+ if vocab.idx2word[i] not in (vocab.START, vocab.END, vocab.PAD)]
175
+ return " ".join(words)
176
+
177
+
178
+ # ── Beam Search with penalties ──
179
+ @torch.no_grad()
180
+ def beam_search_inference(feature, beam_width=BEAM_WIDTH,
181
+ length_penalty=LENGTH_PEN,
182
+ repetition_penalty=REP_PEN):
183
+ feature = feature.unsqueeze(0).to(device)
184
+ encoder_out, h0, c0 = caption_model.encoder(feature)
185
+
186
+ start_idx = vocab.word2idx[vocab.START]
187
+ end_idx = vocab.word2idx[vocab.END]
188
+ pad_idx = vocab.word2idx[vocab.PAD]
189
+
190
+ beams = [(0.0, [start_idx], h0, c0)]
191
+ completed = []
192
+
193
+ for _ in range(MAX_LEN):
194
+ new_beams = []
195
+ for log_prob, seq, h, c in beams:
196
+ inp = torch.tensor([seq[-1]], device=device)
197
+ logits, h_new, c_new, _ = caption_model.decoder.forward_step(
198
+ inp, h, c, encoder_out)
199
+ logits = logits.squeeze(0)
200
+
201
+ for prev_tok in set(seq):
202
+ if prev_tok not in (start_idx, end_idx, pad_idx):
203
+ logits[prev_tok] /= repetition_penalty
204
+
205
+ log_probs = F.log_softmax(logits, dim=-1)
206
+ topk_lp, topk_idx = log_probs.topk(beam_width)
207
+
208
+ for k in range(beam_width):
209
+ token = topk_idx[k].item()
210
+ new_lp = log_prob + topk_lp[k].item()
211
+ new_seq = seq + [token]
212
+ if token == end_idx:
213
+ score = new_lp / (len(new_seq) ** length_penalty)
214
+ completed.append((score, new_seq))
215
+ else:
216
+ new_beams.append((new_lp, new_seq, h_new, c_new))
217
+
218
+ new_beams.sort(key=lambda x: x[0], reverse=True)
219
+ beams = new_beams[:beam_width]
220
+ if not beams or len(completed) >= beam_width:
221
+ break
222
+
223
+ if not completed and beams:
224
+ for lp, seq, _, _ in beams:
225
+ completed.append((lp / (len(seq) ** length_penalty), seq))
226
+
227
+ completed.sort(key=lambda x: x[0], reverse=True)
228
+ best_seq = completed[0][1] if completed else [start_idx]
229
+
230
+ words = [vocab.idx2word[i] for i in best_seq
231
+ if vocab.idx2word[i] not in (vocab.START, vocab.END, vocab.PAD)]
232
+ return " ".join(words)
233
+
234
+
235
+ # ── Prediction function for Gradio ──
236
+ def predict(image, search_method, beam_width, length_penalty, repetition_penalty):
237
+ """Take a PIL image -> return generated caption string."""
238
+ if image is None:
239
+ return """
240
+ <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 30px; border-radius: 15px; text-align: center;">
241
+ <p style="color: white; font-size: 20px; margin: 0;">⚠️ Please upload an image first</p>
242
+ </div>
243
+ """
244
+
245
+ image = image.convert("RGB")
246
+ img_tensor = img_transform(image).unsqueeze(0).to(device)
247
+
248
+ with torch.no_grad():
249
+ feature = resnet(img_tensor).view(1, -1).squeeze(0)
250
+
251
+ if search_method == "Greedy Search (Fast)":
252
+ caption = greedy_search_inference(feature)
253
+ method_info = "🚀 Generated using Greedy Search"
254
+ else: # Beam Search
255
+ caption = beam_search_inference(
256
+ feature,
257
+ beam_width=int(beam_width),
258
+ length_penalty=length_penalty,
259
+ repetition_penalty=repetition_penalty
260
+ )
261
+ method_info = f"🔍 Generated using Beam Search (width={int(beam_width)})"
262
+
263
+ # Return beautiful HTML formatted caption
264
+ return f"""
265
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 40px; border-radius: 15px; box-shadow: 0 8px 32px rgba(0,0,0,0.1);">
266
+ <p style="color: white; font-size: 28px; font-weight: 600; text-align: center; line-height: 1.6; margin: 0; text-shadow: 2px 2px 4px rgba(0,0,0,0.2);">
267
+ "{caption}"
268
+ </p>
269
+ <p style="color: rgba(255,255,255,0.9); font-size: 14px; text-align: center; margin-top: 20px; font-style: italic;">
270
+ {method_info}
271
+ </p>
272
+ </div>
273
+ """
274
+
275
+
276
+ # ── Gradio Interface ──
277
+ with gr.Blocks(theme=gr.themes.Soft(), title="Neural Storyteller", css="""
278
+ .caption-box {
279
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
280
+ padding: 30px;
281
+ border-radius: 15px;
282
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
283
+ margin: 20px 0;
284
+ }
285
+ .caption-text {
286
+ color: white;
287
+ font-size: 24px;
288
+ font-weight: 600;
289
+ text-align: center;
290
+ line-height: 1.6;
291
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
292
+ }
293
+ .method-info {
294
+ color: rgba(255,255,255,0.9);
295
+ font-size: 14px;
296
+ text-align: center;
297
+ margin-top: 15px;
298
+ font-style: italic;
299
+ }
300
+ """) as demo:
301
+ gr.Markdown("""
302
+ # 🧠 Neural Storyteller – AI Image Captioning
303
+
304
+ Upload any image and let the AI generate a natural language description using a **Seq2Seq model**
305
+ with ResNet50 encoder and Attention-based LSTM decoder, trained on Flickr30k dataset.
306
+ """)
307
+
308
+ with gr.Row():
309
+ with gr.Column(scale=1):
310
+ image_input = gr.Image(type="pil", label="📸 Upload Your Image", height=400)
311
+
312
+ with gr.Column(scale=1):
313
+ gr.Markdown("### ⚙️ Generation Settings")
314
+
315
+ search_method = gr.Radio(
316
+ choices=["Greedy Search (Fast)", "Beam Search (Better Quality)"],
317
+ value="Beam Search (Better Quality)",
318
+ label="🎯 Decoding Method",
319
+ info="Greedy is faster, Beam produces better results"
320
+ )
321
+
322
+ with gr.Accordion("🔧 Advanced Options (Beam Search Only)", open=False):
323
+ beam_width = gr.Slider(
324
+ minimum=1, maximum=10, value=5, step=1,
325
+ label="Beam Width",
326
+ info="Number of candidates to explore (higher = better quality but slower)"
327
+ )
328
+
329
+ length_penalty = gr.Slider(
330
+ minimum=0.0, maximum=2.0, value=0.7, step=0.1,
331
+ label="Length Penalty",
332
+ info="Controls caption length (lower = shorter, higher = longer)"
333
+ )
334
+
335
+ repetition_penalty = gr.Slider(
336
+ minimum=1.0, maximum=2.0, value=1.2, step=0.1,
337
+ label="Repetition Penalty",
338
+ info="Reduces word repetition (higher = less repetition)"
339
+ )
340
+
341
+ generate_btn = gr.Button("✨ Generate Caption", variant="primary", size="lg", scale=1)
342
+
343
+ # Beautiful caption display area
344
+ gr.Markdown("## 📝 Generated Caption")
345
+ output_text = gr.HTML(label="")
346
+
347
+ with gr.Accordion("💡 Tips & Model Details", open=False):
348
+ gr.Markdown("""
349
+ ### Tips:
350
+ - Try both **Greedy** and **Beam** search to compare results
351
+ - Increase **Beam Width** for more diverse captions
352
+ - Adjust **Length Penalty** if captions are too short/long
353
+ - Use **Repetition Penalty** to avoid repeated words
354
+
355
+ ### Model Details:
356
+ - **Encoder**: ResNet50 (pretrained on ImageNet)
357
+ - **Decoder**: Attention-based LSTM
358
+ - **Training Data**: Flickr30k dataset
359
+ - **Vocabulary**: ~8000 words
360
+ """)
361
+
362
+ generate_btn.click(
363
+ fn=predict,
364
+ inputs=[image_input, search_method, beam_width, length_penalty, repetition_penalty],
365
+ outputs=output_text
366
+ )
367
+
368
+ gr.Markdown("""
369
+ ---
370
+ <p style="text-align: center; color: #666;">
371
+ Built with PyTorch, Gradio, and ❤️ | Model trained on Flickr30k
372
+ </p>
373
+ """)
374
+
375
+ if __name__ == "__main__":
376
+ demo.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00544b3749ce68d18959d6f2330457512bede59a6d6b6936190299c48f8299fe
3
+ size 127596021
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_size": 256,
3
+ "hidden_size": 512,
4
+ "num_regions": 16,
5
+ "vocab_size": 7673,
6
+ "max_len": 40,
7
+ "dropout": 0.4,
8
+ "beam_width": 5,
9
+ "length_penalty": 0.7,
10
+ "repetition_penalty": 1.2
11
+ }
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements (1).txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
5
+ numpy
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3168f2dffb4f92750512bf7f3afe22a87bd8a44deec19c580ba6b213a206abe4
3
+ size 157392