Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModel | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional | |
| import matplotlib.pyplot as plt | |
| import VQA | |
| fin= open("answer_space.txt") | |
| answer_space= fin.read().splitlines() | |
| pkl_file = open("text-encoder.pickle", 'rb') | |
| tokenizer= pickle.load(pkl_file) | |
| pkl_file.close() | |
| pkl_file = open("image-encoder.pickle", 'rb') | |
| preprocessor= pickle.load(pkl_file) | |
| pkl_file.close() | |
| def encode_text(text_encoder, question, device): | |
| text_encoder= tokenizer | |
| encoded_text= text_encoder( | |
| text= [question], | |
| padding= 'longest', | |
| max_length= 24, | |
| truncation= True, | |
| return_tensors= 'pt', | |
| return_token_type_ids= True, | |
| return_attention_mask= True, | |
| ) | |
| return { | |
| "input_ids": encoded_text['input_ids'].to(device), | |
| "token_type_ids": encoded_text['token_type_ids'].to(device), | |
| "attention_mask": encoded_text['attention_mask'].to(device), | |
| } | |
| def encode_image(image_encoder, image, device): | |
| image_encoder= preprocessor | |
| processed_images= image_encoder( | |
| images=[image], | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "pixel_values": processed_images['pixel_values'].to(device), | |
| } | |
| def get_inputs(question, image): | |
| question_encoded= question.lower().replace("?", "").strip() | |
| question_encoded= encode_text("bert-base-uncased", question, device) | |
| image= encode_image("google/vit-base-patch16-224-in21k", image, device) | |
| model.eval() | |
| input_ids= question_encoded["input_ids"].to(device) | |
| token_type_ids= question_encoded["token_type_ids"].to(device) | |
| attention_mask= question_encoded["attention_mask"].to(device) | |
| pixel_values= image["pixel_values"].to(device) | |
| output= model(input_ids, pixel_values, attention_mask, token_type_ids) | |
| preds= output["logits"].argmax(axis=-1).cpu().numpy() | |
| answer= answer_space[preds[0]] | |
| return answer | |
| model= VQA.VQA() | |
| device= torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu'))) | |
| model.to(device) | |
| interface= gr.Interface(fn= get_inputs, | |
| inputs= ["text", "image"], | |
| outputs= "text" | |
| ) | |
| interface.launch() |