import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering import pandas as pd import torch device = "cuda" if torch.cuda.is_available() else "cpu" processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) def predict(img, prompt, return_topk): encoding = processor(img, prompt, return_tensors="pt") outputs = model(**encoding) with torch.no_grad(): probs = torch.nn.Sigmoid()(outputs.logits) topk_anss = torch.topk(probs, return_topk) # these are the indices of the top-k outputs indices = topk_anss.indices.flatten().numpy() # create a dataframe with two columns/series: # class labels and corresponding probabilities out_df = pd.DataFrame( { "answer": [model.config.id2label[key] for key in indices], "probability": topk_anss.values.flatten().numpy() } ) return out_df demo = gr.Interface( fn = predict, # we use the type='pil' parameter so that gradio passes to our function # a picture that is already in the PIL format, # see https://www.gradio.app/docs/gradio/image#description inputs = [gr.Image(type="pil"), "textbox", # value is the default value, it can be lower than 1 gr.Number(value=4, minimum=1)], outputs = gr.BarPlot(x="answer", y="probability", title="Multi-class probabilities") # outputs="dataframe" # uncomment if the gradio interface is unresponsive ) demo.launch()