| import os |
| import socket |
| from typing import Union, Optional |
|
|
| import nnunetv2 |
| import torch.cuda |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json |
| from nnunetv2.paths import nnUNet_preprocessed |
| from nnunetv2.run.load_pretrained_weights import load_pretrained_weights |
| from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer |
| from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name |
| from nnunetv2.utilities.find_class_by_name import recursive_find_python_class |
| from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset |
| from torch.backends import cudnn |
|
|
|
|
| def find_free_network_port() -> int: |
| """Finds a free port on localhost. |
| |
| It is useful in single-node training when we don't want to connect to a real main node but have to set the |
| `MASTER_PORT` environment variable. |
| """ |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.bind(("", 0)) |
| port = s.getsockname()[1] |
| s.close() |
| return port |
|
|
|
|
| def get_trainer_from_args(dataset_name_or_id: Union[int, str], |
| configuration: str, |
| fold: int, |
| trainer_name: str = 'nnUNetTrainer', |
| plans_identifier: str = 'nnUNetPlans', |
| use_compressed: bool = False, |
| device: torch.device = torch.device('cuda')): |
| |
| nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), |
| trainer_name, 'nnunetv2.training.nnUNetTrainer') |
| if nnunet_trainer is None: |
| raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in ' |
| f'nnunetv2.training.nnUNetTrainer (' |
| f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere ' |
| f'else, please move it there.') |
| assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \ |
| 'nnUNetTrainer' |
|
|
| |
| if dataset_name_or_id.startswith('Dataset'): |
| pass |
| else: |
| try: |
| dataset_name_or_id = int(dataset_name_or_id) |
| except ValueError: |
| raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern ' |
| f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your ' |
| f'input: {dataset_name_or_id}') |
|
|
| |
| preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id)) |
| plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json') |
| plans = load_json(plans_file) |
| dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json')) |
| nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold, |
| dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device) |
| return nnunet_trainer |
|
|
|
|
| def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool, |
| pretrained_weights_file: str = None): |
| if continue_training and pretrained_weights_file is not None: |
| raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' |
| 'be used at the beginning of the training.') |
| if continue_training: |
| expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') |
| if not isfile(expected_checkpoint_file): |
| expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') |
| |
| if not isfile(expected_checkpoint_file): |
| expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') |
| if not isfile(expected_checkpoint_file): |
| print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " |
| f"continue from. Starting a new training...") |
| expected_checkpoint_file = None |
| elif validation_only: |
| expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') |
| if not isfile(expected_checkpoint_file): |
| raise RuntimeError(f"Cannot run validation because the training is not finished yet!") |
| else: |
| if pretrained_weights_file is not None: |
| if not nnunet_trainer.was_initialized: |
| nnunet_trainer.initialize() |
| load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True) |
| expected_checkpoint_file = None |
|
|
| if expected_checkpoint_file is not None: |
| nnunet_trainer.load_checkpoint(expected_checkpoint_file) |
|
|
|
|
| def setup_ddp(rank, world_size): |
| |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
|
|
| def cleanup_ddp(): |
| dist.destroy_process_group() |
|
|
|
|
| def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, |
| pretrained_weights, npz, val_with_best, world_size): |
| setup_ddp(rank, world_size) |
| torch.cuda.set_device(torch.device('cuda', dist.get_rank())) |
|
|
| nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, |
| use_compressed) |
|
|
| if disable_checkpointing: |
| nnunet_trainer.disable_checkpointing = disable_checkpointing |
|
|
| assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.' |
|
|
| maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights) |
|
|
| if torch.cuda.is_available(): |
| cudnn.deterministic = False |
| cudnn.benchmark = True |
|
|
| if not val: |
| nnunet_trainer.run_training() |
|
|
| if val_with_best: |
| nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) |
| nnunet_trainer.perform_actual_validation(npz) |
| cleanup_ddp() |
|
|
|
|
| def run_unpacking(dataset_name_or_id: Union[str, int], |
| configuration: str, fold: Union[int, str], |
| trainer_class_name: str = 'nnUNetTrainer', |
| plans_identifier: str = 'nnUNetPlans', |
| pretrained_weights: Optional[str] = None, |
| num_gpus: int = 1, |
| use_compressed_data: bool = False, |
| export_validation_probabilities: bool = False, |
| continue_training: bool = False, |
| only_run_validation: bool = False, |
| disable_checkpointing: bool = False, |
| val_with_best: bool = False, |
| device: torch.device = torch.device('cuda')): |
|
|
| if isinstance(fold, str): |
| if fold != 'all': |
| try: |
| fold = int(fold) |
| except ValueError as e: |
| print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') |
| raise e |
|
|
| if val_with_best: |
| assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing' |
|
|
| nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, |
| plans_identifier, use_compressed_data, device=device) |
| |
| unpack_dataset(nnunet_trainer.preprocessed_dataset_folder) |
|
|
|
|
|
|
|
|
| def run_training(dataset_name_or_id: Union[str, int], |
| configuration: str, fold: Union[int, str], |
| trainer_class_name: str = 'nnUNetTrainer', |
| plans_identifier: str = 'nnUNetPlans', |
| pretrained_weights: Optional[str] = None, |
| num_gpus: int = 1, |
| use_compressed_data: bool = False, |
| export_validation_probabilities: bool = False, |
| continue_training: bool = False, |
| only_run_validation: bool = False, |
| disable_checkpointing: bool = False, |
| val_with_best: bool = False, |
| device: torch.device = torch.device('cuda')): |
| if isinstance(fold, str): |
| if fold != 'all': |
| try: |
| fold = int(fold) |
| except ValueError as e: |
| print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') |
| raise e |
|
|
| if val_with_best: |
| assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing' |
|
|
| if num_gpus > 1: |
| assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}" |
|
|
| os.environ['MASTER_ADDR'] = 'localhost' |
| if 'MASTER_PORT' not in os.environ.keys(): |
| port = str(find_free_network_port()) |
| print(f"using port {port}") |
| os.environ['MASTER_PORT'] = port |
|
|
| mp.spawn(run_ddp, |
| args=( |
| dataset_name_or_id, |
| configuration, |
| fold, |
| trainer_class_name, |
| plans_identifier, |
| use_compressed_data, |
| disable_checkpointing, |
| continue_training, |
| only_run_validation, |
| pretrained_weights, |
| export_validation_probabilities, |
| val_with_best, |
| num_gpus), |
| nprocs=num_gpus, |
| join=True) |
| else: |
| nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, |
| plans_identifier, use_compressed_data, device=device) |
|
|
| if disable_checkpointing: |
| nnunet_trainer.disable_checkpointing = disable_checkpointing |
|
|
| assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.' |
|
|
| maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) |
|
|
| if torch.cuda.is_available(): |
| cudnn.deterministic = False |
| cudnn.benchmark = True |
|
|
| if not only_run_validation: |
| nnunet_trainer.run_training() |
|
|
| if val_with_best: |
| nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) |
| nnunet_trainer.perform_actual_validation(export_validation_probabilities) |
|
|
|
|
| def run_training_entry(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('dataset_name_or_id', type=str, |
| help="Dataset name or ID to train with") |
| parser.add_argument('configuration', type=str, |
| help="Configuration that should be trained") |
| parser.add_argument('fold', type=str, |
| help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.') |
| parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', |
| help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer') |
| parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', |
| help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans') |
| parser.add_argument('-pretrained_weights', type=str, required=False, default=None, |
| help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only ' |
| 'be used when actually training. Beta. Use with caution.') |
| parser.add_argument('-num_gpus', type=int, default=1, required=False, |
| help='Specify the number of GPUs to use for training') |
| parser.add_argument("--use_compressed", default=False, action="store_true", required=False, |
| help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed " |
| "data is much more CPU and (potentially) RAM intensive and should only be used if you " |
| "know what you are doing") |
| parser.add_argument('--npz', action='store_true', required=False, |
| help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted ' |
| 'segmentations). Needed for finding the best ensemble.') |
| parser.add_argument('--c', action='store_true', required=False, |
| help='[OPTIONAL] Continue training from latest checkpoint') |
| parser.add_argument('--val', action='store_true', required=False, |
| help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.') |
| parser.add_argument('--val_best', action='store_true', required=False, |
| help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead ' |
| 'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! ' |
| 'WARNING: This will use the same \'validation\' folder as the regular validation ' |
| 'with no way of distinguishing the two!') |
| parser.add_argument('--disable_checkpointing', action='store_true', required=False, |
| help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and ' |
| 'you dont want to flood your hard drive with checkpoints.') |
| parser.add_argument('-device', type=str, default='cuda', required=False, |
| help="Use this to set the device the training should run with. Available options are 'cuda' " |
| "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " |
| "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") |
| args = parser.parse_args() |
|
|
| assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' |
| if args.device == 'cpu': |
| |
| import multiprocessing |
| torch.set_num_threads(multiprocessing.cpu_count()) |
| device = torch.device('cpu') |
| elif args.device == 'cuda': |
| |
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
| device = torch.device('cuda') |
| else: |
| device = torch.device('mps') |
|
|
| run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, |
| args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best, |
| device=device) |
|
|
| def run_unpacking_entry(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('dataset_name_or_id', type=str, |
| help="Dataset name or ID to train with") |
| parser.add_argument('configuration', type=str, |
| help="Configuration that should be trained") |
| parser.add_argument('fold', type=str, |
| help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.') |
| parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', |
| help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer') |
| parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', |
| help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans') |
| parser.add_argument('-pretrained_weights', type=str, required=False, default=None, |
| help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only ' |
| 'be used when actually training. Beta. Use with caution.') |
| parser.add_argument('-num_gpus', type=int, default=1, required=False, |
| help='Specify the number of GPUs to use for training') |
| parser.add_argument("--use_compressed", default=False, action="store_true", required=False, |
| help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed " |
| "data is much more CPU and (potentially) RAM intensive and should only be used if you " |
| "know what you are doing") |
| parser.add_argument('--npz', action='store_true', required=False, |
| help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted ' |
| 'segmentations). Needed for finding the best ensemble.') |
| parser.add_argument('--c', action='store_true', required=False, |
| help='[OPTIONAL] Continue training from latest checkpoint') |
| parser.add_argument('--val', action='store_true', required=False, |
| help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.') |
| parser.add_argument('--val_best', action='store_true', required=False, |
| help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead ' |
| 'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! ' |
| 'WARNING: This will use the same \'validation\' folder as the regular validation ' |
| 'with no way of distinguishing the two!') |
| parser.add_argument('--disable_checkpointing', action='store_true', required=False, |
| help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and ' |
| 'you dont want to flood your hard drive with checkpoints.') |
| parser.add_argument('-device', type=str, default='cuda', required=False, |
| help="Use this to set the device the training should run with. Available options are 'cuda' " |
| "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " |
| "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") |
| args = parser.parse_args() |
|
|
| assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' |
| if args.device == 'cpu': |
| |
| import multiprocessing |
| torch.set_num_threads(multiprocessing.cpu_count()) |
| device = torch.device('cpu') |
| elif args.device == 'cuda': |
| |
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
| device = torch.device('cuda') |
| else: |
| device = torch.device('mps') |
|
|
| run_unpacking(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, |
| args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best, |
| device=device) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| run_training_entry() |
|
|