hackergeek commited on
Commit
9ec767e
Β·
verified Β·
1 Parent(s): 9286537

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -23
app.py CHANGED
@@ -5,13 +5,30 @@ from PIL import Image
5
  from torchvision import transforms
6
  import json
7
  import os
 
8
 
 
 
 
9
  IMG_SIZE = 224
10
  SEQ_LEN = 32
11
  VOCAB_SIZE = 75460
12
-
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  transform = transforms.Compose([
16
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
17
  transforms.ToTensor(),
@@ -33,68 +50,79 @@ class SimpleTokenizer:
33
 
34
  @classmethod
35
  def load(cls, path):
36
- with open(f"{path}/vocab.json", "r") as f:
37
  word2idx = json.load(f)
38
  return cls(word2idx)
39
 
 
 
 
40
  class BiasDecoder(torch.nn.Module):
41
  def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
42
  super().__init__()
43
- self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
44
- self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
45
- self.final_layer = torch.nn.Linear(feature_dim, vocab_size)
46
 
47
  def forward(self, img_feat, target_seq):
48
- x = self.token_emb(target_seq)
49
- pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings - 1)
50
- x = x + self.pos_emb(pos)
51
- x = x + img_feat.unsqueeze(1)
52
  return self.final_layer(x)
53
 
54
- # Load ViT
 
 
55
  vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
56
  vit.eval()
57
 
58
- # Load decoder weights from RADIOCAP13 folder
59
  decoder = BiasDecoder().to(device)
60
- decoder.load_state_dict(torch.load("main/pytorch_model.bin", map_location=device))
61
  decoder.eval()
62
 
63
- # Load tokenizer from same folder
64
- tokenizer = SimpleTokenizer.load("RADIOCAP13")
65
- pad_idx = tokenizer.word2idx["<PAD>"]
66
 
 
 
 
67
  @torch.no_grad()
68
- def generate_caption(img):
69
  img_tensor = preprocess_image(img).unsqueeze(0).to(device)
70
- img_feat = vit(pixel_values=img_tensor).pooler_output
71
 
72
  beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
73
- beam_size = 3
74
 
75
- for _ in range(SEQ_LEN - 1):
76
  candidates = []
77
  for seq, score in beams:
78
- inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
79
  logits = decoder(img_feat, inp)
80
- probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
81
  top_p, top_i = torch.topk(probs, beam_size)
 
82
  for i in range(beam_size):
83
  candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
 
84
  beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
 
85
  if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams):
86
  break
87
 
88
  words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
89
  return " ".join(words)
90
 
 
 
 
91
  with gr.Blocks() as demo:
92
  gr.Markdown("# RADIOCAP13 β€” Image Captioning Demo")
93
  gr.Markdown(f"**Device:** {'GPU πŸš€' if torch.cuda.is_available() else 'CPU 🐒'}")
94
 
95
  img_in = gr.Image(type="pil", label="Upload an Image")
96
- out = gr.Textbox(label="Generated Caption")
97
- btn = gr.Button("Generate Caption")
98
  status = gr.Markdown("Ready.")
99
 
100
  def wrapped(img):
 
5
  from torchvision import transforms
6
  import json
7
  import os
8
+ from huggingface_hub import hf_hub_download
9
 
10
+ # ---------------------
11
+ # Config
12
+ # ---------------------
13
  IMG_SIZE = 224
14
  SEQ_LEN = 32
15
  VOCAB_SIZE = 75460
16
+ REPO_ID = "hackergeek/RADIOCAP13" # your HF repo
17
+ WEIGHTS_FILENAME = "pytorch_model.bin"
18
+ VOCAB_FILENAME = "vocab.json"
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
+ # ---------------------
22
+ # Download model files (if not present)
23
+ # ---------------------
24
+ # Download weights
25
+ weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILENAME)
26
+ # Download vocab
27
+ vocab_path = hf_hub_download(repo_id=REPO_ID, filename=VOCAB_FILENAME)
28
+
29
+ # ---------------------
30
+ # Preprocessing & Tokenizer
31
+ # ---------------------
32
  transform = transforms.Compose([
33
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
34
  transforms.ToTensor(),
 
50
 
51
  @classmethod
52
  def load(cls, path):
53
+ with open(path, "r") as f:
54
  word2idx = json.load(f)
55
  return cls(word2idx)
56
 
57
+ # ---------------------
58
+ # Decoder
59
+ # ---------------------
60
  class BiasDecoder(torch.nn.Module):
61
  def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
62
  super().__init__()
63
+ self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
64
+ self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
65
+ self.final_layer = torch.nn.Linear(feature_dim, vocab_size)
66
 
67
  def forward(self, img_feat, target_seq):
68
+ x = self.token_emb(target_seq)
69
+ pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings-1)
70
+ x = x + self.pos_emb(pos)
71
+ x = x + img_feat.unsqueeze(1)
72
  return self.final_layer(x)
73
 
74
+ # ---------------------
75
+ # Load models
76
+ # ---------------------
77
  vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
78
  vit.eval()
79
 
 
80
  decoder = BiasDecoder().to(device)
81
+ decoder.load_state_dict(torch.load(weights_path, map_location=device))
82
  decoder.eval()
83
 
84
+ tokenizer = SimpleTokenizer.load(vocab_path)
85
+ pad_idx = tokenizer.word2idx["<PAD>"]
 
86
 
87
+ # ---------------------
88
+ # Caption generation
89
+ # ---------------------
90
  @torch.no_grad()
91
+ def generate_caption(img, max_len=SEQ_LEN, beam_size=3):
92
  img_tensor = preprocess_image(img).unsqueeze(0).to(device)
93
+ img_feat = vit(pixel_values=img_tensor).pooler_output
94
 
95
  beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
 
96
 
97
+ for _ in range(max_len - 1):
98
  candidates = []
99
  for seq, score in beams:
100
+ inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
101
  logits = decoder(img_feat, inp)
102
+ probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
103
  top_p, top_i = torch.topk(probs, beam_size)
104
+
105
  for i in range(beam_size):
106
  candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
107
+
108
  beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
109
+
110
  if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams):
111
  break
112
 
113
  words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
114
  return " ".join(words)
115
 
116
+ # ---------------------
117
+ # Gradio interface
118
+ # ---------------------
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# RADIOCAP13 β€” Image Captioning Demo")
121
  gr.Markdown(f"**Device:** {'GPU πŸš€' if torch.cuda.is_available() else 'CPU 🐒'}")
122
 
123
  img_in = gr.Image(type="pil", label="Upload an Image")
124
+ out = gr.Textbox(label="Generated Caption")
125
+ btn = gr.Button("Generate Caption")
126
  status = gr.Markdown("Ready.")
127
 
128
  def wrapped(img):