e-p commited on
Commit
d471219
·
1 Parent(s): d3b7c3f
Files changed (2) hide show
  1. app.py +44 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering
3
+ import pandas as pd
4
+ import torch
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
9
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
10
+
11
+
12
+ def predict(img, prompt, return_topk):
13
+ encoding = processor(img, prompt, return_tensors="pt")
14
+ outputs = model(**encoding)
15
+ with torch.no_grad():
16
+ probs = torch.nn.Sigmoid()(outputs.logits)
17
+ topk_anss = torch.topk(probs, return_topk)
18
+ # these are the indices of the top-k outputs
19
+ indices = topk_anss.indices.flatten().numpy()
20
+ # create a dataframe with two columns/series:
21
+ # class labels and corresponding probabilities
22
+ out_df = pd.DataFrame(
23
+ {
24
+ "answer": [model.config.id2label[key] for key in indices],
25
+ "probability": topk_anss.values.flatten().numpy()
26
+ }
27
+ )
28
+ return out_df
29
+
30
+ demo = gr.Interface(
31
+ fn = predict,
32
+ # we use the type='pil' parameter so that gradio passes to our function
33
+ # a picture that is already in the PIL format,
34
+ # see https://www.gradio.app/docs/gradio/image#description
35
+ inputs = [gr.Image(type="pil"),
36
+ "textbox",
37
+ # value is the default value, it can be lower than 1
38
+ gr.Number(value=4, minimum=1)],
39
+ outputs = gr.BarPlot(x="answer", y="probability",
40
+ title="Multi-class probabilities")
41
+ # outputs="dataframe" # uncomment if the gradio interface is unresponsive
42
+ )
43
+
44
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ transformers
4
+ gradio