|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
layers_position = { |
|
|
'PoseResNet_50': 158, |
|
|
'PoseResNet_101': 311, |
|
|
'PoseEfficientNet_B4': 415, |
|
|
} |
|
|
|
|
|
|
|
|
def preset_model(cfg, model, optimizer=None): |
|
|
|
|
|
start_epoch = 0 |
|
|
if 'pretrained' in cfg.TRAIN and os.path.isfile(cfg.TRAIN.pretrained): |
|
|
model, optimizer, start_epoch = load_model(model, |
|
|
cfg.TRAIN.pretrained, |
|
|
optimizer=optimizer, |
|
|
resume=cfg.TRAIN.resume, |
|
|
lr=cfg.TRAIN.lr, |
|
|
lr_step=cfg.TRAIN.lr_scheduler.milestones, |
|
|
gamma=cfg.TRAIN.lr_scheduler.gamma) |
|
|
else: |
|
|
model.init_weights(**cfg.MODEL.INIT_WEIGHTS) |
|
|
print('Loading model successfully -- {}'.format(cfg.MODEL.type)) |
|
|
|
|
|
|
|
|
if cfg.TRAIN.freeze_backbone and start_epoch < cfg.TRAIN.warm_up: |
|
|
freeze_backbone(cfg.MODEL, model) |
|
|
|
|
|
print('Number of parameters', sum(p.numel() for p in model.parameters())) |
|
|
print('Number of trainable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) |
|
|
return model, optimizer, start_epoch |
|
|
|
|
|
|
|
|
def load_pretrained(model, weight_path): |
|
|
''' |
|
|
This function only care about state dict of model |
|
|
For other modules such as optimizer, resume learning, please refer @load_model |
|
|
''' |
|
|
state_dict = torch.load(weight_path)['state_dict'] |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
return model |
|
|
|
|
|
|
|
|
def freeze_backbone(cfg, model): |
|
|
''' |
|
|
This func to freeze some specific layers to warm up the models |
|
|
''' |
|
|
if hasattr(model, 'backbone'): |
|
|
backbone = model.backbone |
|
|
for param in backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
for i, (n, p) in enumerate(model.named_parameters()): |
|
|
if (i <= layers_position[f'{cfg.type}_{cfg.num_layers}']): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
def unfreeze_backbone(model): |
|
|
''' |
|
|
This func to unfreeze all model layers |
|
|
''' |
|
|
for param in model.parameters(): |
|
|
if not param.requires_grad: |
|
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
def load_model(model, model_path, optimizer=None, resume=False, |
|
|
lr=None, lr_step=None, gamma=None): |
|
|
start_epoch = 0 |
|
|
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
|
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) |
|
|
state_dict_ = checkpoint['state_dict'] |
|
|
state_dict = {} |
|
|
|
|
|
|
|
|
for k in state_dict_: |
|
|
if k.startswith('module') and not k.startswith('module_list'): |
|
|
state_dict[k[7:]] = state_dict_[k] |
|
|
else: |
|
|
state_dict[k] = state_dict_[k] |
|
|
model_state_dict = model.state_dict() |
|
|
|
|
|
|
|
|
msg = 'If you see this, your model does not fully load the ' + \ |
|
|
'pre-trained weight. Please make sure ' + \ |
|
|
'you have correctly specified --arch xxx ' + \ |
|
|
'or set the correct --num_classes for your own dataset.' |
|
|
for k in state_dict: |
|
|
if k in model_state_dict: |
|
|
if state_dict[k].shape != model_state_dict[k].shape: |
|
|
print('Skip loading parameter {}, required shape{}, '\ |
|
|
'loaded shape{}. {}'.format( |
|
|
k, model_state_dict[k].shape, state_dict[k].shape, msg)) |
|
|
state_dict[k] = model_state_dict[k] |
|
|
else: |
|
|
print('Drop parameter {}.'.format(k) + msg) |
|
|
for k in model_state_dict: |
|
|
if not (k in state_dict): |
|
|
print('No param {}.'.format(k) + msg) |
|
|
state_dict[k] = model_state_dict[k] |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
|
|
|
if optimizer is not None and resume: |
|
|
if 'optimizer' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
start_epoch = checkpoint['epoch'] + 1 |
|
|
start_lr = lr |
|
|
for step in lr_step: |
|
|
if start_epoch >= step: |
|
|
start_lr *= gamma |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = start_lr |
|
|
print('Resumed optimizer with start lr', start_lr) |
|
|
else: |
|
|
print('No optimizer parameters in checkpoint.') |
|
|
return model, optimizer, start_epoch |
|
|
|
|
|
|
|
|
def save_model(path, epoch, model, optimizer=None): |
|
|
if isinstance(model, torch.nn.DataParallel): |
|
|
state_dict = model.module.state_dict() |
|
|
else: |
|
|
state_dict = model.state_dict() |
|
|
data = {'epoch': epoch, |
|
|
'state_dict': state_dict} |
|
|
if not (optimizer is None): |
|
|
data['optimizer'] = optimizer.state_dict() |
|
|
torch.save(data, path) |
|
|
|