sanjay-906 commited on
Commit
dac0ce7
·
verified ·
1 Parent(s): 05fa164

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel
3
+ import pickle
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+ import matplotlib.pyplot as plt
8
+ import VQA
9
+
10
+ fin= open("answer_space.txt")
11
+ answer_space= fin.read().splitlines()
12
+ pkl_file = open("text-encoder.pickle", 'rb')
13
+ tokenizer= pickle.load(pkl_file)
14
+ pkl_file.close()
15
+
16
+ pkl_file = open("image-encoder.pickle", 'rb')
17
+ preprocessor= pickle.load(pkl_file)
18
+ pkl_file.close()
19
+
20
+ def encode_text(text_encoder, question, device):
21
+ text_encoder= tokenizer
22
+ encoded_text= text_encoder(
23
+ text= [question],
24
+ padding= 'longest',
25
+ max_length= 24,
26
+ truncation= True,
27
+ return_tensors= 'pt',
28
+ return_token_type_ids= True,
29
+ return_attention_mask= True,
30
+ )
31
+
32
+ return {
33
+ "input_ids": encoded_text['input_ids'].to(device),
34
+ "token_type_ids": encoded_text['token_type_ids'].to(device),
35
+ "attention_mask": encoded_text['attention_mask'].to(device),
36
+ }
37
+
38
+ def encode_image(image_encoder, image, device):
39
+ image_encoder= preprocessor
40
+ processed_images= image_encoder(
41
+ images=[image],
42
+ return_tensors="pt",
43
+ )
44
+ return {
45
+ "pixel_values": processed_images['pixel_values'].to(device),
46
+ }
47
+
48
+ def get_inputs(question, image):
49
+ question_encoded= question.lower().replace("?", "").strip()
50
+ question_encoded= encode_text("bert-base-uncased", question, device)
51
+ image= encode_image("google/vit-base-patch16-224-in21k", image, device)
52
+
53
+ model.eval()
54
+
55
+ input_ids= question_encoded["input_ids"].to(device)
56
+ token_type_ids= question_encoded["token_type_ids"].to(device)
57
+ attention_mask= question_encoded["attention_mask"].to(device)
58
+ pixel_values= image["pixel_values"].to(device)
59
+ output= model(input_ids, pixel_values, attention_mask, token_type_ids)
60
+
61
+ preds= output["logits"].argmax(axis=-1).cpu().numpy()
62
+ answer= answer_space[preds[0]]
63
+
64
+ return answer
65
+
66
+ model= VQA.VQA()
67
+ device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68
+ model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu')))
69
+ model.to(device)
70
+
71
+ interface= gr.Interface(fn= get_inputs,
72
+ inputs= ["text", "image"],
73
+ outputs= "text"
74
+ )
75
+ interface.launch()