File size: 3,882 Bytes
3f29a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics.

imports:
- $import os
- $import datetime
- $import torch
- $import glob

# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
both_keys: ['@image', '@label']

# hyperparameters for you to modify on the command line
batch_size: 1  # number of images per batch
num_workers: 0  # number of workers to generate batches with
num_classes: 4  # number of classes in training data which network should predict
save_pred: false  # whether to save prediction images or just run metric tests
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# define various paths
bundle_root: .  # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt'  # checkpoint to load before starting
dataset_dir: $@bundle_root + '/test_data'  # where data is coming from
output_dir: './outputs'  # directory to store images to if save_pred is true

# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
  _target_: UNet
  spatial_dims: 3
  in_channels: 1
  out_channels: '@num_classes'
  channels: [8, 16, 32, 64]
  strides: [2, 2, 2]
  num_res_units: 2
network: $@network_def.to(@device)

# list all niftis in the input directory
file_pattern: '*.nii*'
data_list: '$list(sorted(glob.glob(os.path.join(@dataset_dir, @file_pattern))))'
# collect data dictionaries for all files
imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))'
lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]'
data_dicts: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]'

# these transforms are used for inference to load and regularise inputs
transforms:
- _target_: LoadImaged
  keys: '@both_keys'
  image_only: true
- _target_: EnsureChannelFirstd
  keys: '@both_keys'
- _target_: ScaleIntensityd
  keys: '@image'

preprocessing:
  _target_: Compose
  transforms: $@transforms

dataset:
  _target_: Dataset
  data: '@data_dicts'
  transform: '@preprocessing'

dataloader:
  _target_: ThreadDataLoader  # generate data ansynchronously from inference
  dataset: '@dataset'
  batch_size: '@batch_size'
  num_workers: '@num_workers'

# should be replaced with other inferer types if training process is different for your network
inferer:
  _target_: SimpleInferer

# transform to apply to data from network to be suitable for loss function and validation
postprocessing:
  _target_: Compose
  transforms:
  - _target_: Activationsd
    keys: '@pred'
    softmax: true
  - _target_: AsDiscreted
    keys: '@pred'
    argmax: true
  - _target_: SaveImaged
    _disabled_: '$not @save_pred'
    keys: '@pred'
    meta_keys: pred_meta_dict
    data_root_dir: '@dataset_dir'
    output_dir: '@output_dir'
    dtype: $None
    output_dtype: $None
    output_postfix: ''
    resample: false
    separate_folder: true

# inference handlers to load checkpoint, gather statistics
handlers:
- _target_: CheckpointLoader
  _disabled_: $not os.path.exists(@ckpt_path)
  load_path: '@ckpt_path'
  load_dict:
    model: '@network'
- _target_: StatsHandler
  name: null  # use engine.logger as the Logger object to log to
  output_transform: '$lambda x: None'

# engine for running inference, ties together objects defined above and has metric definitions
evaluator:
  _target_: SupervisedEvaluator
  device: '@device'
  val_data_loader: '@dataloader'
  network: '@network'
  postprocessing: '@postprocessing'
  key_val_metric:
    val_mean_dice:
      _target_: MeanDice
      include_background: false
      output_transform: $monai.handlers.from_engine([@pred, @label])
  val_handlers: '@handlers'

run:
- $@evaluator.run()
- '$print(''Per-image Dice:\n'',@evaluator.state.metric_details[''val_mean_dice''].cpu().numpy())'