Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
import pickle
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import VQA
|
| 9 |
+
|
| 10 |
+
fin= open("answer_space.txt")
|
| 11 |
+
answer_space= fin.read().splitlines()
|
| 12 |
+
pkl_file = open("text-encoder.pickle", 'rb')
|
| 13 |
+
tokenizer= pickle.load(pkl_file)
|
| 14 |
+
pkl_file.close()
|
| 15 |
+
|
| 16 |
+
pkl_file = open("image-encoder.pickle", 'rb')
|
| 17 |
+
preprocessor= pickle.load(pkl_file)
|
| 18 |
+
pkl_file.close()
|
| 19 |
+
|
| 20 |
+
def encode_text(text_encoder, question, device):
|
| 21 |
+
text_encoder= tokenizer
|
| 22 |
+
encoded_text= text_encoder(
|
| 23 |
+
text= [question],
|
| 24 |
+
padding= 'longest',
|
| 25 |
+
max_length= 24,
|
| 26 |
+
truncation= True,
|
| 27 |
+
return_tensors= 'pt',
|
| 28 |
+
return_token_type_ids= True,
|
| 29 |
+
return_attention_mask= True,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return {
|
| 33 |
+
"input_ids": encoded_text['input_ids'].to(device),
|
| 34 |
+
"token_type_ids": encoded_text['token_type_ids'].to(device),
|
| 35 |
+
"attention_mask": encoded_text['attention_mask'].to(device),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def encode_image(image_encoder, image, device):
|
| 39 |
+
image_encoder= preprocessor
|
| 40 |
+
processed_images= image_encoder(
|
| 41 |
+
images=[image],
|
| 42 |
+
return_tensors="pt",
|
| 43 |
+
)
|
| 44 |
+
return {
|
| 45 |
+
"pixel_values": processed_images['pixel_values'].to(device),
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def get_inputs(question, image):
|
| 49 |
+
question_encoded= question.lower().replace("?", "").strip()
|
| 50 |
+
question_encoded= encode_text("bert-base-uncased", question, device)
|
| 51 |
+
image= encode_image("google/vit-base-patch16-224-in21k", image, device)
|
| 52 |
+
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
input_ids= question_encoded["input_ids"].to(device)
|
| 56 |
+
token_type_ids= question_encoded["token_type_ids"].to(device)
|
| 57 |
+
attention_mask= question_encoded["attention_mask"].to(device)
|
| 58 |
+
pixel_values= image["pixel_values"].to(device)
|
| 59 |
+
output= model(input_ids, pixel_values, attention_mask, token_type_ids)
|
| 60 |
+
|
| 61 |
+
preds= output["logits"].argmax(axis=-1).cpu().numpy()
|
| 62 |
+
answer= answer_space[preds[0]]
|
| 63 |
+
|
| 64 |
+
return answer
|
| 65 |
+
|
| 66 |
+
model= VQA.VQA()
|
| 67 |
+
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 68 |
+
model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu')))
|
| 69 |
+
model.to(device)
|
| 70 |
+
|
| 71 |
+
interface= gr.Interface(fn= get_inputs,
|
| 72 |
+
inputs= ["text", "image"],
|
| 73 |
+
outputs= "text"
|
| 74 |
+
)
|
| 75 |
+
interface.launch()
|