| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import logging
|
| import os
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| from lavis.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
| from lavis.common.utils import get_abs_path, is_url
|
| from omegaconf import OmegaConf
|
| from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
| class FAPMConfig(PretrainedConfig):
|
| model_type = 'FAPM'
|
| def __init__(self, important_param=42, **kwargs):
|
| super().__init__(**kwargs)
|
|
|
|
|
| class BaseModel(PreTrainedModel):
|
| """Base class for models."""
|
| config_class = FAPMConfig
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.config = config
|
|
|
| @property
|
| def device(self):
|
| return list(self.parameters())[0].device
|
|
|
| def load_checkpoint(self, url_or_filename):
|
| """
|
| Load from a finetuned checkpoint.
|
|
|
| This should expect no mismatch in the model keys and the checkpoint keys.
|
| """
|
|
|
| if is_url(url_or_filename):
|
| cached_file = download_cached_file(
|
| url_or_filename, check_hash=False, progress=True
|
| )
|
| checkpoint = torch.load(cached_file, map_location="cpu")
|
| elif os.path.isfile(url_or_filename):
|
| checkpoint = torch.load(url_or_filename, map_location="cpu")
|
| else:
|
| raise RuntimeError("checkpoint url or path is invalid")
|
|
|
| if "model" in checkpoint.keys():
|
| state_dict = checkpoint["model"]
|
| else:
|
| state_dict = checkpoint
|
|
|
| msg = self.load_state_dict(state_dict, strict=False)
|
|
|
| logging.info("Missing keys {}".format(msg.missing_keys))
|
| logging.info("load checkpoint from %s" % url_or_filename)
|
|
|
| return msg
|
|
|
| @classmethod
|
| def from_pretrained(cls, model_type):
|
| """
|
| Build a pretrained model from default configuration file, specified by model_type.
|
|
|
| Args:
|
| - model_type (str): model type, specifying architecture and checkpoints.
|
|
|
| Returns:
|
| - model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
| """
|
| model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
| model = cls.from_config(model_cfg)
|
|
|
| return model
|
|
|
| @classmethod
|
| def default_config_path(cls, model_type):
|
| assert (
|
| model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
| ), "Unknown model type {}".format(model_type)
|
| return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
|
|
| def load_checkpoint_from_config(self, cfg, **kwargs):
|
| """
|
| Load checkpoint as specified in the config file.
|
|
|
| If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
| When loading the pretrained model, each task-specific architecture may define their
|
| own load_from_pretrained() method.
|
| """
|
| load_finetuned = cfg.get("load_finetuned", True)
|
| if load_finetuned:
|
| finetune_path = cfg.get("finetuned", None)
|
| assert (
|
| finetune_path is not None
|
| ), "Found load_finetuned is True, but finetune_path is None."
|
| self.load_checkpoint(url_or_filename=finetune_path)
|
| else:
|
| load_pretrained = cfg.get("load_pretrained", True)
|
| if load_pretrained:
|
|
|
| pretrain_path = cfg.get("pretrained", None)
|
| assert "Found load_finetuned is False, but pretrain_path is None."
|
| self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
|
|
| def before_training(self, **kwargs):
|
| pass
|
|
|
| def get_optimizer_params(self, weight_decay, lr_scale=1):
|
| p_wd, p_non_wd = [], []
|
| for n, p in self.named_parameters():
|
| if not p.requires_grad:
|
| continue
|
| if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
|
| p_non_wd.append(p)
|
| else:
|
| p_wd.append(p)
|
| optim_params = [
|
| {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
|
| {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
|
| ]
|
| return optim_params
|
|
|
| def before_evaluation(self, **kwargs):
|
| pass
|
|
|
| def show_n_params(self, return_str=True):
|
| tot = 0
|
| for p in self.parameters():
|
| w = 1
|
| for x in p.shape:
|
| w *= x
|
| tot += w
|
| if return_str:
|
| if tot >= 1e6:
|
| return "{:.1f}M".format(tot / 1e6)
|
| else:
|
| return "{:.1f}K".format(tot / 1e3)
|
| else:
|
| return tot
|
|
|
|
|
| class BaseEncoder(nn.Module):
|
| """
|
| Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
| """
|
|
|
| def __init__(self):
|
| super().__init__()
|
|
|
| def forward_features(self, samples, **kwargs):
|
| raise NotImplementedError
|
|
|
| @property
|
| def device(self):
|
| return list(self.parameters())[0].device
|
|
|
|
|
| class SharedQueueMixin:
|
| @torch.no_grad()
|
| def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
|
|
| image_feats = concat_all_gather(image_feat)
|
| text_feats = concat_all_gather(text_feat)
|
|
|
| batch_size = image_feats.shape[0]
|
|
|
| ptr = int(self.queue_ptr)
|
| assert self.queue_size % batch_size == 0
|
|
|
|
|
| self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
| self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
|
|
| if idxs is not None:
|
| idxs = concat_all_gather(idxs)
|
| self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
|
|
| ptr = (ptr + batch_size) % self.queue_size
|
| self.queue_ptr[0] = ptr
|
|
|
|
|
| class MomentumDistilationMixin:
|
| @torch.no_grad()
|
| def copy_params(self):
|
| for model_pair in self.model_pairs:
|
| for param, param_m in zip(
|
| model_pair[0].parameters(), model_pair[1].parameters()
|
| ):
|
| param_m.data.copy_(param.data)
|
| param_m.requires_grad = False
|
|
|
| @torch.no_grad()
|
| def _momentum_update(self):
|
| for model_pair in self.model_pairs:
|
| for param, param_m in zip(
|
| model_pair[0].parameters(), model_pair[1].parameters()
|
| ):
|
| param_m.data = param_m.data * self.momentum + param.data * (
|
| 1.0 - self.momentum
|
| )
|
|
|
|
|
| class GatherLayer(torch.autograd.Function):
|
| """
|
| Gather tensors from all workers with support for backward propagation:
|
| This implementation does not cut the gradients as torch.distributed.all_gather does.
|
| """
|
|
|
| @staticmethod
|
| def forward(ctx, x):
|
| output = [
|
| torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
| ]
|
| torch.distributed.all_gather(output, x)
|
| return tuple(output)
|
|
|
| @staticmethod
|
| def backward(ctx, *grads):
|
| all_gradients = torch.stack(grads)
|
| torch.distributed.all_reduce(all_gradients)
|
| return all_gradients[torch.distributed.get_rank()]
|
|
|
|
|
| def all_gather_with_grad(tensors):
|
| """
|
| Performs all_gather operation on the provided tensors.
|
| Graph remains connected for backward grad computation.
|
| """
|
|
|
| world_size = torch.distributed.get_world_size()
|
|
|
| if world_size == 1:
|
| return tensors
|
|
|
|
|
| tensor_all = GatherLayer.apply(tensors)
|
|
|
| return torch.cat(tensor_all, dim=0)
|
|
|
|
|
| @torch.no_grad()
|
| def concat_all_gather(tensor):
|
| """
|
| Performs all_gather operation on the provided tensors.
|
| *** Warning ***: torch.distributed.all_gather has no gradient.
|
| """
|
|
|
| if not is_dist_avail_and_initialized():
|
| return tensor
|
|
|
| tensors_gather = [
|
| torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
| ]
|
| torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
| output = torch.cat(tensors_gather, dim=0)
|
| return output
|
|
|
|
|
| def tile(x, dim, n_tile):
|
| init_dim = x.size(dim)
|
| repeat_idx = [1] * x.dim()
|
| repeat_idx[dim] = n_tile
|
| x = x.repeat(*(repeat_idx))
|
| order_index = torch.LongTensor(
|
| np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
| )
|
| return torch.index_select(x, dim, order_index.to(x.device))
|
|
|