sanjay-906 commited on
Commit
ad73ef8
·
verified ·
1 Parent(s): dac0ce7

Create VQA.py

Browse files
Files changed (1) hide show
  1. VQA.py +56 -0
VQA.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional
5
+ fin= open("answer_space.txt")
6
+ answer_space= fin.read().splitlines()
7
+
8
+ class VQA(nn.Module):
9
+ def __init__(self,
10
+ text_encoder_name= 'bert-base-uncased',
11
+ image_encoder_name= 'google/vit-base-patch16-224-in21k',
12
+ num_labels= 582):
13
+ super(VQA, self).__init__()
14
+ self.num_labels= num_labels
15
+ self.text_encoder_name= text_encoder_name
16
+ self.image_encoder_name= image_encoder_name
17
+ self.text_encoder= AutoModel.from_pretrained(self.text_encoder_name)
18
+ self.image_encoder= AutoModel.from_pretrained(self.image_encoder_name)
19
+
20
+ # 768 + 768
21
+ self.combine= nn.Sequential(
22
+ nn.Linear(self.text_encoder.config.hidden_size+ self.image_encoder.config.hidden_size, 1059),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.6)
25
+ )
26
+ self.layer1= nn.Linear(1059, 1059)
27
+ self.classifier= nn.Linear(1059, self.num_labels)
28
+ self.criterion= nn.CrossEntropyLoss()
29
+
30
+ def forward(self, input_ids, pixel_values, attention_mask, token_type_ids, labels: Optional[torch.LongTensor]= None):
31
+ encoded_text= self.text_encoder(
32
+ input_ids= input_ids,
33
+ attention_mask= attention_mask,
34
+ token_type_ids= token_type_ids,
35
+ return_dict= True
36
+ )
37
+ encoded_image= self.image_encoder(
38
+ pixel_values= pixel_values,
39
+ return_dict= True
40
+ )
41
+ combined_output= self.combine(
42
+ torch.cat(
43
+ [
44
+ encoded_text['pooler_output'],
45
+ encoded_image['pooler_output']
46
+ ],
47
+ dim= 1
48
+ )
49
+ )
50
+ logits= self.classifier(combined_output)
51
+ output= {'logits': logits}
52
+ if labels is not None:
53
+ loss= self.criterion(logits, labels)
54
+ output['loss']= loss
55
+
56
+ return output