VQA / app.py
sanjay-906's picture
Create app.py
dac0ce7 verified
import gradio as gr
from transformers import AutoModel
import pickle
import torch
import torch.nn as nn
from typing import Optional
import matplotlib.pyplot as plt
import VQA
fin= open("answer_space.txt")
answer_space= fin.read().splitlines()
pkl_file = open("text-encoder.pickle", 'rb')
tokenizer= pickle.load(pkl_file)
pkl_file.close()
pkl_file = open("image-encoder.pickle", 'rb')
preprocessor= pickle.load(pkl_file)
pkl_file.close()
def encode_text(text_encoder, question, device):
text_encoder= tokenizer
encoded_text= text_encoder(
text= [question],
padding= 'longest',
max_length= 24,
truncation= True,
return_tensors= 'pt',
return_token_type_ids= True,
return_attention_mask= True,
)
return {
"input_ids": encoded_text['input_ids'].to(device),
"token_type_ids": encoded_text['token_type_ids'].to(device),
"attention_mask": encoded_text['attention_mask'].to(device),
}
def encode_image(image_encoder, image, device):
image_encoder= preprocessor
processed_images= image_encoder(
images=[image],
return_tensors="pt",
)
return {
"pixel_values": processed_images['pixel_values'].to(device),
}
def get_inputs(question, image):
question_encoded= question.lower().replace("?", "").strip()
question_encoded= encode_text("bert-base-uncased", question, device)
image= encode_image("google/vit-base-patch16-224-in21k", image, device)
model.eval()
input_ids= question_encoded["input_ids"].to(device)
token_type_ids= question_encoded["token_type_ids"].to(device)
attention_mask= question_encoded["attention_mask"].to(device)
pixel_values= image["pixel_values"].to(device)
output= model(input_ids, pixel_values, attention_mask, token_type_ids)
preds= output["logits"].argmax(axis=-1).cpu().numpy()
answer= answer_space[preds[0]]
return answer
model= VQA.VQA()
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu')))
model.to(device)
interface= gr.Interface(fn= get_inputs,
inputs= ["text", "image"],
outputs= "text"
)
interface.launch()