import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering import pandas as pd import torch processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") def predict(img, prompt, return_topk): encoding = processor(img, prompt, return_tensors="pt") outputs = model(**encoding) with torch.no_grad(): probs = torch.nn.Sigmoid()(outputs.logits) topk_anss = torch.topk(probs, return_topk) indices = topk_anss.indices.flatten().numpy() values = topk_anss.values.flatten().numpy() out_df = pd.DataFrame({ "answer": [model.config.id2label[i] for i in indices], "probability": values }) return out_df demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Question"), gr.Number(value=4, minimum=1, label="Top-K") ], outputs=gr.BarPlot( x="answer", y="probability", title="Predicted Answers" ) ) demo.launch()