#-*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import torch import torch.nn as nn layers_position = { 'PoseResNet_50': 158, 'PoseResNet_101': 311, 'PoseEfficientNet_B4': 415, } def preset_model(cfg, model, optimizer=None): #Loading models from config, make sure the pretrained path correct to the model name start_epoch = 0 if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained): model, optimizer, start_epoch = load_model(model, cfg.TRAIN.pretrained, optimizer=optimizer, resume=cfg.TRAIN.resume, lr=cfg.TRAIN.lr, lr_step=cfg.TRAIN.lr_scheduler.milestones, gamma=cfg.TRAIN.lr_scheduler.gamma) else: model.init_weights(**cfg.MODEL.INIT_WEIGHTS) print('Loading model successfully -- {}'.format(cfg.MODEL.type)) #Freeze backbone if begin_epoch < warm up if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up: freeze_backbone(cfg.MODEL, model) print('Number of parameters', sum(p.numel() for p in model.parameters())) print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) return model, optimizer, start_epoch def load_pretrained(model, weight_path): ''' This function only care about state dict of model For other modules such as optimizer, resume learning, please refer @load_model ''' state_dict = torch.load(weight_path)['state_dict'] model.load_state_dict(state_dict, strict=True) return model def freeze_backbone(cfg, model): ''' This func to freeze some specific layers to warm up the models ''' if hasattr(model, 'backbone'): backbone = model.backbone for param in backbone.parameters(): param.requires_grad = False else: for i, (n, p) in enumerate(model.named_parameters()): if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']): p.requires_grad = False def unfreeze_backbone(model): ''' This func to unfreeze all model layers ''' for param in model.parameters(): if not param.requires_grad: param.requires_grad = True def load_model(model, model_path, optimizer=None, resume=False, lr=None, lr_step=None, gamma=None): start_epoch = 0 checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) state_dict_ = checkpoint['state_dict'] state_dict = {} # convert data_parallal to model for k in state_dict_: if k.startswith('module') and not k.startswith('module_list'): state_dict[k[7:]] = state_dict_[k] else: state_dict[k] = state_dict_[k] model_state_dict = model.state_dict() # check loaded parameters and created model parameters msg = 'If you see this, your model does not fully load the ' + \ 'pre-trained weight. Please make sure ' + \ 'you have correctly specified --arch xxx ' + \ 'or set the correct --num_classes for your own dataset.' for k in state_dict: if k in model_state_dict: if state_dict[k].shape != model_state_dict[k].shape: print('Skip loading parameter {}, required shape{}, '\ 'loaded shape{}. {}'.format( k, model_state_dict[k].shape, state_dict[k].shape, msg)) state_dict[k] = model_state_dict[k] else: print('Drop parameter {}.'.format(k) + msg) for k in model_state_dict: if not (k in state_dict): print('No param {}.'.format(k) + msg) state_dict[k] = model_state_dict[k] model.load_state_dict(state_dict, strict=False) # resume optimizer parameters if optimizer is not None and resume: if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1 start_lr = lr for step in lr_step: if start_epoch >= step: start_lr *= gamma for param_group in optimizer.param_groups: param_group['lr'] = start_lr print('Resumed optimizer with start lr', start_lr) else: print('No optimizer parameters in checkpoint.') return model, optimizer, start_epoch def save_model(path, epoch, model, optimizer=None): if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() data = {'epoch': epoch, 'state_dict': state_dict} if not (optimizer is None): data['optimizer'] = optimizer.state_dict() torch.save(data, path)