|
|
import gradio as gr |
|
|
from transformers import ViltProcessor, ViltForQuestionAnswering |
|
|
import pandas as pd |
|
|
import torch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
|
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) |
|
|
|
|
|
|
|
|
def predict(img, prompt, return_topk): |
|
|
encoding = processor(img, prompt, return_tensors="pt") |
|
|
outputs = model(**encoding) |
|
|
with torch.no_grad(): |
|
|
probs = torch.nn.Sigmoid()(outputs.logits) |
|
|
topk_anss = torch.topk(probs, return_topk) |
|
|
|
|
|
indices = topk_anss.indices.flatten().numpy() |
|
|
|
|
|
|
|
|
out_df = pd.DataFrame( |
|
|
{ |
|
|
"answer": [model.config.id2label[key] for key in indices], |
|
|
"probability": topk_anss.values.flatten().numpy() |
|
|
} |
|
|
) |
|
|
return out_df |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn = predict, |
|
|
|
|
|
|
|
|
|
|
|
inputs = [gr.Image(type="pil"), |
|
|
"textbox", |
|
|
|
|
|
gr.Number(value=4, minimum=1)], |
|
|
outputs = gr.BarPlot(x="answer", y="probability", |
|
|
title="Multi-class probabilities") |
|
|
|
|
|
) |
|
|
|
|
|
demo.launch() |