waleed-12 commited on
Commit
59ba849
·
verified ·
1 Parent(s): 362b97f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from torchvision import models
7
+
8
+ # -------------------------
9
+ # Model definitions (must match training)
10
+ # -------------------------
11
+ PAD_TOKEN = "<pad>"
12
+ UNK_TOKEN = "<unk>"
13
+
14
+ class Encoder(nn.Module):
15
+ def __init__(self, in_dim=2048, hidden_size=512):
16
+ super().__init__()
17
+ self.fc = nn.Linear(in_dim, hidden_size)
18
+ self.relu = nn.ReLU()
19
+
20
+ def forward(self, feat):
21
+ return self.relu(self.fc(feat))
22
+
23
+ class Decoder(nn.Module):
24
+ def __init__(self, vocab_size, pad_id, embed_dim=256, hidden_size=512, dropout=0.1):
25
+ super().__init__()
26
+ self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
27
+ self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers=1, batch_first=True)
28
+ self.dropout = nn.Dropout(dropout)
29
+ self.fc_out = nn.Linear(hidden_size, vocab_size)
30
+
31
+ class Img2Caption(nn.Module):
32
+ def __init__(self, vocab_size, pad_id, hidden_size=512, embed_dim=256):
33
+ super().__init__()
34
+ self.encoder = Encoder(in_dim=2048, hidden_size=hidden_size)
35
+ self.decoder = Decoder(vocab_size=vocab_size, pad_id=pad_id, embed_dim=embed_dim, hidden_size=hidden_size)
36
+
37
+ # -------------------------
38
+ # Load checkpoint
39
+ # -------------------------
40
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ CKPT_PATH = "img_caption_seq2seq.pth"
42
+
43
+ ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
44
+ word2idx = ckpt["word2idx"]
45
+ idx2word = ckpt["idx2word"]
46
+ max_len = ckpt.get("max_len", 30)
47
+
48
+ pad_id = word2idx[PAD_TOKEN]
49
+ start_id = word2idx["<start>"]
50
+ end_id = word2idx["<end>"]
51
+
52
+ model = Img2Caption(vocab_size=len(word2idx), pad_id=pad_id).to(DEVICE)
53
+ model.load_state_dict(ckpt["model_state"])
54
+ model.eval()
55
+
56
+ # -------------------------
57
+ # ResNet50 feature extractor (on-the-fly)
58
+ # -------------------------
59
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
60
+ resnet = nn.Sequential(*list(resnet.children())[:-1]).to(DEVICE)
61
+ resnet.eval()
62
+
63
+ transform = transforms.Compose([
64
+ transforms.Resize((224, 224)),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
67
+ ])
68
+
69
+ def decode_tokens(token_ids):
70
+ words = []
71
+ for tid in token_ids:
72
+ w = idx2word.get(int(tid), UNK_TOKEN)
73
+ if w == "<end>":
74
+ break
75
+ if w not in ["<start>", "<pad>"]:
76
+ words.append(w)
77
+ return " ".join(words)
78
+
79
+ @torch.no_grad()
80
+ def greedy_caption(feat_vec, max_words=30):
81
+ feat = torch.tensor(feat_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE) # [1,2048]
82
+ h0 = model.encoder(feat) # [1,hidden]
83
+
84
+ last = start_id
85
+ out_tokens = []
86
+
87
+ h = h0.unsqueeze(0) # [1,1,hidden]
88
+ c = torch.zeros_like(h)
89
+
90
+ for _ in range(max_words):
91
+ cur = torch.tensor([[last]], dtype=torch.long).to(DEVICE)
92
+ emb = model.decoder.embed(cur) # [1,1,E]
93
+ lstm_out, (h, c) = model.decoder.lstm(emb, (h, c)) # [1,1,H]
94
+ logits = model.decoder.fc_out(lstm_out.squeeze(1)) # [1,V]
95
+ nxt = int(torch.argmax(logits, dim=-1).item())
96
+
97
+ if nxt == end_id:
98
+ break
99
+ out_tokens.append(nxt)
100
+ last = nxt
101
+
102
+ return decode_tokens(out_tokens)
103
+
104
+ @torch.no_grad()
105
+ def beam_caption(feat_vec, beam_size=3, max_words=30):
106
+ feat = torch.tensor(feat_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)
107
+ h0 = model.encoder(feat)
108
+
109
+ h = h0.unsqueeze(0)
110
+ c = torch.zeros_like(h)
111
+
112
+ beams = [([], 0.0, h, c, start_id)] # (tokens, score, h, c, last)
113
+
114
+ for _ in range(max_words):
115
+ new_beams = []
116
+ for tokens, score, h_i, c_i, last in beams:
117
+ if last == end_id:
118
+ new_beams.append((tokens, score, h_i, c_i, last))
119
+ continue
120
+
121
+ cur = torch.tensor([[last]], dtype=torch.long).to(DEVICE)
122
+ emb = model.decoder.embed(cur)
123
+ lstm_out, (h_new, c_new) = model.decoder.lstm(emb, (h_i, c_i))
124
+ logits = model.decoder.fc_out(lstm_out.squeeze(1))
125
+ log_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
126
+
127
+ topk = torch.topk(log_probs, beam_size)
128
+ for lp, idx in zip(topk.values.tolist(), topk.indices.tolist()):
129
+ new_beams.append((tokens + [idx], score + lp, h_new, c_new, idx))
130
+
131
+ new_beams.sort(key=lambda x: x[1], reverse=True)
132
+ beams = new_beams[:beam_size]
133
+
134
+ if all(b[4] == end_id for b in beams):
135
+ break
136
+
137
+ best = beams[0][0]
138
+ if len(best) and best[-1] == end_id:
139
+ best = best[:-1]
140
+ return decode_tokens(best)
141
+
142
+ @torch.no_grad()
143
+ def caption_image(img: Image.Image, decoding="Beam Search"):
144
+ img = img.convert("RGB")
145
+ x = transform(img).unsqueeze(0).to(DEVICE)
146
+
147
+ feat = resnet(x).view(1, -1).squeeze(0).cpu().numpy() # [2048]
148
+
149
+ if decoding == "Greedy":
150
+ return greedy_caption(feat, max_words=30)
151
+ return beam_caption(feat, beam_size=3, max_words=30)
152
+
153
+ demo = gr.Interface(
154
+ fn=caption_image,
155
+ inputs=[
156
+ gr.Image(type="pil", label="Upload Image"),
157
+ gr.Radio(["Beam Search", "Greedy"], value="Beam Search", label="Decoding")
158
+ ],
159
+ outputs=gr.Textbox(label="Generated Caption"),
160
+ title="Seq2Seq Image Captioning (Flickr30k)",
161
+ description="Upload an image and generate a caption using a ResNet50 + LSTM Seq2Seq model."
162
+ )
163
+
164
+ if __name__ == "__main__":
165
+ demo.launch()