File size: 2,276 Bytes
dac0ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import AutoModel
import pickle
import torch
import torch.nn as nn
from typing import Optional
import matplotlib.pyplot as plt
import VQA

fin= open("answer_space.txt")
answer_space= fin.read().splitlines()
pkl_file = open("text-encoder.pickle", 'rb')
tokenizer= pickle.load(pkl_file)
pkl_file.close()

pkl_file = open("image-encoder.pickle", 'rb')
preprocessor= pickle.load(pkl_file)
pkl_file.close()

def encode_text(text_encoder, question, device):
  text_encoder= tokenizer
  encoded_text= text_encoder(
      text= [question],
      padding= 'longest',
      max_length= 24,
      truncation= True,
      return_tensors= 'pt',
      return_token_type_ids= True,
      return_attention_mask= True,
  )

  return {
      "input_ids": encoded_text['input_ids'].to(device),
      "token_type_ids": encoded_text['token_type_ids'].to(device),
      "attention_mask": encoded_text['attention_mask'].to(device),
  }

def encode_image(image_encoder, image, device):
      image_encoder= preprocessor
      processed_images= image_encoder(
          images=[image],
          return_tensors="pt",
      )
      return {
          "pixel_values": processed_images['pixel_values'].to(device),
      }

def get_inputs(question, image):
  question_encoded= question.lower().replace("?", "").strip()
  question_encoded= encode_text("bert-base-uncased", question, device)
  image= encode_image("google/vit-base-patch16-224-in21k", image, device)

  model.eval()

  input_ids= question_encoded["input_ids"].to(device)
  token_type_ids= question_encoded["token_type_ids"].to(device)
  attention_mask= question_encoded["attention_mask"].to(device)
  pixel_values= image["pixel_values"].to(device)
  output= model(input_ids, pixel_values, attention_mask, token_type_ids)

  preds= output["logits"].argmax(axis=-1).cpu().numpy()
  answer= answer_space[preds[0]]

  return answer
    
model= VQA.VQA()
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load("state_dict_model.pt", map_location= torch.device('cpu')))
model.to(device)

interface= gr.Interface(fn= get_inputs,
                        inputs= ["text", "image"],
                        outputs= "text"
                        )
interface.launch()