# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for # multi-GPU runs. Many definitions in this file are duplicated across other files for compatibility with MONAI # Label, eg. network_def, but ideally these would be in a common.yaml file used in conjunction with this one # or the other config files for testing or inference. imports: - $import os - $import datetime - $import torch - $import glob # pull out some constants from MONAI image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED both_keys: ['@image', '@label'] # multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes rank: 0 # without multi-gpu support consider the process as rank 0 anyway is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks # hyperparameters for you to modify on the command line val_interval: 1 # how often to perform validation after an epoch ckpt_interval: 1 # how often to save a checkpoint after an epoch rand_prob: 0.5 # probability a random transform is applied batch_size: 5 # number of images per batch num_epochs: 20 # number of epochs to train for num_substeps: 1 # how many times to repeatly train with the same batch num_workers: 4 # number of workers to generate batches with learning_rate: 0.001 # initial learning rate num_classes: 4 # number of classes in training data which network should predict device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # define various paths bundle_root: . # root directory of the bundle ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting dataset_dir: $@bundle_root + '/train_data' # where data is coming from results_dir: $@bundle_root + '/results' # where results are being stored to # a new output directory is chosen using a timestamp for every invocation output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' # network definition, this could be parameterised by pre-defined values or on the command line network_def: _target_: UNet spatial_dims: 3 in_channels: 1 out_channels: '@num_classes' channels: [8, 16, 32, 64] strides: [2, 2, 2] num_res_units: 2 network: $@network_def.to(@device) # dataset value, this assumes a directory filled with img##.nii.gz and lbl##.nii.gz files imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))' lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]' all_pairs: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]' partitions: '$monai.data.partition_dataset(@all_pairs, (4, 1), shuffle=True, seed=0)' train_sub: '$@partitions[0]' # train partition val_sub: '$@partitions[1]' # validation partition # these transforms are used for training and validation transform sequences base_transforms: - _target_: LoadImaged keys: '@both_keys' image_only: true - _target_: EnsureChannelFirstd keys: '@both_keys' # these are the random and regularising transforms used only for training train_transforms: - _target_: RandAxisFlipd keys: '@both_keys' prob: '@rand_prob' - _target_: RandRotate90d keys: '@both_keys' prob: '@rand_prob' - _target_: RandGaussianNoised keys: '@image' prob: '@rand_prob' std: 0.05 - _target_: ScaleIntensityd keys: '@image' # these are used for validation data so no randomness val_transforms: - _target_: ScaleIntensityd keys: '@image' # define the Compose objects for training and validation preprocessing: _target_: Compose transforms: $@base_transforms + @train_transforms val_preprocessing: _target_: Compose transforms: $@base_transforms + @val_transforms # define the datasets for training and validation train_dataset: _target_: Dataset data: '@train_sub' transform: '@preprocessing' val_dataset: _target_: Dataset data: '@val_sub' transform: '@val_preprocessing' # define the dataloaders for training and validation train_dataloader: _target_: ThreadDataLoader # generate data ansynchronously from training dataset: '@train_dataset' batch_size: '@batch_size' repeats: '@num_substeps' num_workers: '@num_workers' val_dataloader: _target_: DataLoader # faster transforms probably won't benefit from threading dataset: '@val_dataset' batch_size: '@batch_size' num_workers: '@num_workers' # Simple Dice loss configured for multi-class segmentation, for binary segmentation # use include_background==True and sigmoid==True instead of these values lossfn: _target_: DiceLoss include_background: true # if your segmentations are relatively small it might help for this to be false to_onehot_y: true # convert ground truth to one-hot for training softmax: true # softmax applied to prediction # hyperparameters could be added for other arguments of this class optimizer: _target_: torch.optim.Adam params: $@network.parameters() lr: '@learning_rate' # should be replaced with other inferer types if training process is different for your network inferer: _target_: SimpleInferer # transform to apply to data from network to be suitable for validation postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: '@pred' softmax: true - _target_: AsDiscreted keys: ['@pred', '@label'] argmax: [true, false] to_onehot: '@num_classes' # validation handlers to gather statistics, log these to a file, and save best checkpoint val_handlers: - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to output_transform: '$lambda x: None' - _target_: LogfileHandler # log outputs from the validation engine output_dir: '@output_dir' - _target_: CheckpointSaver _disabled_: '@is_not_rank0' # only need rank 0 to save save_dir: '@output_dir' save_dict: model: '@network' save_interval: 0 # don't save iterations, just when the metric improves save_final: false epoch_level: false save_key_metric: true key_metric_name: val_mean_dice # save the checkpoint when this value improves # engine for running validation, ties together objects defined above and has metric definitions evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@val_dataloader' network: '@network' postprocessing: '@postprocessing' key_val_metric: val_mean_dice: _target_: MeanDice include_background: false output_transform: $monai.handlers.from_engine([@pred, @label]) val_mean_iou: _target_: MeanIoUHandler include_background: false output_transform: $monai.handlers.from_engine([@pred, @label]) additional_metrics: val_mae: # can have other metrics, MAE not great for segmentation tasks so here just to demo _target_: MeanAbsoluteError output_transform: $monai.handlers.from_engine([@pred, @label]) val_handlers: '@val_handlers' # gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately metriclogger: _target_: MetricLogger evaluator: '@evaluator' handlers: - '@metriclogger' - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects validator: '@evaluator' epoch_level: true interval: '@val_interval' - _target_: CheckpointSaver _disabled_: '@is_not_rank0' # only need rank 0 to save save_dir: '@output_dir' save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary model: '@network' logger: '@metriclogger' save_interval: '@ckpt_interval' save_final: true epoch_level: true - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to tag_name: train_loss output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value - _target_: LogfileHandler # log outputs from the training engine output_dir: '@output_dir' # engine for training, ties values defined above together into the main engine for the training process trainer: _target_: SupervisedTrainer max_epochs: '@num_epochs' device: '@device' train_data_loader: '@train_dataloader' network: '@network' inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided loss_function: '@lossfn' optimizer: '@optimizer' # postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing key_train_metric: null train_handlers: '@handlers' run: - $@trainer.run()