Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,10 @@
|
|
| 1 |
-
# 📦 RADIOCAP13 — HuggingFace Space
|
| 2 |
-
|
| 3 |
-
#Below is a complete multi-file project layout for deploying your image-captioning model as a HuggingFace Space.
|
| 4 |
-
#You can copy/paste these into your repository.
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
## **app.py**
|
| 8 |
import gradio as gr
|
| 9 |
import torch
|
| 10 |
from transformers import ViTModel
|
| 11 |
from PIL import Image
|
| 12 |
from torchvision import transforms
|
| 13 |
import json
|
|
|
|
| 14 |
|
| 15 |
IMG_SIZE = 224
|
| 16 |
SEQ_LEN = 32
|
|
@@ -57,15 +51,17 @@ class BiasDecoder(torch.nn.Module):
|
|
| 57 |
x = x + img_feat.unsqueeze(1)
|
| 58 |
return self.final_layer(x)
|
| 59 |
|
| 60 |
-
# Load
|
| 61 |
-
decoder = BiasDecoder().to(device)
|
| 62 |
-
decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
|
| 63 |
-
decoder.eval()
|
| 64 |
-
|
| 65 |
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
|
| 66 |
vit.eval()
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
pad_idx = tokenizer.word2idx["<PAD>"]
|
| 70 |
|
| 71 |
@torch.no_grad()
|
|
@@ -94,10 +90,20 @@ def generate_caption(img):
|
|
| 94 |
|
| 95 |
with gr.Blocks() as demo:
|
| 96 |
gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
|
|
|
|
|
|
|
| 97 |
img_in = gr.Image(type="pil", label="Upload an Image")
|
| 98 |
out = gr.Textbox(label="Generated Caption")
|
| 99 |
btn = gr.Button("Generate Caption")
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
if __name__ == "__main__":
|
| 103 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import ViTModel
|
| 4 |
from PIL import Image
|
| 5 |
from torchvision import transforms
|
| 6 |
import json
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
IMG_SIZE = 224
|
| 10 |
SEQ_LEN = 32
|
|
|
|
| 51 |
x = x + img_feat.unsqueeze(1)
|
| 52 |
return self.final_layer(x)
|
| 53 |
|
| 54 |
+
# Load ViT
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
|
| 56 |
vit.eval()
|
| 57 |
|
| 58 |
+
# Load decoder weights from RADIOCAP13 folder
|
| 59 |
+
decoder = BiasDecoder().to(device)
|
| 60 |
+
decoder.load_state_dict(torch.load("RADIOCAP13/pytorch_model.bin", map_location=device))
|
| 61 |
+
decoder.eval()
|
| 62 |
+
|
| 63 |
+
# Load tokenizer from same folder
|
| 64 |
+
tokenizer = SimpleTokenizer.load("RADIOCAP13")
|
| 65 |
pad_idx = tokenizer.word2idx["<PAD>"]
|
| 66 |
|
| 67 |
@torch.no_grad()
|
|
|
|
| 90 |
|
| 91 |
with gr.Blocks() as demo:
|
| 92 |
gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
|
| 93 |
+
gr.Markdown(f"**Device:** {'GPU 🚀' if torch.cuda.is_available() else 'CPU 🐢'}")
|
| 94 |
+
|
| 95 |
img_in = gr.Image(type="pil", label="Upload an Image")
|
| 96 |
out = gr.Textbox(label="Generated Caption")
|
| 97 |
btn = gr.Button("Generate Caption")
|
| 98 |
+
status = gr.Markdown("Ready.")
|
| 99 |
+
|
| 100 |
+
def wrapped(img):
|
| 101 |
+
status.update("Processing…")
|
| 102 |
+
caption = generate_caption(img)
|
| 103 |
+
status.update("Done ✔️")
|
| 104 |
+
return caption
|
| 105 |
+
|
| 106 |
+
btn.click(wrapped, inputs=img_in, outputs=out)
|
| 107 |
|
| 108 |
if __name__ == "__main__":
|
| 109 |
demo.launch()
|