Spaces:
Build error
Build error
File size: 2,276 Bytes
dac0ce7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import gradio as gr
from transformers import AutoModel
import pickle
import torch
import torch.nn as nn
from typing import Optional
import matplotlib.pyplot as plt
import VQA
fin= open("answer_space.txt")
answer_space= fin.read().splitlines()
pkl_file = open("text-encoder.pickle", 'rb')
tokenizer= pickle.load(pkl_file)
pkl_file.close()
pkl_file = open("image-encoder.pickle", 'rb')
preprocessor= pickle.load(pkl_file)
pkl_file.close()
def encode_text(text_encoder, question, device):
text_encoder= tokenizer
encoded_text= text_encoder(
text= [question],
padding= 'longest',
max_length= 24,
truncation= True,
return_tensors= 'pt',
return_token_type_ids= True,
return_attention_mask= True,
)
return {
"input_ids": encoded_text['input_ids'].to(device),
"token_type_ids": encoded_text['token_type_ids'].to(device),
"attention_mask": encoded_text['attention_mask'].to(device),
}
def encode_image(image_encoder, image, device):
image_encoder= preprocessor
processed_images= image_encoder(
images=[image],
return_tensors="pt",
)
return {
"pixel_values": processed_images['pixel_values'].to(device),
}
def get_inputs(question, image):
question_encoded= question.lower().replace("?", "").strip()
question_encoded= encode_text("bert-base-uncased", question, device)
image= encode_image("google/vit-base-patch16-224-in21k", image, device)
model.eval()
input_ids= question_encoded["input_ids"].to(device)
token_type_ids= question_encoded["token_type_ids"].to(device)
attention_mask= question_encoded["attention_mask"].to(device)
pixel_values= image["pixel_values"].to(device)
output= model(input_ids, pixel_values, attention_mask, token_type_ids)
preds= output["logits"].argmax(axis=-1).cpu().numpy()
answer= answer_space[preds[0]]
return answer
model= VQA.VQA()
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu')))
model.to(device)
interface= gr.Interface(fn= get_inputs,
inputs= ["text", "image"],
outputs= "text"
)
interface.launch() |