GAP-week7 / app.py
hatethisworld's picture
Add app.py
5671150
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
import pandas as pd
import torch
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def predict(img, prompt, return_topk):
encoding = processor(img, prompt, return_tensors="pt")
outputs = model(**encoding)
with torch.no_grad():
probs = torch.nn.Sigmoid()(outputs.logits)
topk_anss = torch.topk(probs, return_topk)
indices = topk_anss.indices.flatten().numpy()
values = topk_anss.values.flatten().numpy()
out_df = pd.DataFrame({
"answer": [model.config.id2label[i] for i in indices],
"probability": values
})
return out_df
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Question"),
gr.Number(value=4, minimum=1, label="Top-K")
],
outputs=gr.BarPlot(
x="answer",
y="probability",
title="Predicted Answers"
)
)
demo.launch()