Spaces:
Paused
Paused
File size: 1,116 Bytes
5671150 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
import pandas as pd
import torch
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def predict(img, prompt, return_topk):
encoding = processor(img, prompt, return_tensors="pt")
outputs = model(**encoding)
with torch.no_grad():
probs = torch.nn.Sigmoid()(outputs.logits)
topk_anss = torch.topk(probs, return_topk)
indices = topk_anss.indices.flatten().numpy()
values = topk_anss.values.flatten().numpy()
out_df = pd.DataFrame({
"answer": [model.config.id2label[i] for i in indices],
"probability": values
})
return out_df
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Question"),
gr.Number(value=4, minimum=1, label="Top-K")
],
outputs=gr.BarPlot(
x="answer",
y="probability",
title="Predicted Answers"
)
)
demo.launch()
|