Spaces:
Runtime error
Runtime error
| #-*- coding: utf-8 -*- | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| layers_position = { | |
| 'PoseResNet_50': 158, | |
| 'PoseResNet_101': 311, | |
| 'PoseEfficientNet_B4': 415, | |
| } | |
| def preset_model(cfg, model, optimizer=None): | |
| #Loading models from config, make sure the pretrained path correct to the model name | |
| start_epoch = 0 | |
| if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained): | |
| model, optimizer, start_epoch = load_model(model, | |
| cfg.TRAIN.pretrained, | |
| optimizer=optimizer, | |
| resume=cfg.TRAIN.resume, | |
| lr=cfg.TRAIN.lr, | |
| lr_step=cfg.TRAIN.lr_scheduler.milestones, | |
| gamma=cfg.TRAIN.lr_scheduler.gamma) | |
| else: | |
| model.init_weights(**cfg.MODEL.INIT_WEIGHTS) | |
| print('Loading model successfully -- {}'.format(cfg.MODEL.type)) | |
| #Freeze backbone if begin_epoch < warm up | |
| if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up: | |
| freeze_backbone(cfg.MODEL, model) | |
| print('Number of parameters', sum(p.numel() for p in model.parameters())) | |
| print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) | |
| return model, optimizer, start_epoch | |
| def load_pretrained(model, weight_path): | |
| ''' | |
| This function only care about state dict of model | |
| For other modules such as optimizer, resume learning, please refer @load_model | |
| ''' | |
| state_dict = torch.load(weight_path)['state_dict'] | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| def freeze_backbone(cfg, model): | |
| ''' | |
| This func to freeze some specific layers to warm up the models | |
| ''' | |
| if hasattr(model, 'backbone'): | |
| backbone = model.backbone | |
| for param in backbone.parameters(): | |
| param.requires_grad = False | |
| else: | |
| for i, (n, p) in enumerate(model.named_parameters()): | |
| if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']): | |
| p.requires_grad = False | |
| def unfreeze_backbone(model): | |
| ''' | |
| This func to unfreeze all model layers | |
| ''' | |
| for param in model.parameters(): | |
| if not param.requires_grad: | |
| param.requires_grad = True | |
| def load_model(model, model_path, optimizer=None, resume=False, | |
| lr=None, lr_step=None, gamma=None): | |
| start_epoch = 0 | |
| checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) | |
| print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) | |
| state_dict_ = checkpoint['state_dict'] | |
| state_dict = {} | |
| # convert data_parallal to model | |
| for k in state_dict_: | |
| if k.startswith('module') and not k.startswith('module_list'): | |
| state_dict[k[7:]] = state_dict_[k] | |
| else: | |
| state_dict[k] = state_dict_[k] | |
| model_state_dict = model.state_dict() | |
| # check loaded parameters and created model parameters | |
| msg = 'If you see this, your model does not fully load the ' + \ | |
| 'pre-trained weight. Please make sure ' + \ | |
| 'you have correctly specified --arch xxx ' + \ | |
| 'or set the correct --num_classes for your own dataset.' | |
| for k in state_dict: | |
| if k in model_state_dict: | |
| if state_dict[k].shape != model_state_dict[k].shape: | |
| print('Skip loading parameter {}, required shape{}, '\ | |
| 'loaded shape{}. {}'.format( | |
| k, model_state_dict[k].shape, state_dict[k].shape, msg)) | |
| state_dict[k] = model_state_dict[k] | |
| else: | |
| print('Drop parameter {}.'.format(k) + msg) | |
| for k in model_state_dict: | |
| if not (k in state_dict): | |
| print('No param {}.'.format(k) + msg) | |
| state_dict[k] = model_state_dict[k] | |
| model.load_state_dict(state_dict, strict=False) | |
| # resume optimizer parameters | |
| if optimizer is not None and resume: | |
| if 'optimizer' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| start_lr = lr | |
| for step in lr_step: | |
| if start_epoch >= step: | |
| start_lr *= gamma | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = start_lr | |
| print('Resumed optimizer with start lr', start_lr) | |
| else: | |
| print('No optimizer parameters in checkpoint.') | |
| return model, optimizer, start_epoch | |
| def save_model(path, epoch, model, optimizer=None): | |
| if isinstance(model, torch.nn.DataParallel): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| data = {'epoch': epoch, | |
| 'state_dict': state_dict} | |
| if not (optimizer is None): | |
| data['optimizer'] = optimizer.state_dict() | |
| torch.save(data, path) | |