hackergeek commited on
Commit
3471015
·
verified ·
1 Parent(s): 3fb11ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTModel
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import json
6
+ import os
7
+
8
+ # Assuming SimpleTokenizer and BiasDecoder classes are available from your training script.
9
+ # For a full runnable example, their definitions are included below.
10
+
11
+ # Re-define necessary components and classes for a self-contained example
12
+ IMG_SIZE = 224
13
+ SEQ_LEN = 32
14
+ VOCAB_SIZE = 75460
15
+
16
+ transform = transforms.Compose([
17
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
18
+ transforms.ToTensor(),
19
+ ])
20
+
21
+ def preprocess_image(img):
22
+ if img is None: raise ValueError("Image is None")
23
+ if not isinstance(img, Image.Image): img = Image.fromarray(img)
24
+ if img.mode != "RGB": img = img.convert("RGB")
25
+ return transform(img)
26
+
27
+ # SimpleTokenizer class (copy-pasted from notebook for self-contained example)
28
+ class SimpleTokenizer:
29
+ def __init__(self, word2idx=None):
30
+ if word2idx is None:
31
+ # Placeholder for actual vocab loading or creation if not loaded from file
32
+ self.word2idx = {} # Escaped
33
+ else:
34
+ self.word2idx = word2idx
35
+ self.idx2word = {v: k for k, v in self.word2idx.items()} # Escaped
36
+
37
+ def encode(self, text, max_len=SEQ_LEN):
38
+ tokens = [self.word2idx.get(w, self.word2idx["<PAD>"]) for w in text.lower().split()]
39
+ tokens = [self.word2idx["<SOS>"]] + tokens[:max_len-2] + [self.word2idx["<EOS>"]]
40
+ tokens += [self.word2idx["<PAD>"]] * (max_len - len(tokens))
41
+ return torch.tensor(tokens, dtype=torch.long)
42
+
43
+ def decode(self, tokens):
44
+ return " ".join(self.idx2word.get(t.item(), "<UNK>") for t in tokens if t not in [self.word2idx["<PAD>"], self.word2idx["<SOS>"], self.word2idx["<EOS>"]])
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ with open(f"{path}/vocab.json", "r") as f: # Correctly escaped
49
+ word2idx = json.load(f)
50
+ tokenizer = cls(word2idx)
51
+ return tokenizer
52
+
53
+ # BiasDecoder class (copy-pasted from notebook for self-contained example)
54
+ class BiasDecoder(torch.nn.Module):
55
+ def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
56
+ super().__init__()
57
+ self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
58
+ self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
59
+ self.final_layer = torch.nn.Linear(feature_dim, vocab_size)
60
+
61
+ def forward(self, img_feat, target_seq):
62
+ x = self.token_emb(target_seq)
63
+ pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings-1)
64
+ x = x + self.pos_emb(pos)
65
+ x = x + img_feat.unsqueeze(1)
66
+ return self.final_layer(x)
67
+
68
+ # Setup device
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ # Load ViT (frozen)
72
+ vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
73
+ vit.eval()
74
+ vit.to(device)
75
+
76
+ # Load decoder
77
+ decoder = BiasDecoder().to(device)
78
+ # Assuming 'pytorch_model.bin' is in the current directory or specified path
79
+ decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
80
+ decoder.eval()
81
+
82
+ # Load tokenizer
83
+ # Assuming 'vocab.json' is in the current directory or specified path
84
+ tokenizer = SimpleTokenizer.load("./")
85
+ pad_idx = tokenizer.word2idx["<PAD>"]
86
+
87
+ # Generation function
88
+ @torch.no_grad()
89
+ def generate_caption(model, img_feat, max_len=SEQ_LEN, beam_size=3):
90
+ model.eval()
91
+ img_feat = img_feat.to(device)
92
+ beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
93
+ for _ in range(max_len - 1):
94
+ candidates = []
95
+ for seq, score in beams:
96
+ inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
97
+ logits = model(img_feat, inp)
98
+ probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
99
+ top_p, top_i = torch.topk(probs, beam_size)
100
+ for i in range(beam_size):
101
+ candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
102
+ beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
103
+ if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams): break
104
+ words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
105
+ return " ".join(words)
106
+
107
+ # Example: Generate a caption for an image
108
+ # For a real example, you would load an actual image and process it.
109
+ # img_path = "path/to/your/image.jpg"
110
+ # image = Image.open(img_path).convert("RGB")
111
+ # img_tensor = preprocess_image(image).unsqueeze(0).to(device)
112
+ # img_feat = vit(pixel_values=img_tensor).pooler_output
113
+ # generated_caption = generate_caption(decoder, img_feat)
114
+ # print(f"Generated caption: {generated_caption}")