Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| Based on https://github.com/facebookresearch/TimeSformer | |
| """ | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # Copyright 2020 Ross Wightman | |
| # Modified model creation / weight loading / state_dict helpers | |
| import logging, warnings | |
| import os | |
| import math | |
| from collections import OrderedDict | |
| import torch | |
| import torch.utils.model_zoo as model_zoo | |
| import torch.nn.functional as F | |
| def load_state_dict(checkpoint_path, use_ema=False): | |
| if checkpoint_path and os.path.isfile(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| state_dict_key = "state_dict" | |
| if isinstance(checkpoint, dict): | |
| if use_ema and "state_dict_ema" in checkpoint: | |
| state_dict_key = "state_dict_ema" | |
| if state_dict_key and state_dict_key in checkpoint: | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint[state_dict_key].items(): | |
| # strip `module.` prefix | |
| name = k[7:] if k.startswith("module") else k | |
| new_state_dict[name] = v | |
| state_dict = new_state_dict | |
| elif "model_state" in checkpoint: | |
| state_dict_key = "model_state" | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint[state_dict_key].items(): | |
| # strip `model.` prefix | |
| name = k[6:] if k.startswith("model") else k | |
| new_state_dict[name] = v | |
| state_dict = new_state_dict | |
| else: | |
| state_dict = checkpoint | |
| logging.info( | |
| "Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path) | |
| ) | |
| return state_dict | |
| else: | |
| logging.error("No checkpoint found at '{}'".format(checkpoint_path)) | |
| raise FileNotFoundError() | |
| def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): | |
| state_dict = load_state_dict(checkpoint_path, use_ema) | |
| model.load_state_dict(state_dict, strict=strict) | |
| # def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): | |
| # resume_epoch = None | |
| # if os.path.isfile(checkpoint_path): | |
| # checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
| # if log_info: | |
| # _logger.info('Restoring model state from checkpoint...') | |
| # new_state_dict = OrderedDict() | |
| # for k, v in checkpoint['state_dict'].items(): | |
| # name = k[7:] if k.startswith('module') else k | |
| # new_state_dict[name] = v | |
| # model.load_state_dict(new_state_dict) | |
| # if optimizer is not None and 'optimizer' in checkpoint: | |
| # if log_info: | |
| # _logger.info('Restoring optimizer state from checkpoint...') | |
| # optimizer.load_state_dict(checkpoint['optimizer']) | |
| # if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: | |
| # if log_info: | |
| # _logger.info('Restoring AMP loss scaler state from checkpoint...') | |
| # loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) | |
| # if 'epoch' in checkpoint: | |
| # resume_epoch = checkpoint['epoch'] | |
| # if 'version' in checkpoint and checkpoint['version'] > 1: | |
| # resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save | |
| # if log_info: | |
| # _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) | |
| # else: | |
| # model.load_state_dict(checkpoint) | |
| # if log_info: | |
| # _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) | |
| # return resume_epoch | |
| # else: | |
| # _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) | |
| # raise FileNotFoundError() | |
| def load_pretrained( | |
| model, | |
| cfg=None, | |
| num_classes=1000, | |
| in_chans=3, | |
| filter_fn=None, | |
| img_size=224, | |
| num_frames=8, | |
| num_patches=196, | |
| attention_type="divided_space_time", | |
| pretrained_model="", | |
| strict=True, | |
| ): | |
| if cfg is None: | |
| cfg = getattr(model, "default_cfg") | |
| if cfg is None or "url" not in cfg or not cfg["url"]: | |
| logging.warning("Pretrained model URL is invalid, using random initialization.") | |
| return | |
| if len(pretrained_model) == 0: | |
| if cfg is None: | |
| logging.info(f"loading from default config {model.default_cfg}.") | |
| state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu") | |
| else: | |
| try: | |
| state_dict = load_state_dict(pretrained_model)["model"] | |
| except: | |
| state_dict = load_state_dict(pretrained_model) | |
| if filter_fn is not None: | |
| state_dict = filter_fn(state_dict) | |
| if in_chans == 1: | |
| conv1_name = cfg["first_conv"] | |
| logging.info( | |
| "Converting first conv (%s) pretrained weights from 3 to 1 channel" | |
| % conv1_name | |
| ) | |
| conv1_weight = state_dict[conv1_name + ".weight"] | |
| conv1_type = conv1_weight.dtype | |
| conv1_weight = conv1_weight.float() | |
| O, I, J, K = conv1_weight.shape | |
| if I > 3: | |
| assert conv1_weight.shape[1] % 3 == 0 | |
| # For models with space2depth stems | |
| conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) | |
| conv1_weight = conv1_weight.sum(dim=2, keepdim=False) | |
| else: | |
| conv1_weight = conv1_weight.sum(dim=1, keepdim=True) | |
| conv1_weight = conv1_weight.to(conv1_type) | |
| state_dict[conv1_name + ".weight"] = conv1_weight | |
| elif in_chans != 3: | |
| conv1_name = cfg["first_conv"] | |
| conv1_weight = state_dict[conv1_name + ".weight"] | |
| conv1_type = conv1_weight.dtype | |
| conv1_weight = conv1_weight.float() | |
| O, I, J, K = conv1_weight.shape | |
| if I != 3: | |
| logging.warning( | |
| "Deleting first conv (%s) from pretrained weights." % conv1_name | |
| ) | |
| del state_dict[conv1_name + ".weight"] | |
| strict = False | |
| else: | |
| logging.info( | |
| "Repeating first conv (%s) weights in channel dim." % conv1_name | |
| ) | |
| repeat = int(math.ceil(in_chans / 3)) | |
| conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] | |
| conv1_weight *= 3 / float(in_chans) | |
| conv1_weight = conv1_weight.to(conv1_type) | |
| state_dict[conv1_name + ".weight"] = conv1_weight | |
| classifier_name = cfg["classifier"] | |
| if num_classes == 1000 and cfg["num_classes"] == 1001: | |
| # special case for imagenet trained models with extra background class in pretrained weights | |
| classifier_weight = state_dict[classifier_name + ".weight"] | |
| state_dict[classifier_name + ".weight"] = classifier_weight[1:] | |
| classifier_bias = state_dict[classifier_name + ".bias"] | |
| state_dict[classifier_name + ".bias"] = classifier_bias[1:] | |
| elif num_classes != state_dict[classifier_name + ".weight"].size(0): | |
| # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) | |
| # completely discard fully connected for all other differences between pretrained and created model | |
| del state_dict[classifier_name + ".weight"] | |
| del state_dict[classifier_name + ".bias"] | |
| strict = False | |
| ## Resizing the positional embeddings in case they don't match | |
| logging.info( | |
| f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}" | |
| ) | |
| if num_patches + 1 != state_dict["pos_embed"].size(1): | |
| pos_embed = state_dict["pos_embed"] | |
| cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
| other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
| new_pos_embed = F.interpolate( | |
| other_pos_embed, size=(num_patches), mode="nearest" | |
| ) | |
| new_pos_embed = new_pos_embed.transpose(1, 2) | |
| new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
| state_dict["pos_embed"] = new_pos_embed | |
| ## Resizing time embeddings in case they don't match | |
| if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): | |
| logging.info( | |
| f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}" | |
| ) | |
| time_embed = state_dict["time_embed"].transpose(1, 2) | |
| new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
| state_dict["time_embed"] = new_time_embed.transpose(1, 2) | |
| ## Initializing temporal attention | |
| if attention_type == "divided_space_time": | |
| new_state_dict = state_dict.copy() | |
| for key in state_dict: | |
| if "blocks" in key and "attn" in key: | |
| new_key = key.replace("attn", "temporal_attn") | |
| if not new_key in state_dict: | |
| new_state_dict[new_key] = state_dict[key] | |
| else: | |
| new_state_dict[new_key] = state_dict[new_key] | |
| if "blocks" in key and "norm1" in key: | |
| new_key = key.replace("norm1", "temporal_norm1") | |
| if not new_key in state_dict: | |
| new_state_dict[new_key] = state_dict[key] | |
| else: | |
| new_state_dict[new_key] = state_dict[new_key] | |
| state_dict = new_state_dict | |
| ## Loading the weights | |
| model.load_state_dict(state_dict, strict=False) | |
| def load_pretrained_imagenet( | |
| model, | |
| pretrained_model, | |
| cfg=None, | |
| ignore_classifier=True, | |
| num_frames=8, | |
| num_patches=196, | |
| **kwargs, | |
| ): | |
| import timm | |
| logging.info(f"Loading vit_base_patch16_224 checkpoints.") | |
| loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224( | |
| pretrained=True | |
| ).state_dict() | |
| del loaded_state_dict["head.weight"] | |
| del loaded_state_dict["head.bias"] | |
| ## Initializing temporal attention | |
| new_state_dict = loaded_state_dict.copy() | |
| for key in loaded_state_dict: | |
| if "blocks" in key and "attn" in key: | |
| new_key = key.replace("attn", "temporal_attn") | |
| if not new_key in loaded_state_dict: | |
| new_state_dict[new_key] = loaded_state_dict[key] | |
| else: | |
| new_state_dict[new_key] = loaded_state_dict[new_key] | |
| if "blocks" in key and "norm1" in key: | |
| new_key = key.replace("norm1", "temporal_norm1") | |
| if not new_key in loaded_state_dict: | |
| new_state_dict[new_key] = loaded_state_dict[key] | |
| else: | |
| new_state_dict[new_key] = loaded_state_dict[new_key] | |
| loaded_state_dict = new_state_dict | |
| loaded_keys = loaded_state_dict.keys() | |
| model_keys = model.state_dict().keys() | |
| load_not_in_model = [k for k in loaded_keys if k not in model_keys] | |
| model_not_in_load = [k for k in model_keys if k not in loaded_keys] | |
| toload = dict() | |
| mismatched_shape_keys = [] | |
| for k in model_keys: | |
| if k in loaded_keys: | |
| if model.state_dict()[k].shape != loaded_state_dict[k].shape: | |
| mismatched_shape_keys.append(k) | |
| else: | |
| toload[k] = loaded_state_dict[k] | |
| logging.info("Keys in loaded but not in model:") | |
| logging.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") | |
| logging.info("Keys in model but not in loaded:") | |
| logging.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") | |
| logging.info("Keys in model and loaded, but shape mismatched:") | |
| logging.info( | |
| f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}" | |
| ) | |
| model.load_state_dict(toload, strict=False) | |
| def load_pretrained_kinetics( | |
| model, | |
| pretrained_model, | |
| cfg=None, | |
| ignore_classifier=True, | |
| num_frames=8, | |
| num_patches=196, | |
| **kwargs, | |
| ): | |
| if cfg is None: | |
| cfg = getattr(model, "default_cfg") | |
| if cfg is None or "url" not in cfg or not cfg["url"]: | |
| logging.warning("Pretrained model URL is invalid, using random initialization.") | |
| return | |
| assert ( | |
| len(pretrained_model) > 0 | |
| ), "Path to pre-trained Kinetics weights not provided." | |
| state_dict = load_state_dict(pretrained_model) | |
| classifier_name = cfg["classifier"] | |
| if ignore_classifier: | |
| classifier_weight_key = classifier_name + ".weight" | |
| classifier_bias_key = classifier_name + ".bias" | |
| state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key] | |
| state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key] | |
| else: | |
| raise NotImplementedError( | |
| "[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier." | |
| ) | |
| ## Resizing the positional embeddings in case they don't match | |
| if num_patches + 1 != state_dict["pos_embed"].size(1): | |
| new_pos_embed = resize_spatial_embedding(state_dict, "pos_embed", num_patches) | |
| state_dict["pos_embed"] = new_pos_embed | |
| ## Resizing time embeddings in case they don't match | |
| if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): | |
| state_dict["time_embed"] = resize_temporal_embedding( | |
| state_dict, "time_embed", num_frames | |
| ) | |
| ## Loading the weights | |
| try: | |
| model.load_state_dict(state_dict, strict=True) | |
| logging.info("Succeeded in loading Kinetics pre-trained weights.") | |
| except: | |
| logging.error("Error in loading Kinetics pre-trained weights.") | |
| def resize_spatial_embedding(state_dict, key, num_patches): | |
| logging.info( | |
| f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" | |
| ) | |
| pos_embed = state_dict[key] | |
| cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
| other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
| new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") | |
| new_pos_embed = new_pos_embed.transpose(1, 2) | |
| new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
| return new_pos_embed | |
| def resize_temporal_embedding(state_dict, key, num_frames): | |
| logging.info( | |
| f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" | |
| ) | |
| time_embed = state_dict[key].transpose(1, 2) | |
| new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
| return new_time_embed.transpose(1, 2) | |
| def detach_variable(inputs): | |
| if isinstance(inputs, tuple): | |
| out = [] | |
| for inp in inputs: | |
| x = inp.detach() | |
| x.requires_grad = inp.requires_grad | |
| out.append(x) | |
| return tuple(out) | |
| else: | |
| raise RuntimeError( | |
| "Only tuple of tensors is supported. Got Unsupported input type: ", | |
| type(inputs).__name__, | |
| ) | |
| def check_backward_validity(inputs): | |
| if not any(inp.requires_grad for inp in inputs): | |
| warnings.warn( | |
| "None of the inputs have requires_grad=True. Gradients will be None" | |
| ) | |