Spaces:
Runtime error
Runtime error
| from mmengine.runner.checkpoint import CheckpointLoader | |
| def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): | |
| """Load partial pretrained model with specific prefix. | |
| Args: | |
| prefix (str): The prefix of sub-module. | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
| details. | |
| map_location (str | None): Same as :func:`torch.load`. | |
| Defaults to None. | |
| logger: logger | |
| Returns: | |
| dict or OrderedDict: The loaded checkpoint. | |
| """ | |
| checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| if not prefix: | |
| return state_dict | |
| if not prefix.endswith('.'): | |
| prefix += '.' | |
| prefix_len = len(prefix) | |
| state_dict = { | |
| k[prefix_len:]: v | |
| for k, v in state_dict.items() if k.startswith(prefix) | |
| } | |
| assert state_dict, f'{prefix} is not in the pretrained model' | |
| return state_dict | |