vqa-vilt / app.py
e-p's picture
app demo
d471219
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
import pandas as pd
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
def predict(img, prompt, return_topk):
encoding = processor(img, prompt, return_tensors="pt")
outputs = model(**encoding)
with torch.no_grad():
probs = torch.nn.Sigmoid()(outputs.logits)
topk_anss = torch.topk(probs, return_topk)
# these are the indices of the top-k outputs
indices = topk_anss.indices.flatten().numpy()
# create a dataframe with two columns/series:
# class labels and corresponding probabilities
out_df = pd.DataFrame(
{
"answer": [model.config.id2label[key] for key in indices],
"probability": topk_anss.values.flatten().numpy()
}
)
return out_df
demo = gr.Interface(
fn = predict,
# we use the type='pil' parameter so that gradio passes to our function
# a picture that is already in the PIL format,
# see https://www.gradio.app/docs/gradio/image#description
inputs = [gr.Image(type="pil"),
"textbox",
# value is the default value, it can be lower than 1
gr.Number(value=4, minimum=1)],
outputs = gr.BarPlot(x="answer", y="probability",
title="Multi-class probabilities")
# outputs="dataframe" # uncomment if the gradio interface is unresponsive
)
demo.launch()