| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import nnunet |
| import torch |
| from batchgenerators.utilities.file_and_folder_operations import * |
| import importlib |
| import pkgutil |
| from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
|
|
|
|
| def recursive_find_python_class(folder, trainer_name, current_module): |
| tr = None |
| for importer, modname, ispkg in pkgutil.iter_modules(folder): |
| |
| if not ispkg: |
| m = importlib.import_module(current_module + "." + modname) |
| if hasattr(m, trainer_name): |
| tr = getattr(m, trainer_name) |
| break |
|
|
| if tr is None: |
| for importer, modname, ispkg in pkgutil.iter_modules(folder): |
| if ispkg: |
| next_current_module = current_module + "." + modname |
| tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module) |
| if tr is not None: |
| break |
|
|
| return tr |
|
|
|
|
| def restore_model(pkl_file, checkpoint=None, train=False, fp16=None): |
| """ |
| This is a utility function to load any nnUNet trainer from a pkl. It will recursively search |
| nnunet.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint |
| is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train). |
| The pkl file required here is the one that will be saved automatically when calling nnUNetTrainer.save_checkpoint. |
| :param pkl_file: |
| :param checkpoint: |
| :param train: |
| :param fp16: if None then we take no action. If True/False we overwrite what the model has in its init |
| :return: |
| """ |
| info = load_pickle(pkl_file) |
| init = info['init'] |
| name = info['name'] |
| search_in = join(nnunet.__path__[0], "training", "network_training") |
| tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training") |
|
|
| if tr is None: |
| """ |
| Fabian only. This will trigger searching for trainer classes in other repositories as well |
| """ |
| try: |
| import meddec |
| search_in = join(meddec.__path__[0], "model_training") |
| tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training") |
| except ImportError: |
| pass |
|
|
| if tr is None: |
| raise RuntimeError("Could not find the model trainer specified in checkpoint in nnunet.trainig.network_training. If it " |
| "is not located there, please move it or change the code of restore_model. Your model " |
| "trainer can be located in any directory within nnunet.trainig.network_training (search is recursive)." |
| "\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name)) |
| assert issubclass(tr, nnUNetTrainer), "The network trainer was found but is not a subclass of nnUNetTrainer. " \ |
| "Please make it so!" |
|
|
| |
| """if len(init) == 7: |
| print("warning: this model seems to have been saved with a previous version of nnUNet. Attempting to load it " |
| "anyways. Expect the unexpected.") |
| print("manually editing init args...") |
| init = [init[i] for i in range(len(init)) if i != 2]""" |
|
|
| |
|
|
| trainer = tr(*init) |
|
|
| |
| |
| if fp16 is not None: |
| trainer.fp16 = fp16 |
|
|
| trainer.process_plans(info['plans']) |
| if checkpoint is not None: |
| trainer.load_checkpoint(checkpoint, train) |
| return trainer |
|
|
|
|
| def load_best_model_for_inference(folder): |
| checkpoint = join(folder, "model_best.model") |
| pkl_file = checkpoint + ".pkl" |
| return restore_model(pkl_file, checkpoint, False) |
|
|
|
|
| def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"): |
| """ |
| used for if you need to ensemble the five models of a cross-validation. This will restore the model from the |
| checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast |
| switching between parameters (as opposed to loading them from disk each time). |
| |
| This is best used for inference and test prediction |
| :param folder: |
| :param folds: |
| :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init |
| :return: |
| """ |
| if isinstance(folds, str): |
| folds = [join(folder, "all")] |
| assert isdir(folds[0]), "no output folder for fold %s found" % folds |
| elif isinstance(folds, (list, tuple)): |
| if len(folds) == 1 and folds[0] == "all": |
| folds = [join(folder, "all")] |
| else: |
| folds = [join(folder, "fold_%d" % i) for i in folds] |
| assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present" |
| elif isinstance(folds, int): |
| folds = [join(folder, "fold_%d" % folds)] |
| assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds |
| elif folds is None: |
| print("folds is None so we will automatically look for output folders (not using \'all\'!)") |
| folds = subfolders(folder, prefix="fold") |
| print("found the following folds: ", folds) |
| else: |
| raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds))) |
|
|
| trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision) |
| trainer.output_folder = folder |
| trainer.output_folder_base = folder |
| trainer.update_fold(0) |
| trainer.initialize(False) |
| all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds] |
| print("using the following model files: ", all_best_model_files) |
| all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files] |
| return trainer, all_params |
|
|
|
|
| if __name__ == "__main__": |
| pkl = "/home/fabian/PhD/results/nnUNetV2/nnUNetV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl" |
| checkpoint = pkl[:-4] |
| train = False |
| trainer = restore_model(pkl, checkpoint, train) |
|
|