| from extras.BLIP.models.med import BertConfig, BertModel |
| from transformers import BertTokenizer |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint |
|
|
| class BLIP_ITM(nn.Module): |
| def __init__(self, |
| med_config = 'configs/med_config.json', |
| image_size = 384, |
| vit = 'base', |
| vit_grad_ckpt = False, |
| vit_ckpt_layer = 0, |
| embed_dim = 256, |
| ): |
| """ |
| Args: |
| med_config (str): path for the mixture of encoder-decoder model's configuration file |
| image_size (int): input image size |
| vit (str): model size of vision transformer |
| """ |
| super().__init__() |
| |
| self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) |
| self.tokenizer = init_tokenizer() |
| med_config = BertConfig.from_json_file(med_config) |
| med_config.encoder_width = vision_width |
| self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) |
|
|
| text_width = self.text_encoder.config.hidden_size |
| |
| self.vision_proj = nn.Linear(vision_width, embed_dim) |
| self.text_proj = nn.Linear(text_width, embed_dim) |
|
|
| self.itm_head = nn.Linear(text_width, 2) |
| |
| |
| def forward(self, image, caption, match_head='itm'): |
|
|
| image_embeds = self.visual_encoder(image) |
| image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
| |
| text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, |
| return_tensors="pt").to(image.device) |
|
|
| |
| if match_head=='itm': |
| output = self.text_encoder(text.input_ids, |
| attention_mask = text.attention_mask, |
| encoder_hidden_states = image_embeds, |
| encoder_attention_mask = image_atts, |
| return_dict = True, |
| ) |
| itm_output = self.itm_head(output.last_hidden_state[:,0,:]) |
| return itm_output |
| |
| elif match_head=='itc': |
| text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, |
| return_dict = True, mode = 'text') |
| image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
| text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) |
| |
| sim = image_feat @ text_feat.t() |
| return sim |
| |
| |
| def blip_itm(pretrained='',**kwargs): |
| model = BLIP_ITM(**kwargs) |
| if pretrained: |
| model,msg = load_checkpoint(model,pretrained) |
| assert(len(msg.missing_keys)==0) |
| return model |
| |