| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| |
|
| | def load_model(model, model_path, optimizer=None, resume=False, |
| | lr=None, lr_step=None): |
| | start_epoch = 0 |
| | checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) |
| | print(f'loaded {model_path}') |
| | state_dict = checkpoint['model'] |
| | model_state_dict = model.state_dict() |
| |
|
| | |
| | msg = 'If you see this, your model does not fully load the ' + \ |
| | 'pre-trained weight. Please make sure ' + \ |
| | 'you set the correct --num_classes for your own dataset.' |
| | for k in state_dict: |
| | if k in model_state_dict: |
| | if state_dict[k].shape != model_state_dict[k].shape: |
| | print('Skip loading parameter {}, required shape{}, ' \ |
| | 'loaded shape{}. {}'.format( |
| | k, model_state_dict[k].shape, state_dict[k].shape, msg)) |
| | if 'class_embed' in k: |
| | print("load class_embed: {} shape={}".format(k, state_dict[k].shape)) |
| | if model_state_dict[k].shape[0] == 1: |
| | state_dict[k] = state_dict[k][1:2] |
| | elif model_state_dict[k].shape[0] == 2: |
| | state_dict[k] = state_dict[k][1:3] |
| | elif model_state_dict[k].shape[0] == 3: |
| | state_dict[k] = state_dict[k][1:4] |
| | else: |
| | raise NotImplementedError('invalid shape: {}'.format(model_state_dict[k].shape)) |
| | continue |
| | state_dict[k] = model_state_dict[k] |
| | else: |
| | print('Drop parameter {}.'.format(k) + msg) |
| | for k in model_state_dict: |
| | if not (k in state_dict): |
| | print('No param {}.'.format(k) + msg) |
| | state_dict[k] = model_state_dict[k] |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | |
| | if optimizer is not None and resume: |
| | if 'optimizer' in checkpoint: |
| | optimizer.load_state_dict(checkpoint['optimizer']) |
| | start_epoch = checkpoint['epoch'] |
| | start_lr = lr |
| | for step in lr_step: |
| | if start_epoch >= step: |
| | start_lr *= 0.1 |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = start_lr |
| | print('Resumed optimizer with start lr', start_lr) |
| | else: |
| | print('No optimizer parameters in checkpoint.') |
| | if optimizer is not None: |
| | return model, optimizer, start_epoch |
| | else: |
| | return model |
| |
|
| |
|
| |
|
| |
|