hackergeek commited on
Commit
c0eb6b0
·
verified ·
1 Parent(s): a7016ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -1,16 +1,10 @@
1
- # 📦 RADIOCAP13 — HuggingFace Space
2
-
3
- #Below is a complete multi-file project layout for deploying your image-captioning model as a HuggingFace Space.
4
- #You can copy/paste these into your repository.
5
-
6
-
7
- ## **app.py**
8
  import gradio as gr
9
  import torch
10
  from transformers import ViTModel
11
  from PIL import Image
12
  from torchvision import transforms
13
  import json
 
14
 
15
  IMG_SIZE = 224
16
  SEQ_LEN = 32
@@ -57,15 +51,17 @@ class BiasDecoder(torch.nn.Module):
57
  x = x + img_feat.unsqueeze(1)
58
  return self.final_layer(x)
59
 
60
- # Load models
61
- decoder = BiasDecoder().to(device)
62
- decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
63
- decoder.eval()
64
-
65
  vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
66
  vit.eval()
67
 
68
- tokenizer = SimpleTokenizer.load("./")
 
 
 
 
 
 
69
  pad_idx = tokenizer.word2idx["<PAD>"]
70
 
71
  @torch.no_grad()
@@ -94,10 +90,20 @@ def generate_caption(img):
94
 
95
  with gr.Blocks() as demo:
96
  gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
 
 
97
  img_in = gr.Image(type="pil", label="Upload an Image")
98
  out = gr.Textbox(label="Generated Caption")
99
  btn = gr.Button("Generate Caption")
100
- btn.click(generate_caption, inputs=img_in, outputs=out)
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
  demo.launch()
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import ViTModel
4
  from PIL import Image
5
  from torchvision import transforms
6
  import json
7
+ import os
8
 
9
  IMG_SIZE = 224
10
  SEQ_LEN = 32
 
51
  x = x + img_feat.unsqueeze(1)
52
  return self.final_layer(x)
53
 
54
+ # Load ViT
 
 
 
 
55
  vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
56
  vit.eval()
57
 
58
+ # Load decoder weights from RADIOCAP13 folder
59
+ decoder = BiasDecoder().to(device)
60
+ decoder.load_state_dict(torch.load("RADIOCAP13/pytorch_model.bin", map_location=device))
61
+ decoder.eval()
62
+
63
+ # Load tokenizer from same folder
64
+ tokenizer = SimpleTokenizer.load("RADIOCAP13")
65
  pad_idx = tokenizer.word2idx["<PAD>"]
66
 
67
  @torch.no_grad()
 
90
 
91
  with gr.Blocks() as demo:
92
  gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
93
+ gr.Markdown(f"**Device:** {'GPU 🚀' if torch.cuda.is_available() else 'CPU 🐢'}")
94
+
95
  img_in = gr.Image(type="pil", label="Upload an Image")
96
  out = gr.Textbox(label="Generated Caption")
97
  btn = gr.Button("Generate Caption")
98
+ status = gr.Markdown("Ready.")
99
+
100
+ def wrapped(img):
101
+ status.update("Processing…")
102
+ caption = generate_caption(img)
103
+ status.update("Done ✔️")
104
+ return caption
105
+
106
+ btn.click(wrapped, inputs=img_in, outputs=out)
107
 
108
  if __name__ == "__main__":
109
  demo.launch()