Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.registry import registry | |
| from lavis.common.utils import get_abs_path | |
| from lavis.models.albef_models import AlbefBase | |
| from lavis.models.albef_models.albef_outputs import AlbefIntermediateOutput, AlbefOutput | |
| from lavis.models.base_model import MomentumDistilationMixin | |
| from lavis.models.med import BertModel | |
| from lavis.models.vit import VisionTransformerEncoder | |
| from torch import nn | |
| from transformers import BertConfig | |
| class AlbefNLVR(AlbefBase, MomentumDistilationMixin): | |
| PRETRAINED_MODEL_CONFIG_DICT = { | |
| "nlvr": "configs/models/albef_nlvr.yaml", | |
| } | |
| def __init__( | |
| self, | |
| image_encoder, | |
| text_encoder, | |
| num_classes, | |
| momentum=0.995, | |
| alpha=0.4, | |
| use_distill=True, | |
| max_txt_len=40, | |
| ): | |
| super().__init__() | |
| self.tokenizer = self.init_tokenizer() | |
| self.max_txt_len = max_txt_len | |
| self.use_distill = use_distill | |
| self.visual_encoder = image_encoder | |
| self.text_encoder = text_encoder | |
| hidden_size = text_encoder.config.hidden_size | |
| self.cls_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, num_classes), | |
| ) | |
| self.share_cross_attention(self.text_encoder.encoder) | |
| if self.use_distill: | |
| self.visual_encoder_m = deepcopy(self.visual_encoder) | |
| self.text_encoder_m = deepcopy(self.text_encoder) | |
| self.cls_head_m = deepcopy(self.cls_head) | |
| self.share_cross_attention(self.text_encoder_m.encoder) | |
| self.momentum = momentum | |
| self.alpha = alpha | |
| self.model_pairs = [ | |
| [self.visual_encoder, self.visual_encoder_m], | |
| [self.text_encoder, self.text_encoder_m], | |
| [self.cls_head, self.cls_head_m], | |
| ] | |
| self.copy_params() | |
| def _rampup_factor(self, epoch, iters, num_iters_per_epoch): | |
| return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) | |
| def forward(self, samples, is_train=True): | |
| """ | |
| Forward function for training and evaluation. | |
| Args: | |
| samples (dict): a dict of input samples, which contains the following keys: | |
| - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384. | |
| - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384. | |
| - text_input (list): list of strings, each string is a natural language sentence. | |
| - label (torch.LongTensor): ground truth label with shape (batch_size,). | |
| is_train (bool): whether the model is in training mode. | |
| If True, the model will return the loss; | |
| If False, the model will return the prediction. | |
| Examples: | |
| >>> import torch | |
| >>> from lavis.models import load_model | |
| >>> model = load_model("albef_nlvr") | |
| >>> samples = { | |
| ... "image0": torch.randn(2, 3, 384, 384), | |
| ... "image1": torch.randn(2, 3, 384, 384), | |
| ... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"], | |
| ... "label": torch.tensor([0, 1]), | |
| ... } | |
| >>> output = model(samples) | |
| >>> output.keys() | |
| odict_keys(['intermediate_output', 'loss']) | |
| """ | |
| text = samples["text_input"] | |
| text = self.tokenizer( | |
| text, | |
| padding="longest", | |
| truncation=True, | |
| max_length=self.max_txt_len, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| targets = samples["label"] | |
| image0 = samples["image0"] | |
| image1 = samples["image1"] | |
| images = torch.cat([image0, image1], dim=0) | |
| image_embeds = self.visual_encoder.forward_features(images) | |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
| self.device | |
| ) | |
| image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) | |
| encoder_output = self.text_encoder( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| encoder_hidden_states=[image0_embeds, image1_embeds], | |
| encoder_attention_mask=[ | |
| image_atts[: image0_embeds.size(0)], | |
| image_atts[image0_embeds.size(0) :], | |
| ], | |
| return_dict=True, | |
| ) | |
| prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) | |
| if is_train: | |
| if self.use_distill: | |
| with torch.no_grad(): | |
| self._momentum_update() | |
| image_embeds_m = self.visual_encoder_m(images) | |
| image0_embeds_m, image1_embeds_m = torch.split( | |
| image_embeds_m, targets.size(0) | |
| ) | |
| encoder_output_m = self.text_encoder( | |
| text.input_ids, | |
| attention_mask=text.attention_mask, | |
| encoder_hidden_states=[image0_embeds_m, image1_embeds_m], | |
| encoder_attention_mask=[ | |
| image_atts[: image0_embeds_m.size(0)], | |
| image_atts[image0_embeds_m.size(0) :], | |
| ], | |
| return_dict=True, | |
| ) | |
| prediction_m = self.cls_head_m( | |
| encoder_output_m.last_hidden_state[:, 0, :] | |
| ) | |
| alpha = self.alpha * self._rampup_factor( | |
| epoch=samples["epoch"], | |
| iters=samples["iters"], | |
| num_iters_per_epoch=samples["num_iters_per_epoch"], | |
| ) | |
| loss = (1 - alpha) * F.cross_entropy( | |
| prediction, targets | |
| ) - alpha * torch.sum( | |
| F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1), | |
| dim=1, | |
| ).mean() | |
| else: | |
| loss = F.cross_entropy(prediction, targets) | |
| encoder_output_m = None | |
| image0_embeds_m, image1_embeds_m = None, None | |
| # return {"loss": loss} | |
| return AlbefOutput( | |
| loss=loss, | |
| intermediate_output=AlbefIntermediateOutput( | |
| image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0), | |
| image_embeds_m=torch.stack( | |
| [image0_embeds_m, image1_embeds_m], dim=0 | |
| ), | |
| encoder_output=encoder_output, | |
| encoder_output_m=encoder_output_m, | |
| ), | |
| ) | |
| else: | |
| return {"predictions": prediction, "targets": targets} | |
| def share_cross_attention(self, model): | |
| for i in range(6): | |
| layer_num = 6 + i * 2 | |
| modules_0 = model.layer[layer_num].crossattention.self._modules | |
| modules_1 = model.layer[layer_num + 1].crossattention.self._modules | |
| for name in modules_0.keys(): | |
| if "key" in name or "value" in name: | |
| module_0 = modules_0[name] | |
| module_1 = modules_1[name] | |
| if hasattr(module_0, "weight"): | |
| module_0.weight = module_1.weight | |
| if hasattr(module_0, "bias"): | |
| module_0.bias = module_1.bias | |
| def predict(self, samples): | |
| output = self.forward(samples, is_train=False) | |
| return output | |
| def load_from_pretrained(self, url_or_filename, use_distill=True): | |
| _, msg = super().load_from_pretrained(url_or_filename) | |
| if use_distill and any(["_m" in k for k in msg.missing_keys]): | |
| # this is required when initializing the model from TA pre-trained weights | |
| self.copy_params() | |
| return msg | |
| def from_config(cls, cfg=None): | |
| image_encoder = VisionTransformerEncoder.from_config(cfg) | |
| # text encoder + multimodal encoder | |
| bert_config = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"])) | |
| bert_config.num_hidden_layers = 18 | |
| text_encoder = BertModel.from_pretrained( | |
| "bert-base-uncased", config=bert_config, add_pooling_layer=False | |
| ) | |
| alpha = cfg.get("alpha", 0.4) | |
| momentum = cfg.get("momentum", 0.995) | |
| use_distill = cfg.get("use_distill", True) | |
| num_classes = cfg.get("num_classes", -1) | |
| max_txt_len = cfg.get("max_txt_len", 40) | |
| assert num_classes > 1, "Invalid number of classes provided, found {}".format( | |
| num_classes | |
| ) | |
| model = cls( | |
| image_encoder=image_encoder, | |
| text_encoder=text_encoder, | |
| use_distill=use_distill, | |
| alpha=alpha, | |
| num_classes=num_classes, | |
| momentum=momentum, | |
| max_txt_len=max_txt_len, | |
| ) | |
| model.load_checkpoint_from_config(cfg) | |
| return model | |