yfyangd commited on
Commit
19449ba
·
1 Parent(s): 9827d9c

Upload blip_itm.py

Browse files
Files changed (1) hide show
  1. BLIP/models/blip_itm.py +76 -0
BLIP/models/blip_itm.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig, BertModel
2
+ from transformers import BertTokenizer
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
9
+
10
+ class BLIP_ITM(nn.Module):
11
+ def __init__(self,
12
+ med_config = 'configs/med_config.json',
13
+ image_size = 384,
14
+ vit = 'base',
15
+ vit_grad_ckpt = False,
16
+ vit_ckpt_layer = 0,
17
+ embed_dim = 256,
18
+ ):
19
+ """
20
+ Args:
21
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
22
+ image_size (int): input image size
23
+ vit (str): model size of vision transformer
24
+ """
25
+ super().__init__()
26
+
27
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28
+ self.tokenizer = init_tokenizer()
29
+ med_config = BertConfig.from_json_file(med_config)
30
+ med_config.encoder_width = vision_width
31
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32
+
33
+ text_width = self.text_encoder.config.hidden_size
34
+
35
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
36
+ self.text_proj = nn.Linear(text_width, embed_dim)
37
+
38
+ self.itm_head = nn.Linear(text_width, 2)
39
+
40
+
41
+ def forward(self, image, caption, match_head='itm'):
42
+
43
+ image_embeds = self.visual_encoder(image)
44
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45
+
46
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47
+ return_tensors="pt").to(image.device)
48
+
49
+
50
+ if match_head=='itm':
51
+ output = self.text_encoder(text.input_ids,
52
+ attention_mask = text.attention_mask,
53
+ encoder_hidden_states = image_embeds,
54
+ encoder_attention_mask = image_atts,
55
+ return_dict = True,
56
+ )
57
+ itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58
+ return itm_output
59
+
60
+ elif match_head=='itc':
61
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62
+ return_dict = True, mode = 'text')
63
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65
+
66
+ sim = image_feat @ text_feat.t()
67
+ return sim
68
+
69
+
70
+ def blip_itm(pretrained='',**kwargs):
71
+ model = BLIP_ITM(**kwargs)
72
+ if pretrained:
73
+ model,msg = load_checkpoint(model,pretrained)
74
+ assert(len(msg.missing_keys)==0)
75
+ return model
76
+