|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
|
|
from mmengine.runner import CheckpointLoader, load_state_dict |
|
|
|
|
|
|
|
|
def load_checkpoint(model, |
|
|
filename, |
|
|
map_location='cpu', |
|
|
strict=False, |
|
|
logger=None): |
|
|
"""Load checkpoint from a file or URI. |
|
|
|
|
|
Args: |
|
|
model (Module): Module to load checkpoint. |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
map_location (str): Same as :func:`torch.load`. |
|
|
strict (bool): Whether to allow different params for the model and |
|
|
checkpoint. |
|
|
logger (:mod:`logging.Logger` or None): The logger for error message. |
|
|
|
|
|
Returns: |
|
|
dict or OrderedDict: The loaded checkpoint. |
|
|
""" |
|
|
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) |
|
|
|
|
|
if not isinstance(checkpoint, dict): |
|
|
raise RuntimeError( |
|
|
f'No state_dict found in checkpoint file {filename}') |
|
|
|
|
|
if 'state_dict' in checkpoint: |
|
|
state_dict_tmp = checkpoint['state_dict'] |
|
|
elif 'model' in checkpoint: |
|
|
state_dict_tmp = checkpoint['model'] |
|
|
else: |
|
|
state_dict_tmp = checkpoint |
|
|
|
|
|
state_dict = OrderedDict() |
|
|
|
|
|
for k, v in state_dict_tmp.items(): |
|
|
if k.startswith('module.backbone.'): |
|
|
state_dict[k[16:]] = v |
|
|
elif k.startswith('module.'): |
|
|
state_dict[k[7:]] = v |
|
|
elif k.startswith('backbone.'): |
|
|
state_dict[k[9:]] = v |
|
|
else: |
|
|
state_dict[k] = v |
|
|
|
|
|
load_state_dict(model, state_dict, strict, logger) |
|
|
return checkpoint |
|
|
|
|
|
|
|
|
def get_state_dict(filename, map_location='cpu'): |
|
|
"""Get state_dict from a file or URI. |
|
|
|
|
|
Args: |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
map_location (str): Same as :func:`torch.load`. |
|
|
|
|
|
Returns: |
|
|
OrderedDict: The state_dict. |
|
|
""" |
|
|
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) |
|
|
|
|
|
if not isinstance(checkpoint, dict): |
|
|
raise RuntimeError( |
|
|
f'No state_dict found in checkpoint file {filename}') |
|
|
|
|
|
if 'state_dict' in checkpoint: |
|
|
state_dict_tmp = checkpoint['state_dict'] |
|
|
else: |
|
|
state_dict_tmp = checkpoint |
|
|
|
|
|
state_dict = OrderedDict() |
|
|
|
|
|
for k, v in state_dict_tmp.items(): |
|
|
if k.startswith('module.backbone.'): |
|
|
state_dict[k[16:]] = v |
|
|
elif k.startswith('module.'): |
|
|
state_dict[k[7:]] = v |
|
|
elif k.startswith('backbone.'): |
|
|
state_dict[k[9:]] = v |
|
|
else: |
|
|
state_dict[k] = v |
|
|
|
|
|
return state_dict |
|
|
|