# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. imports: - $import os - $import datetime - $import torch - $import glob # pull out some constants from MONAI image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED both_keys: ['@image', '@label'] # hyperparameters for you to modify on the command line batch_size: 1 # number of images per batch num_workers: 0 # number of workers to generate batches with num_classes: 4 # number of classes in training data which network should predict save_pred: false # whether to save prediction images or just run metric tests device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # define various paths bundle_root: . # root directory of the bundle ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting dataset_dir: $@bundle_root + '/test_data' # where data is coming from output_dir: './outputs' # directory to store images to if save_pred is true # network definition, this could be parameterised by pre-defined values or on the command line network_def: _target_: UNet spatial_dims: 3 in_channels: 1 out_channels: '@num_classes' channels: [8, 16, 32, 64] strides: [2, 2, 2] num_res_units: 2 network: $@network_def.to(@device) # list all niftis in the input directory file_pattern: '*.nii*' data_list: '$list(sorted(glob.glob(os.path.join(@dataset_dir, @file_pattern))))' # collect data dictionaries for all files imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))' lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]' data_dicts: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]' # these transforms are used for inference to load and regularise inputs transforms: - _target_: LoadImaged keys: '@both_keys' image_only: true - _target_: EnsureChannelFirstd keys: '@both_keys' - _target_: ScaleIntensityd keys: '@image' preprocessing: _target_: Compose transforms: $@transforms dataset: _target_: Dataset data: '@data_dicts' transform: '@preprocessing' dataloader: _target_: ThreadDataLoader # generate data ansynchronously from inference dataset: '@dataset' batch_size: '@batch_size' num_workers: '@num_workers' # should be replaced with other inferer types if training process is different for your network inferer: _target_: SimpleInferer # transform to apply to data from network to be suitable for loss function and validation postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: '@pred' softmax: true - _target_: AsDiscreted keys: '@pred' argmax: true - _target_: SaveImaged _disabled_: '$not @save_pred' keys: '@pred' meta_keys: pred_meta_dict data_root_dir: '@dataset_dir' output_dir: '@output_dir' dtype: $None output_dtype: $None output_postfix: '' resample: false separate_folder: true # inference handlers to load checkpoint, gather statistics handlers: - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to output_transform: '$lambda x: None' # engine for running inference, ties together objects defined above and has metric definitions evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@dataloader' network: '@network' postprocessing: '@postprocessing' key_val_metric: val_mean_dice: _target_: MeanDice include_background: false output_transform: $monai.handlers.from_engine([@pred, @label]) val_handlers: '@handlers' run: - $@evaluator.run() - '$print(''Per-image Dice:\n'',@evaluator.state.metric_details[''val_mean_dice''].cpu().numpy())'