app demo
Browse files- app.py +44 -0
- requirements.txt +4 -0
app.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 7 |
+
|
| 8 |
+
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 9 |
+
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def predict(img, prompt, return_topk):
|
| 13 |
+
encoding = processor(img, prompt, return_tensors="pt")
|
| 14 |
+
outputs = model(**encoding)
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
probs = torch.nn.Sigmoid()(outputs.logits)
|
| 17 |
+
topk_anss = torch.topk(probs, return_topk)
|
| 18 |
+
# these are the indices of the top-k outputs
|
| 19 |
+
indices = topk_anss.indices.flatten().numpy()
|
| 20 |
+
# create a dataframe with two columns/series:
|
| 21 |
+
# class labels and corresponding probabilities
|
| 22 |
+
out_df = pd.DataFrame(
|
| 23 |
+
{
|
| 24 |
+
"answer": [model.config.id2label[key] for key in indices],
|
| 25 |
+
"probability": topk_anss.values.flatten().numpy()
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
return out_df
|
| 29 |
+
|
| 30 |
+
demo = gr.Interface(
|
| 31 |
+
fn = predict,
|
| 32 |
+
# we use the type='pil' parameter so that gradio passes to our function
|
| 33 |
+
# a picture that is already in the PIL format,
|
| 34 |
+
# see https://www.gradio.app/docs/gradio/image#description
|
| 35 |
+
inputs = [gr.Image(type="pil"),
|
| 36 |
+
"textbox",
|
| 37 |
+
# value is the default value, it can be lower than 1
|
| 38 |
+
gr.Number(value=4, minimum=1)],
|
| 39 |
+
outputs = gr.BarPlot(x="answer", y="probability",
|
| 40 |
+
title="Multi-class probabilities")
|
| 41 |
+
# outputs="dataframe" # uncomment if the gradio interface is unresponsive
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pandas
|
| 3 |
+
transformers
|
| 4 |
+
gradio
|