|
|
import torch |
|
|
from torch._dynamo import OptimizedModule |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
import numpy |
|
|
|
|
|
|
|
|
def load_pretrained_weights(network, fname, verbose=False): |
|
|
""" |
|
|
Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the |
|
|
shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) |
|
|
identified by keys ending with '.seg_layers') are not transferred! |
|
|
|
|
|
If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used, |
|
|
you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds |
|
|
'_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as |
|
|
nnUNetTrainer.save_checkpoint takes care of that! |
|
|
|
|
|
""" |
|
|
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar, numpy.dtype, numpy.dtypes.Float32DType]) |
|
|
saved_model = torch.load(fname) |
|
|
pretrained_dict = saved_model['network_weights'] |
|
|
|
|
|
skip_strings_in_pretrained = [ |
|
|
'.seg_layers.', |
|
|
] |
|
|
|
|
|
if isinstance(network, DDP): |
|
|
mod = network.module |
|
|
else: |
|
|
mod = network |
|
|
if isinstance(mod, OptimizedModule): |
|
|
mod = mod._orig_mod |
|
|
|
|
|
model_dict = mod.state_dict() |
|
|
|
|
|
for key, _ in model_dict.items(): |
|
|
if all([i not in key for i in skip_strings_in_pretrained]): |
|
|
assert key in pretrained_dict, \ |
|
|
f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ |
|
|
f"compatible with your network." |
|
|
assert model_dict[key].shape == pretrained_dict[key].shape, \ |
|
|
f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ |
|
|
f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ |
|
|
f"does not seem to be compatible with your network." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() |
|
|
if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} |
|
|
|
|
|
model_dict.update(pretrained_dict) |
|
|
|
|
|
print("################### Loading pretrained weights from file ", fname, '###################') |
|
|
if verbose: |
|
|
print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") |
|
|
for key, value in pretrained_dict.items(): |
|
|
print(key, 'shape', value.shape) |
|
|
print("################### Done ###################") |
|
|
mod.load_state_dict(model_dict) |
|
|
|
|
|
|
|
|
|