Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.registry import registry | |
| from lavis.models.blip_models.blip import BlipBase | |
| from torch import nn | |
| from lavis.models.med import XBertEncoder | |
| from lavis.models.vit import VisionTransformerEncoder | |
| class BlipITM(BlipBase): | |
| """ | |
| BLIP Image-Text Matching (ITM) model. | |
| Supported model types: | |
| - base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split). | |
| - large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split). | |
| Usage: | |
| >>> from lavis.models import load_model | |
| >>> model = load_model("blip_image_text_matching", "base") | |
| >>> model = load_model("blip_image_text_matching", "large") | |
| """ | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "base": "configs/models/blip_itm_base.yaml", | |
| "large": "configs/models/blip_itm_large.yaml", | |
| } | |
| def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35): | |
| super().__init__() | |
| self.tokenizer = self.init_tokenizer() | |
| self.text_encoder = text_encoder | |
| self.visual_encoder = image_encoder | |
| self.max_txt_len = max_txt_len | |
| # creating projection layers for ITC | |
| text_width = text_encoder.config.hidden_size | |
| vision_width = image_encoder.vision_width | |
| self.vision_proj = nn.Linear(vision_width, embed_dim) | |
| self.text_proj = nn.Linear(text_width, embed_dim) | |
| self.itm_head = nn.Linear(text_width, 2) | |
| def forward(self, samples, match_head="itm"): | |
| image = samples["image"] | |
| caption = samples["text_input"] | |
| image_embeds = self.visual_encoder.forward_features(image) | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| image.device | |
| ) | |
| text = self.tokenizer( | |
| caption, | |
| padding="longest", | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| return_tensors="pt", | |
| ).to(image.device) | |
| if match_head == "itm": | |
| encoder_input_ids = text.input_ids.clone() | |
| encoder_input_ids[:, 0] = self.tokenizer.enc_token_id # extra code | |
| output = self.text_encoder( | |
| encoder_input_ids, | |
| attention_mask=text.attention_mask, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) | |
| return itm_output | |
| elif match_head == "itc": | |
| text_output = self.text_encoder( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| return_dict=True, | |
| mode="text", | |
| ) | |
| image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) | |
| text_feat = F.normalize( | |
| self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 | |
| ) | |
| sim = image_feat @ text_feat.t() | |
| return sim | |
| def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'): | |
| # breakpoint() | |
| encoder_input_ids = encoder_input_ids.clone() | |
| encoder_input_ids = encoder_input_ids[:, 3:] | |
| text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long() | |
| if match_head == 'itm': | |
| # encoder_input_ids = encoder_input_ids.clone() | |
| encoder_input_ids[:, 0] = self.tokenizer.enc_token_id | |
| output = self.text_encoder(encoder_input_ids, | |
| attention_mask=text_attention_mask, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| # print(output.last_hidden_state.shape) | |
| itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) | |
| itm_output = F.softmax(itm_output, dim=1)[:,1] | |
| return itm_output #, mask, token_length | |
| elif match_head == 'itc': | |
| encoder_input_ids[:, 0] = self.tokenizer.cls_token_id | |
| text_output = self.text_encoder(encoder_input_ids, attention_mask=text_attention_mask, | |
| return_dict=True, mode='text') | |
| image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) | |
| text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1) | |
| sim = image_feat @ text_feat.t() | |
| return sim | |
| def from_config(cls, cfg=None): | |
| image_encoder = VisionTransformerEncoder.from_config(cfg) | |
| text_encoder = XBertEncoder.from_config(cfg) | |
| embed_dim = cfg.get("embed_dim", 256) | |
| max_txt_len = cfg.get("max_txt_len", 35) | |
| model = cls( | |
| image_encoder=image_encoder, | |
| text_encoder=text_encoder, | |
| embed_dim=embed_dim, | |
| max_txt_len=max_txt_len, | |
| ) | |
| model.load_checkpoint_from_config(cfg) | |
| return model | |
| def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6): | |
| model.text_encoder.base_model.base_model.encoder.layer[ | |
| block_num | |
| ].crossattention.self.save_attention = True | |
| output = model({"image": visual_input, "text_input": text_input}, match_head="itm") | |
| loss = output[:, 1].sum() | |
| model.zero_grad() | |
| loss.backward() | |
| with torch.no_grad(): | |
| mask = tokenized_text.attention_mask.view( | |
| tokenized_text.attention_mask.size(0), 1, -1, 1, 1 | |
| ) # (bsz,1,token_len, 1,1) | |
| token_length = tokenized_text.attention_mask.sum(dim=-1) - 2 | |
| token_length = token_length.cpu() | |
| # grads and cams [bsz, num_head, seq_len, image_patch] | |
| grads = model.text_encoder.base_model.base_model.encoder.layer[ | |
| block_num | |
| ].crossattention.self.get_attn_gradients() | |
| cams = model.text_encoder.base_model.base_model.encoder.layer[ | |
| block_num | |
| ].crossattention.self.get_attention_map() | |
| # assume using vit with 576 num image patch | |
| cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask | |
| grads = ( | |
| grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) | |
| * mask | |
| ) | |
| gradcams = cams * grads | |
| gradcam_list = [] | |
| for ind in range(visual_input.size(0)): | |
| token_length_ = token_length[ind] | |
| gradcam = gradcams[ind].mean(0).cpu().detach() | |
| # [enc token gradcam, average gradcam across token, gradcam for individual token] | |
| gradcam = torch.cat( | |
| ( | |
| gradcam[0:1, :], | |
| gradcam[1 : token_length_ + 1, :].sum(dim=0, keepdim=True) | |
| / token_length_, | |
| gradcam[1:, :], | |
| ) | |
| ) | |
| gradcam_list.append(gradcam) | |
| return gradcam_list, output | |