Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.registry import registry | |
| from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
| class Blip2ITM(Blip2Qformer): | |
| """ | |
| BLIP Image-Text Matching (ITM) model. | |
| Supported model types: | |
| - pretrained: pretrained model | |
| - coco: fintuned model on coco | |
| Usage: | |
| >>> from lavis.models import load_model | |
| >>> model = load_model("blip2_image_text_matching", "pretrained") | |
| >>> model = load_model("blip2_image_text_matching", "coco") | |
| """ | |
| def __init__( | |
| self, | |
| vit_model="eva_clip_g", | |
| img_size=224, | |
| drop_path_rate=0, | |
| use_grad_checkpoint=False, | |
| vit_precision="fp16", | |
| freeze_vit=True, | |
| num_query_token=32, | |
| cross_attention_freq=2, | |
| embed_dim=256, | |
| max_txt_len=32, | |
| ): | |
| super().__init__( | |
| vit_model=vit_model, | |
| img_size=img_size, | |
| drop_path_rate=drop_path_rate, | |
| use_grad_checkpoint=use_grad_checkpoint, | |
| vit_precision=vit_precision, | |
| freeze_vit=freeze_vit, | |
| num_query_token=num_query_token, | |
| cross_attention_freq=cross_attention_freq, | |
| embed_dim=embed_dim, | |
| max_txt_len=max_txt_len, | |
| ) | |
| def forward(self, samples, match_head="itm"): | |
| image = samples["image"] | |
| caption = samples["text_input"] | |
| with self.maybe_autocast(): | |
| image_embeds = self.ln_vision(self.visual_encoder(image)) | |
| image_embeds = image_embeds.float() | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| image.device | |
| ) | |
| text = self.tokenizer( | |
| caption, | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| return_tensors="pt", | |
| ).to(image.device) | |
| if match_head == "itm": | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( | |
| image.device | |
| ) | |
| attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) | |
| output_itm = self.Qformer.bert( | |
| text.input_ids, | |
| query_embeds=query_tokens, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] | |
| itm_logit = self.itm_head(itm_embeddings) | |
| itm_logit = itm_logit.mean(dim=1) | |
| return itm_logit | |
| elif match_head == "itc": | |
| query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
| query_output = self.Qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=image_embeds, | |
| encoder_attention_mask=image_atts, | |
| return_dict=True, | |
| ) | |
| image_feats = F.normalize( | |
| self.vision_proj(query_output.last_hidden_state), dim=-1 | |
| ) | |
| text_output = self.Qformer.bert( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| return_dict=True, | |
| ) | |
| text_feat = F.normalize( | |
| self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 | |
| ) | |
| sims = torch.bmm(image_feats, text_feat.unsqueeze(-1)) | |
| sim, _ = torch.max(sims, dim=1) | |
| return sim | |