Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import ViltProcessor, ViltForQuestionAnswering | |
| import pandas as pd | |
| import torch | |
| processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
| model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
| def predict(img, prompt, return_topk): | |
| encoding = processor(img, prompt, return_tensors="pt") | |
| outputs = model(**encoding) | |
| with torch.no_grad(): | |
| probs = torch.nn.Sigmoid()(outputs.logits) | |
| topk_anss = torch.topk(probs, return_topk) | |
| indices = topk_anss.indices.flatten().numpy() | |
| values = topk_anss.values.flatten().numpy() | |
| out_df = pd.DataFrame({ | |
| "answer": [model.config.id2label[i] for i in indices], | |
| "probability": values | |
| }) | |
| return out_df | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Textbox(label="Question"), | |
| gr.Number(value=4, minimum=1, label="Top-K") | |
| ], | |
| outputs=gr.BarPlot( | |
| x="answer", | |
| y="probability", | |
| title="Predicted Answers" | |
| ) | |
| ) | |
| demo.launch() | |