| import os | |
| import sys | |
| import time | |
| import random | |
| import argparse | |
| from collections import OrderedDict, defaultdict | |
| import torch | |
| import torch.utils.model_zoo as model_zoo | |
| model_urls = { | |
| 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | |
| 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | |
| 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | |
| 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | |
| 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | |
| } | |
| def load_model(model, model_file, is_restore=False): | |
| t_start = time.time() | |
| if model_file is None: | |
| return model | |
| if isinstance(model_file, str): | |
| state_dict = torch.load(model_file) | |
| if 'model' in state_dict.keys(): | |
| state_dict = state_dict['model'] | |
| else: | |
| state_dict = model_file | |
| t_ioend = time.time() | |
| if is_restore: | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = 'module.' + k | |
| new_state_dict[name] = v | |
| state_dict = new_state_dict | |
| model.load_state_dict(state_dict, strict=False) | |
| ckpt_keys = set(state_dict.keys()) | |
| own_keys = set(model.state_dict().keys()) | |
| missing_keys = own_keys - ckpt_keys | |
| unexpected_keys = ckpt_keys - own_keys | |
| del state_dict | |
| t_end = time.time() | |
| return model | |