File size: 5,196 Bytes
e7824d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#-*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import torch
import torch.nn as nn
layers_position = {
'PoseResNet_50': 158,
'PoseResNet_101': 311,
'PoseEfficientNet_B4': 415,
}
def preset_model(cfg, model, optimizer=None):
#Loading models from config, make sure the pretrained path correct to the model name
start_epoch = 0
if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained):
model, optimizer, start_epoch = load_model(model,
cfg.TRAIN.pretrained,
optimizer=optimizer,
resume=cfg.TRAIN.resume,
lr=cfg.TRAIN.lr,
lr_step=cfg.TRAIN.lr_scheduler.milestones,
gamma=cfg.TRAIN.lr_scheduler.gamma)
else:
model.init_weights(**cfg.MODEL.INIT_WEIGHTS)
print('Loading model successfully -- {}'.format(cfg.MODEL.type))
#Freeze backbone if begin_epoch < warm up
if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up:
freeze_backbone(cfg.MODEL, model)
print('Number of parameters', sum(p.numel() for p in model.parameters()))
print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
return model, optimizer, start_epoch
def load_pretrained(model, weight_path):
'''
This function only care about state dict of model
For other modules such as optimizer, resume learning, please refer @load_model
'''
state_dict = torch.load(weight_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
return model
def freeze_backbone(cfg, model):
'''
This func to freeze some specific layers to warm up the models
'''
if hasattr(model, 'backbone'):
backbone = model.backbone
for param in backbone.parameters():
param.requires_grad = False
else:
for i, (n, p) in enumerate(model.named_parameters()):
if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']):
p.requires_grad = False
def unfreeze_backbone(model):
'''
This func to unfreeze all model layers
'''
for param in model.parameters():
if not param.requires_grad:
param.requires_grad = True
def load_model(model, model_path, optimizer=None, resume=False,
lr=None, lr_step=None, gamma=None):
start_epoch = 0
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict_ = checkpoint['state_dict']
state_dict = {}
# convert data_parallal to model
for k in state_dict_:
if k.startswith('module') and not k.startswith('module_list'):
state_dict[k[7:]] = state_dict_[k]
else:
state_dict[k] = state_dict_[k]
model_state_dict = model.state_dict()
# check loaded parameters and created model parameters
msg = 'If you see this, your model does not fully load the ' + \
'pre-trained weight. Please make sure ' + \
'you have correctly specified --arch xxx ' + \
'or set the correct --num_classes for your own dataset.'
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, '\
'loaded shape{}. {}'.format(
k, model_state_dict[k].shape, state_dict[k].shape, msg))
state_dict[k] = model_state_dict[k]
else:
print('Drop parameter {}.'.format(k) + msg)
for k in model_state_dict:
if not (k in state_dict):
print('No param {}.'.format(k) + msg)
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict, strict=False)
# resume optimizer parameters
if optimizer is not None and resume:
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
start_lr = lr
for step in lr_step:
if start_epoch >= step:
start_lr *= gamma
for param_group in optimizer.param_groups:
param_group['lr'] = start_lr
print('Resumed optimizer with start lr', start_lr)
else:
print('No optimizer parameters in checkpoint.')
return model, optimizer, start_epoch
def save_model(path, epoch, model, optimizer=None):
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
data = {'epoch': epoch,
'state_dict': state_dict}
if not (optimizer is None):
data['optimizer'] = optimizer.state_dict()
torch.save(data, path)
|