import gradio as gr from transformers import AutoModel import pickle import torch import torch.nn as nn from typing import Optional import matplotlib.pyplot as plt import VQA fin= open("answer_space.txt") answer_space= fin.read().splitlines() pkl_file = open("text-encoder.pickle", 'rb') tokenizer= pickle.load(pkl_file) pkl_file.close() pkl_file = open("image-encoder.pickle", 'rb') preprocessor= pickle.load(pkl_file) pkl_file.close() def encode_text(text_encoder, question, device): text_encoder= tokenizer encoded_text= text_encoder( text= [question], padding= 'longest', max_length= 24, truncation= True, return_tensors= 'pt', return_token_type_ids= True, return_attention_mask= True, ) return { "input_ids": encoded_text['input_ids'].to(device), "token_type_ids": encoded_text['token_type_ids'].to(device), "attention_mask": encoded_text['attention_mask'].to(device), } def encode_image(image_encoder, image, device): image_encoder= preprocessor processed_images= image_encoder( images=[image], return_tensors="pt", ) return { "pixel_values": processed_images['pixel_values'].to(device), } def get_inputs(question, image): question_encoded= question.lower().replace("?", "").strip() question_encoded= encode_text("bert-base-uncased", question, device) image= encode_image("google/vit-base-patch16-224-in21k", image, device) model.eval() input_ids= question_encoded["input_ids"].to(device) token_type_ids= question_encoded["token_type_ids"].to(device) attention_mask= question_encoded["attention_mask"].to(device) pixel_values= image["pixel_values"].to(device) output= model(input_ids, pixel_values, attention_mask, token_type_ids) preds= output["logits"].argmax(axis=-1).cpu().numpy() answer= answer_space[preds[0]] return answer model= VQA.VQA() device= torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu'))) model.to(device) interface= gr.Interface(fn= get_inputs, inputs= ["text", "image"], outputs= "text" ) interface.launch()