sakshirathi360 commited on
Commit
1f90840
·
verified ·
1 Parent(s): 9976caf

Upload folder using huggingface_hub

Browse files
configs/hyper_parameters.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _meta_: {}
2
+ bundle_root: /Users/sakshirathi/Downloads/work_dir/segresnet_0
3
+ ckpt_path: $@bundle_root + '/model'
4
+ mlflow_tracking_uri: $@ckpt_path + '/mlruns/'
5
+ mlflow_experiment_name: Auto3DSeg
6
+ data_file_base_dir: /Users/sakshirathi/Documents/ShamLab
7
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
8
+ modality: ct
9
+ fold: 0
10
+ input_channels: 1
11
+ output_classes: 2
12
+ class_names: null
13
+ class_index: null
14
+ debug: false
15
+ ckpt_save: true
16
+ cache_rate: null
17
+ roi_size: [384, 384, 60]
18
+ auto_scale_allowed: true
19
+ auto_scale_batch: true
20
+ auto_scale_roi: false
21
+ auto_scale_filters: false
22
+ quick: false
23
+ channels_last: true
24
+ validate_final_original_res: true
25
+ calc_val_loss: false
26
+ amp: true
27
+ log_output_file: null
28
+ cache_class_indices: null
29
+ early_stopping_fraction: 0.001
30
+ determ: false
31
+ orientation_ras: true
32
+ crop_foreground: true
33
+ learning_rate: 0.0002
34
+ batch_size: 1
35
+ num_images_per_batch: 1
36
+ num_epochs: 1250
37
+ num_warmup_epochs: 3
38
+ sigmoid: false
39
+ resample: true
40
+ resample_resolution: [0.48766356436698155, 0.4876635832539761, 2.748479210553717]
41
+ crop_mode: ratio
42
+ normalize_mode: range
43
+ intensity_bounds: [39.63595217750186, 97.59593563988095]
44
+ num_epochs_per_validation: null
45
+ num_epochs_per_saving: 1
46
+ num_workers: 4
47
+ num_steps_per_image: null
48
+ num_crops_per_image: 2
49
+ loss: {_target_: DiceCELoss, include_background: true, squared_pred: true, smooth_nr: 0,
50
+ smooth_dr: 1.0e-05, softmax: $not @sigmoid, sigmoid: $@sigmoid, to_onehot_y: $not
51
+ @sigmoid}
52
+ optimizer: {_target_: torch.optim.AdamW, lr: '@learning_rate', weight_decay: 1.0e-05}
53
+ network:
54
+ _target_: SegResNetDS
55
+ init_filters: 32
56
+ blocks_down: [1, 2, 2, 4, 4]
57
+ norm: INSTANCE_NVFUSER
58
+ in_channels: '@input_channels'
59
+ out_channels: '@output_classes'
60
+ dsdepth: 4
61
+ finetune: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt'}
62
+ validate: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt', output_path: $@bundle_root
63
+ + '/prediction_validation', save_mask: false, invert: true}
64
+ infer: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt', output_path: $@bundle_root
65
+ + '/prediction_' + @infer#data_list_key, data_list_key: testing}
66
+ anisotropic_scales: true
67
+ spacing_median: [0.48766356436698155, 0.4876635832539761, 4.770811902267695]
68
+ spacing_lower: [0.42813486948609353, 0.428134856247896, 2.499999978382533]
69
+ spacing_upper: [0.5859375, 0.5859375004856939, 5.012642938162783]
70
+ image_size_mm_median: [249.68374495589455, 249.68375462603575, 168.30083390623668]
71
+ image_size_mm_90: [265.61599121093747, 265.6159922216141, 190.12765338720757]
72
+ image_size: [544, 544, 69]
model/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc8b31e85759b2e6f77b9ec71df0c07988281f8a8ec349b349c2a31c68a3b846
3
+ size 345158862
model/training.log ADDED
@@ -0,0 +1,1818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _meta_: {}
2
+ acc: null
3
+ amp: false
4
+ anisotropic_scales: true
5
+ auto_scale_allowed: true
6
+ auto_scale_batch: true
7
+ auto_scale_filters: false
8
+ auto_scale_roi: false
9
+ batch_size: 1
10
+ bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
11
+ cache_class_indices: null
12
+ cache_rate: null
13
+ calc_val_loss: false
14
+ channels_last: true
15
+ ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
16
+ ckpt_save: true
17
+ class_index: null
18
+ class_names:
19
+ - acc_0
20
+ crop_add_background: true
21
+ crop_foreground: true
22
+ crop_mode: ratio
23
+ crop_ratios: null
24
+ cuda: false
25
+ data_file_base_dir: /Users/sakshirathi/neurotk/bundles
26
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
27
+ debug: false
28
+ determ: false
29
+ early_stopping_fraction: 0.001
30
+ extra_modalities: {}
31
+ finetune:
32
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
33
+ enabled: false
34
+ float32_precision: null
35
+ fold: 0
36
+ fork: true
37
+ global_rank: 0
38
+ image_size:
39
+ - 544
40
+ - 544
41
+ - 69
42
+ image_size_mm_90:
43
+ - 265.61599121093747
44
+ - 265.6159922216141
45
+ - 190.12765338720757
46
+ image_size_mm_median:
47
+ - 249.68374495589455
48
+ - 249.68375462603575
49
+ - 168.30083390623668
50
+ infer:
51
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
52
+ data_list_key: testing
53
+ enabled: true
54
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
55
+ input_channels: 1
56
+ intensity_bounds:
57
+ - 39.63595217750186
58
+ - 97.59593563988095
59
+ learning_rate: 0.0002
60
+ log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
61
+ loss:
62
+ _target_: DiceCELoss
63
+ include_background: true
64
+ sigmoid: false
65
+ smooth_dr: 1.0e-05
66
+ smooth_nr: 0
67
+ softmax: true
68
+ squared_pred: true
69
+ to_onehot_y: true
70
+ max_samples_per_class: 12500
71
+ mlflow_experiment_name: Auto3DSeg
72
+ mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
73
+ modality: ct
74
+ network:
75
+ _target_: SegResNetDS
76
+ blocks_down:
77
+ - 1
78
+ - 2
79
+ - 2
80
+ - 4
81
+ - 4
82
+ dsdepth: 4
83
+ in_channels: 1
84
+ init_filters: 32
85
+ norm: INSTANCE_NVFUSER
86
+ out_channels: 2
87
+ normalize_mode: range
88
+ notf32: false
89
+ num_crops_per_image: 2
90
+ num_epochs: 1250
91
+ num_epochs_per_saving: 1
92
+ num_epochs_per_validation: null
93
+ num_images_per_batch: 1
94
+ num_steps_per_image: null
95
+ num_warmup_epochs: 3
96
+ num_workers: 4
97
+ optimizer:
98
+ _target_: torch.optim.AdamW
99
+ lr: 0.0002
100
+ weight_decay: 1.0e-05
101
+ orientation_ras: true
102
+ output_classes: 2
103
+ pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
104
+ quick: false
105
+ rank: 0
106
+ resample: true
107
+ resample_resolution:
108
+ - 0.48766356436698155
109
+ - 0.4876635832539761
110
+ - 2.748479210553717
111
+ roi_size:
112
+ - 384
113
+ - 384
114
+ - 60
115
+ sigmoid: false
116
+ spacing_lower:
117
+ - 0.42813486948609353
118
+ - 0.428134856247896
119
+ - 2.499999978382533
120
+ spacing_median:
121
+ - 0.48766356436698155
122
+ - 0.4876635832539761
123
+ - 4.770811902267695
124
+ spacing_upper:
125
+ - 0.5859375
126
+ - 0.5859375004856939
127
+ - 5.012642938162783
128
+ start_epoch: 0
129
+ stop_on_lowacc: true
130
+ validate:
131
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
132
+ enabled: false
133
+ invert: true
134
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
135
+ save_mask: false
136
+ validate_final_original_res: true
137
+
138
+ auto_adjust_network_settings no distributed global_rank 0
139
+ GPU device memory min: 16
140
+ base_numel 7225344 gpu_factor 1 gpu_factor_init 1
141
+ input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
142
+ increasing roi step [ 257.600 257.600 61.000]
143
+ increasing roi result 1 [ 257.600 257.600 61.000]
144
+ increasing roi step [ 296.240 296.240 61.000]
145
+ increasing roi result 1 [ 296.240 296.240 61.000]
146
+ increasing roi step [ 340.676 340.676 61.000]
147
+ increasing roi result 1 [ 340.676 340.676 61.000]
148
+ increasing roi step [ 391.777 391.777 61.000]
149
+ increasing roi result 1 [ 391.777 391.777 61.000]
150
+ roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
151
+ kept filters the same base_numel 7225344, gpu_factor 1
152
+ kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
153
+ Suggested network parameters:
154
+ Batch size 1 => 1
155
+ ROI size [224, 224, 144] => [384, 384, 60]
156
+ init_filters 32 => 32
157
+ aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
158
+
159
+ Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
160
+ SegResNetDS(
161
+ (encoder): SegResEncoder(
162
+ (conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
163
+ (layers): ModuleList(
164
+ (0): ModuleDict(
165
+ (blocks): Sequential(
166
+ (0): SegResBlock(
167
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
168
+ (act1): ReLU(inplace=True)
169
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
170
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
171
+ (act2): ReLU(inplace=True)
172
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
173
+ )
174
+ )
175
+ (downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
176
+ )
177
+ (1): ModuleDict(
178
+ (blocks): Sequential(
179
+ (0): SegResBlock(
180
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
181
+ (act1): ReLU(inplace=True)
182
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
183
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
184
+ (act2): ReLU(inplace=True)
185
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
186
+ )
187
+ (1): SegResBlock(
188
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
189
+ (act1): ReLU(inplace=True)
190
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
191
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
192
+ (act2): ReLU(inplace=True)
193
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
194
+ )
195
+ )
196
+ (downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
197
+ )
198
+ (2): ModuleDict(
199
+ (blocks): Sequential(
200
+ (0): SegResBlock(
201
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
202
+ (act1): ReLU(inplace=True)
203
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
204
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
205
+ (act2): ReLU(inplace=True)
206
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
207
+ )
208
+ (1): SegResBlock(
209
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
210
+ (act1): ReLU(inplace=True)
211
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
212
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
213
+ (act2): ReLU(inplace=True)
214
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
215
+ )
216
+ )
217
+ (downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
218
+ )
219
+ (3): ModuleDict(
220
+ (blocks): Sequential(
221
+ (0): SegResBlock(
222
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
223
+ (act1): ReLU(inplace=True)
224
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
225
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
226
+ (act2): ReLU(inplace=True)
227
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
228
+ )
229
+ (1): SegResBlock(
230
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
231
+ (act1): ReLU(inplace=True)
232
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
233
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
234
+ (act2): ReLU(inplace=True)
235
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
236
+ )
237
+ (2): SegResBlock(
238
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
239
+ (act1): ReLU(inplace=True)
240
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
241
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
242
+ (act2): ReLU(inplace=True)
243
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
244
+ )
245
+ (3): SegResBlock(
246
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
247
+ (act1): ReLU(inplace=True)
248
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
249
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
250
+ (act2): ReLU(inplace=True)
251
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
252
+ )
253
+ )
254
+ (downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
255
+ )
256
+ (4): ModuleDict(
257
+ (blocks): Sequential(
258
+ (0): SegResBlock(
259
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
260
+ (act1): ReLU(inplace=True)
261
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
262
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
263
+ (act2): ReLU(inplace=True)
264
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
265
+ )
266
+ (1): SegResBlock(
267
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
268
+ (act1): ReLU(inplace=True)
269
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
270
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
271
+ (act2): ReLU(inplace=True)
272
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
273
+ )
274
+ (2): SegResBlock(
275
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
276
+ (act1): ReLU(inplace=True)
277
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
278
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
279
+ (act2): ReLU(inplace=True)
280
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
281
+ )
282
+ (3): SegResBlock(
283
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
284
+ (act1): ReLU(inplace=True)
285
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
286
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
287
+ (act2): ReLU(inplace=True)
288
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
289
+ )
290
+ )
291
+ (downsample): Identity()
292
+ )
293
+ )
294
+ )
295
+ (up_layers): ModuleList(
296
+ (0): ModuleDict(
297
+ (upsample): UpSample(
298
+ (deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
299
+ )
300
+ (blocks): Sequential(
301
+ (0): SegResBlock(
302
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
303
+ (act1): ReLU(inplace=True)
304
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
305
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
306
+ (act2): ReLU(inplace=True)
307
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
308
+ )
309
+ )
310
+ (head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
311
+ )
312
+ (1): ModuleDict(
313
+ (upsample): UpSample(
314
+ (deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
315
+ )
316
+ (blocks): Sequential(
317
+ (0): SegResBlock(
318
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
319
+ (act1): ReLU(inplace=True)
320
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
321
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
322
+ (act2): ReLU(inplace=True)
323
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
324
+ )
325
+ )
326
+ (head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
327
+ )
328
+ (2): ModuleDict(
329
+ (upsample): UpSample(
330
+ (deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
331
+ )
332
+ (blocks): Sequential(
333
+ (0): SegResBlock(
334
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
335
+ (act1): ReLU(inplace=True)
336
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
337
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
338
+ (act2): ReLU(inplace=True)
339
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
340
+ )
341
+ )
342
+ (head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
343
+ )
344
+ (3): ModuleDict(
345
+ (upsample): UpSample(
346
+ (deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
347
+ )
348
+ (blocks): Sequential(
349
+ (0): SegResBlock(
350
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
351
+ (act1): ReLU(inplace=True)
352
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
353
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
354
+ (act2): ReLU(inplace=True)
355
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
356
+ )
357
+ )
358
+ (head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
359
+ )
360
+ )
361
+ )
362
+ => loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
363
+ Total parameters count: 86278888 distributed: False
364
+ Inference complete, time 234.85s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
365
+ _meta_: {}
366
+ acc: null
367
+ amp: false
368
+ anisotropic_scales: true
369
+ auto_scale_allowed: true
370
+ auto_scale_batch: true
371
+ auto_scale_filters: false
372
+ auto_scale_roi: false
373
+ batch_size: 1
374
+ bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
375
+ cache_class_indices: null
376
+ cache_rate: null
377
+ calc_val_loss: false
378
+ channels_last: true
379
+ ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
380
+ ckpt_save: true
381
+ class_index: null
382
+ class_names:
383
+ - acc_0
384
+ crop_add_background: true
385
+ crop_foreground: true
386
+ crop_mode: ratio
387
+ crop_ratios: null
388
+ cuda: false
389
+ data_file_base_dir: /Users/sakshirathi/neurotk/bundles
390
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
391
+ debug: false
392
+ determ: false
393
+ early_stopping_fraction: 0.001
394
+ extra_modalities: {}
395
+ finetune:
396
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
397
+ enabled: false
398
+ float32_precision: null
399
+ fold: 0
400
+ fork: true
401
+ global_rank: 0
402
+ image_size:
403
+ - 544
404
+ - 544
405
+ - 69
406
+ image_size_mm_90:
407
+ - 265.61599121093747
408
+ - 265.6159922216141
409
+ - 190.12765338720757
410
+ image_size_mm_median:
411
+ - 249.68374495589455
412
+ - 249.68375462603575
413
+ - 168.30083390623668
414
+ infer:
415
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
416
+ data_list_key: testing
417
+ enabled: true
418
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
419
+ input_channels: 1
420
+ intensity_bounds:
421
+ - 39.63595217750186
422
+ - 97.59593563988095
423
+ learning_rate: 0.0002
424
+ log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
425
+ loss:
426
+ _target_: DiceCELoss
427
+ include_background: true
428
+ sigmoid: false
429
+ smooth_dr: 1.0e-05
430
+ smooth_nr: 0
431
+ softmax: true
432
+ squared_pred: true
433
+ to_onehot_y: true
434
+ max_samples_per_class: 12500
435
+ mlflow_experiment_name: Auto3DSeg
436
+ mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
437
+ modality: ct
438
+ network:
439
+ _target_: SegResNetDS
440
+ blocks_down:
441
+ - 1
442
+ - 2
443
+ - 2
444
+ - 4
445
+ - 4
446
+ dsdepth: 4
447
+ in_channels: 1
448
+ init_filters: 32
449
+ norm: INSTANCE_NVFUSER
450
+ out_channels: 2
451
+ normalize_mode: range
452
+ notf32: false
453
+ num_crops_per_image: 2
454
+ num_epochs: 1250
455
+ num_epochs_per_saving: 1
456
+ num_epochs_per_validation: null
457
+ num_images_per_batch: 1
458
+ num_steps_per_image: null
459
+ num_warmup_epochs: 3
460
+ num_workers: 4
461
+ optimizer:
462
+ _target_: torch.optim.AdamW
463
+ lr: 0.0002
464
+ weight_decay: 1.0e-05
465
+ orientation_ras: true
466
+ output_classes: 2
467
+ pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
468
+ quick: false
469
+ rank: 0
470
+ resample: true
471
+ resample_resolution:
472
+ - 0.48766356436698155
473
+ - 0.4876635832539761
474
+ - 2.748479210553717
475
+ roi_size:
476
+ - 384
477
+ - 384
478
+ - 60
479
+ sigmoid: false
480
+ spacing_lower:
481
+ - 0.42813486948609353
482
+ - 0.428134856247896
483
+ - 2.499999978382533
484
+ spacing_median:
485
+ - 0.48766356436698155
486
+ - 0.4876635832539761
487
+ - 4.770811902267695
488
+ spacing_upper:
489
+ - 0.5859375
490
+ - 0.5859375004856939
491
+ - 5.012642938162783
492
+ start_epoch: 0
493
+ stop_on_lowacc: true
494
+ validate:
495
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
496
+ enabled: false
497
+ invert: true
498
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
499
+ save_mask: false
500
+ validate_final_original_res: true
501
+
502
+ auto_adjust_network_settings no distributed global_rank 0
503
+ GPU device memory min: 16
504
+ base_numel 7225344 gpu_factor 1 gpu_factor_init 1
505
+ input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
506
+ increasing roi step [ 257.600 257.600 61.000]
507
+ increasing roi result 1 [ 257.600 257.600 61.000]
508
+ increasing roi step [ 296.240 296.240 61.000]
509
+ increasing roi result 1 [ 296.240 296.240 61.000]
510
+ increasing roi step [ 340.676 340.676 61.000]
511
+ increasing roi result 1 [ 340.676 340.676 61.000]
512
+ increasing roi step [ 391.777 391.777 61.000]
513
+ increasing roi result 1 [ 391.777 391.777 61.000]
514
+ roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
515
+ kept filters the same base_numel 7225344, gpu_factor 1
516
+ kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
517
+ Suggested network parameters:
518
+ Batch size 1 => 1
519
+ ROI size [224, 224, 144] => [384, 384, 60]
520
+ init_filters 32 => 32
521
+ aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
522
+
523
+ Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
524
+ SegResNetDS(
525
+ (encoder): SegResEncoder(
526
+ (conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
527
+ (layers): ModuleList(
528
+ (0): ModuleDict(
529
+ (blocks): Sequential(
530
+ (0): SegResBlock(
531
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
532
+ (act1): ReLU(inplace=True)
533
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
534
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
535
+ (act2): ReLU(inplace=True)
536
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
537
+ )
538
+ )
539
+ (downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
540
+ )
541
+ (1): ModuleDict(
542
+ (blocks): Sequential(
543
+ (0): SegResBlock(
544
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
545
+ (act1): ReLU(inplace=True)
546
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
547
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
548
+ (act2): ReLU(inplace=True)
549
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
550
+ )
551
+ (1): SegResBlock(
552
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
553
+ (act1): ReLU(inplace=True)
554
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
555
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
556
+ (act2): ReLU(inplace=True)
557
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
558
+ )
559
+ )
560
+ (downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
561
+ )
562
+ (2): ModuleDict(
563
+ (blocks): Sequential(
564
+ (0): SegResBlock(
565
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
566
+ (act1): ReLU(inplace=True)
567
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
568
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
569
+ (act2): ReLU(inplace=True)
570
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
571
+ )
572
+ (1): SegResBlock(
573
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
574
+ (act1): ReLU(inplace=True)
575
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
576
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
577
+ (act2): ReLU(inplace=True)
578
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
579
+ )
580
+ )
581
+ (downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
582
+ )
583
+ (3): ModuleDict(
584
+ (blocks): Sequential(
585
+ (0): SegResBlock(
586
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
587
+ (act1): ReLU(inplace=True)
588
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
589
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
590
+ (act2): ReLU(inplace=True)
591
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
592
+ )
593
+ (1): SegResBlock(
594
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
595
+ (act1): ReLU(inplace=True)
596
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
597
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
598
+ (act2): ReLU(inplace=True)
599
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
600
+ )
601
+ (2): SegResBlock(
602
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
603
+ (act1): ReLU(inplace=True)
604
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
605
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
606
+ (act2): ReLU(inplace=True)
607
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
608
+ )
609
+ (3): SegResBlock(
610
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
611
+ (act1): ReLU(inplace=True)
612
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
613
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
614
+ (act2): ReLU(inplace=True)
615
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
616
+ )
617
+ )
618
+ (downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
619
+ )
620
+ (4): ModuleDict(
621
+ (blocks): Sequential(
622
+ (0): SegResBlock(
623
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
624
+ (act1): ReLU(inplace=True)
625
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
626
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
627
+ (act2): ReLU(inplace=True)
628
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
629
+ )
630
+ (1): SegResBlock(
631
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
632
+ (act1): ReLU(inplace=True)
633
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
634
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
635
+ (act2): ReLU(inplace=True)
636
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
637
+ )
638
+ (2): SegResBlock(
639
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
640
+ (act1): ReLU(inplace=True)
641
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
642
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
643
+ (act2): ReLU(inplace=True)
644
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
645
+ )
646
+ (3): SegResBlock(
647
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
648
+ (act1): ReLU(inplace=True)
649
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
650
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
651
+ (act2): ReLU(inplace=True)
652
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
653
+ )
654
+ )
655
+ (downsample): Identity()
656
+ )
657
+ )
658
+ )
659
+ (up_layers): ModuleList(
660
+ (0): ModuleDict(
661
+ (upsample): UpSample(
662
+ (deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
663
+ )
664
+ (blocks): Sequential(
665
+ (0): SegResBlock(
666
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
667
+ (act1): ReLU(inplace=True)
668
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
669
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
670
+ (act2): ReLU(inplace=True)
671
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
672
+ )
673
+ )
674
+ (head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
675
+ )
676
+ (1): ModuleDict(
677
+ (upsample): UpSample(
678
+ (deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
679
+ )
680
+ (blocks): Sequential(
681
+ (0): SegResBlock(
682
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
683
+ (act1): ReLU(inplace=True)
684
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
685
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
686
+ (act2): ReLU(inplace=True)
687
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
688
+ )
689
+ )
690
+ (head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
691
+ )
692
+ (2): ModuleDict(
693
+ (upsample): UpSample(
694
+ (deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
695
+ )
696
+ (blocks): Sequential(
697
+ (0): SegResBlock(
698
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
699
+ (act1): ReLU(inplace=True)
700
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
701
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
702
+ (act2): ReLU(inplace=True)
703
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
704
+ )
705
+ )
706
+ (head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
707
+ )
708
+ (3): ModuleDict(
709
+ (upsample): UpSample(
710
+ (deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
711
+ )
712
+ (blocks): Sequential(
713
+ (0): SegResBlock(
714
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
715
+ (act1): ReLU(inplace=True)
716
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
717
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
718
+ (act2): ReLU(inplace=True)
719
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
720
+ )
721
+ )
722
+ (head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
723
+ )
724
+ )
725
+ )
726
+ => loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
727
+ Total parameters count: 86278888 distributed: False
728
+ Inference complete, time 233.93s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
729
+ _meta_: {}
730
+ acc: null
731
+ amp: false
732
+ anisotropic_scales: true
733
+ auto_scale_allowed: true
734
+ auto_scale_batch: true
735
+ auto_scale_filters: false
736
+ auto_scale_roi: false
737
+ batch_size: 1
738
+ bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
739
+ cache_class_indices: null
740
+ cache_rate: null
741
+ calc_val_loss: false
742
+ channels_last: true
743
+ ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
744
+ ckpt_save: true
745
+ class_index: null
746
+ class_names:
747
+ - acc_0
748
+ crop_add_background: true
749
+ crop_foreground: true
750
+ crop_mode: ratio
751
+ crop_ratios: null
752
+ cuda: false
753
+ data_file_base_dir: /Users/sakshirathi/neurotk/bundles
754
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
755
+ debug: false
756
+ determ: false
757
+ early_stopping_fraction: 0.001
758
+ extra_modalities: {}
759
+ finetune:
760
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
761
+ enabled: false
762
+ float32_precision: null
763
+ fold: 0
764
+ fork: true
765
+ global_rank: 0
766
+ image_size:
767
+ - 544
768
+ - 544
769
+ - 69
770
+ image_size_mm_90:
771
+ - 265.61599121093747
772
+ - 265.6159922216141
773
+ - 190.12765338720757
774
+ image_size_mm_median:
775
+ - 249.68374495589455
776
+ - 249.68375462603575
777
+ - 168.30083390623668
778
+ infer:
779
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
780
+ data_list_key: testing
781
+ enabled: true
782
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
783
+ input_channels: 1
784
+ intensity_bounds:
785
+ - 39.63595217750186
786
+ - 97.59593563988095
787
+ learning_rate: 0.0002
788
+ log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
789
+ loss:
790
+ _target_: DiceCELoss
791
+ include_background: true
792
+ sigmoid: false
793
+ smooth_dr: 1.0e-05
794
+ smooth_nr: 0
795
+ softmax: true
796
+ squared_pred: true
797
+ to_onehot_y: true
798
+ max_samples_per_class: 12500
799
+ mlflow_experiment_name: Auto3DSeg
800
+ mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
801
+ modality: ct
802
+ network:
803
+ _target_: SegResNetDS
804
+ blocks_down:
805
+ - 1
806
+ - 2
807
+ - 2
808
+ - 4
809
+ - 4
810
+ dsdepth: 4
811
+ in_channels: 1
812
+ init_filters: 32
813
+ norm: INSTANCE_NVFUSER
814
+ out_channels: 2
815
+ normalize_mode: range
816
+ notf32: false
817
+ num_crops_per_image: 2
818
+ num_epochs: 1250
819
+ num_epochs_per_saving: 1
820
+ num_epochs_per_validation: null
821
+ num_images_per_batch: 1
822
+ num_steps_per_image: null
823
+ num_warmup_epochs: 3
824
+ num_workers: 4
825
+ optimizer:
826
+ _target_: torch.optim.AdamW
827
+ lr: 0.0002
828
+ weight_decay: 1.0e-05
829
+ orientation_ras: true
830
+ output_classes: 2
831
+ pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
832
+ quick: false
833
+ rank: 0
834
+ resample: true
835
+ resample_resolution:
836
+ - 0.48766356436698155
837
+ - 0.4876635832539761
838
+ - 2.748479210553717
839
+ roi_size:
840
+ - 384
841
+ - 384
842
+ - 60
843
+ sigmoid: false
844
+ spacing_lower:
845
+ - 0.42813486948609353
846
+ - 0.428134856247896
847
+ - 2.499999978382533
848
+ spacing_median:
849
+ - 0.48766356436698155
850
+ - 0.4876635832539761
851
+ - 4.770811902267695
852
+ spacing_upper:
853
+ - 0.5859375
854
+ - 0.5859375004856939
855
+ - 5.012642938162783
856
+ start_epoch: 0
857
+ stop_on_lowacc: true
858
+ validate:
859
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
860
+ enabled: false
861
+ invert: true
862
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
863
+ save_mask: false
864
+ validate_final_original_res: true
865
+
866
+ auto_adjust_network_settings no distributed global_rank 0
867
+ GPU device memory min: 16
868
+ base_numel 7225344 gpu_factor 1 gpu_factor_init 1
869
+ input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
870
+ increasing roi step [ 257.600 257.600 61.000]
871
+ increasing roi result 1 [ 257.600 257.600 61.000]
872
+ increasing roi step [ 296.240 296.240 61.000]
873
+ increasing roi result 1 [ 296.240 296.240 61.000]
874
+ increasing roi step [ 340.676 340.676 61.000]
875
+ increasing roi result 1 [ 340.676 340.676 61.000]
876
+ increasing roi step [ 391.777 391.777 61.000]
877
+ increasing roi result 1 [ 391.777 391.777 61.000]
878
+ roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
879
+ kept filters the same base_numel 7225344, gpu_factor 1
880
+ kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
881
+ Suggested network parameters:
882
+ Batch size 1 => 1
883
+ ROI size [224, 224, 144] => [384, 384, 60]
884
+ init_filters 32 => 32
885
+ aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
886
+
887
+ Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
888
+ SegResNetDS(
889
+ (encoder): SegResEncoder(
890
+ (conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
891
+ (layers): ModuleList(
892
+ (0): ModuleDict(
893
+ (blocks): Sequential(
894
+ (0): SegResBlock(
895
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
896
+ (act1): ReLU(inplace=True)
897
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
898
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
899
+ (act2): ReLU(inplace=True)
900
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
901
+ )
902
+ )
903
+ (downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
904
+ )
905
+ (1): ModuleDict(
906
+ (blocks): Sequential(
907
+ (0): SegResBlock(
908
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
909
+ (act1): ReLU(inplace=True)
910
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
911
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
912
+ (act2): ReLU(inplace=True)
913
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
914
+ )
915
+ (1): SegResBlock(
916
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
917
+ (act1): ReLU(inplace=True)
918
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
919
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
920
+ (act2): ReLU(inplace=True)
921
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
922
+ )
923
+ )
924
+ (downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
925
+ )
926
+ (2): ModuleDict(
927
+ (blocks): Sequential(
928
+ (0): SegResBlock(
929
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
930
+ (act1): ReLU(inplace=True)
931
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
932
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
933
+ (act2): ReLU(inplace=True)
934
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
935
+ )
936
+ (1): SegResBlock(
937
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
938
+ (act1): ReLU(inplace=True)
939
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
940
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
941
+ (act2): ReLU(inplace=True)
942
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
943
+ )
944
+ )
945
+ (downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
946
+ )
947
+ (3): ModuleDict(
948
+ (blocks): Sequential(
949
+ (0): SegResBlock(
950
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
951
+ (act1): ReLU(inplace=True)
952
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
953
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
954
+ (act2): ReLU(inplace=True)
955
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
956
+ )
957
+ (1): SegResBlock(
958
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
959
+ (act1): ReLU(inplace=True)
960
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
961
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
962
+ (act2): ReLU(inplace=True)
963
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
964
+ )
965
+ (2): SegResBlock(
966
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
967
+ (act1): ReLU(inplace=True)
968
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
969
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
970
+ (act2): ReLU(inplace=True)
971
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
972
+ )
973
+ (3): SegResBlock(
974
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
975
+ (act1): ReLU(inplace=True)
976
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
977
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
978
+ (act2): ReLU(inplace=True)
979
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
980
+ )
981
+ )
982
+ (downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
983
+ )
984
+ (4): ModuleDict(
985
+ (blocks): Sequential(
986
+ (0): SegResBlock(
987
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
988
+ (act1): ReLU(inplace=True)
989
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
990
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
991
+ (act2): ReLU(inplace=True)
992
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
993
+ )
994
+ (1): SegResBlock(
995
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
996
+ (act1): ReLU(inplace=True)
997
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
998
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
999
+ (act2): ReLU(inplace=True)
1000
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1001
+ )
1002
+ (2): SegResBlock(
1003
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1004
+ (act1): ReLU(inplace=True)
1005
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1006
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1007
+ (act2): ReLU(inplace=True)
1008
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1009
+ )
1010
+ (3): SegResBlock(
1011
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1012
+ (act1): ReLU(inplace=True)
1013
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1014
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1015
+ (act2): ReLU(inplace=True)
1016
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1017
+ )
1018
+ )
1019
+ (downsample): Identity()
1020
+ )
1021
+ )
1022
+ )
1023
+ (up_layers): ModuleList(
1024
+ (0): ModuleDict(
1025
+ (upsample): UpSample(
1026
+ (deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1027
+ )
1028
+ (blocks): Sequential(
1029
+ (0): SegResBlock(
1030
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1031
+ (act1): ReLU(inplace=True)
1032
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1033
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1034
+ (act2): ReLU(inplace=True)
1035
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1036
+ )
1037
+ )
1038
+ (head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1039
+ )
1040
+ (1): ModuleDict(
1041
+ (upsample): UpSample(
1042
+ (deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1043
+ )
1044
+ (blocks): Sequential(
1045
+ (0): SegResBlock(
1046
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1047
+ (act1): ReLU(inplace=True)
1048
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1049
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1050
+ (act2): ReLU(inplace=True)
1051
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1052
+ )
1053
+ )
1054
+ (head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1055
+ )
1056
+ (2): ModuleDict(
1057
+ (upsample): UpSample(
1058
+ (deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1059
+ )
1060
+ (blocks): Sequential(
1061
+ (0): SegResBlock(
1062
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1063
+ (act1): ReLU(inplace=True)
1064
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1065
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1066
+ (act2): ReLU(inplace=True)
1067
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1068
+ )
1069
+ )
1070
+ (head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1071
+ )
1072
+ (3): ModuleDict(
1073
+ (upsample): UpSample(
1074
+ (deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1075
+ )
1076
+ (blocks): Sequential(
1077
+ (0): SegResBlock(
1078
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1079
+ (act1): ReLU(inplace=True)
1080
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1081
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1082
+ (act2): ReLU(inplace=True)
1083
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1084
+ )
1085
+ )
1086
+ (head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1087
+ )
1088
+ )
1089
+ )
1090
+ => loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
1091
+ Total parameters count: 86278888 distributed: False
1092
+ Inference complete, time 226.94s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
1093
+ _meta_: {}
1094
+ acc: null
1095
+ amp: false
1096
+ anisotropic_scales: true
1097
+ auto_scale_allowed: true
1098
+ auto_scale_batch: true
1099
+ auto_scale_filters: false
1100
+ auto_scale_roi: false
1101
+ batch_size: 1
1102
+ bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
1103
+ cache_class_indices: null
1104
+ cache_rate: null
1105
+ calc_val_loss: false
1106
+ channels_last: true
1107
+ ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
1108
+ ckpt_save: true
1109
+ class_index: null
1110
+ class_names:
1111
+ - acc_0
1112
+ crop_add_background: true
1113
+ crop_foreground: true
1114
+ crop_mode: ratio
1115
+ crop_ratios: null
1116
+ cuda: false
1117
+ data_file_base_dir: /Users/sakshirathi/neurotk/bundles
1118
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
1119
+ debug: false
1120
+ determ: false
1121
+ early_stopping_fraction: 0.001
1122
+ extra_modalities: {}
1123
+ finetune:
1124
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1125
+ enabled: false
1126
+ float32_precision: null
1127
+ fold: 0
1128
+ fork: true
1129
+ global_rank: 0
1130
+ image_size:
1131
+ - 544
1132
+ - 544
1133
+ - 69
1134
+ image_size_mm_90:
1135
+ - 265.61599121093747
1136
+ - 265.6159922216141
1137
+ - 190.12765338720757
1138
+ image_size_mm_median:
1139
+ - 249.68374495589455
1140
+ - 249.68375462603575
1141
+ - 168.30083390623668
1142
+ infer:
1143
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1144
+ data_list_key: testing
1145
+ enabled: true
1146
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
1147
+ input_channels: 1
1148
+ intensity_bounds:
1149
+ - 39.63595217750186
1150
+ - 97.59593563988095
1151
+ learning_rate: 0.0002
1152
+ log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
1153
+ loss:
1154
+ _target_: DiceCELoss
1155
+ include_background: true
1156
+ sigmoid: false
1157
+ smooth_dr: 1.0e-05
1158
+ smooth_nr: 0
1159
+ softmax: true
1160
+ squared_pred: true
1161
+ to_onehot_y: true
1162
+ max_samples_per_class: 12500
1163
+ mlflow_experiment_name: Auto3DSeg
1164
+ mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
1165
+ modality: ct
1166
+ network:
1167
+ _target_: SegResNetDS
1168
+ blocks_down:
1169
+ - 1
1170
+ - 2
1171
+ - 2
1172
+ - 4
1173
+ - 4
1174
+ dsdepth: 4
1175
+ in_channels: 1
1176
+ init_filters: 32
1177
+ norm: INSTANCE_NVFUSER
1178
+ out_channels: 2
1179
+ normalize_mode: range
1180
+ notf32: false
1181
+ num_crops_per_image: 2
1182
+ num_epochs: 1250
1183
+ num_epochs_per_saving: 1
1184
+ num_epochs_per_validation: null
1185
+ num_images_per_batch: 1
1186
+ num_steps_per_image: null
1187
+ num_warmup_epochs: 3
1188
+ num_workers: 4
1189
+ optimizer:
1190
+ _target_: torch.optim.AdamW
1191
+ lr: 0.0002
1192
+ weight_decay: 1.0e-05
1193
+ orientation_ras: true
1194
+ output_classes: 2
1195
+ pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1196
+ quick: false
1197
+ rank: 0
1198
+ resample: true
1199
+ resample_resolution:
1200
+ - 0.48766356436698155
1201
+ - 0.4876635832539761
1202
+ - 2.748479210553717
1203
+ roi_size:
1204
+ - 384
1205
+ - 384
1206
+ - 60
1207
+ sigmoid: false
1208
+ spacing_lower:
1209
+ - 0.42813486948609353
1210
+ - 0.428134856247896
1211
+ - 2.499999978382533
1212
+ spacing_median:
1213
+ - 0.48766356436698155
1214
+ - 0.4876635832539761
1215
+ - 4.770811902267695
1216
+ spacing_upper:
1217
+ - 0.5859375
1218
+ - 0.5859375004856939
1219
+ - 5.012642938162783
1220
+ start_epoch: 0
1221
+ stop_on_lowacc: true
1222
+ validate:
1223
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1224
+ enabled: false
1225
+ invert: true
1226
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
1227
+ save_mask: false
1228
+ validate_final_original_res: true
1229
+
1230
+ auto_adjust_network_settings no distributed global_rank 0
1231
+ GPU device memory min: 16
1232
+ base_numel 7225344 gpu_factor 1 gpu_factor_init 1
1233
+ input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
1234
+ increasing roi step [ 257.600 257.600 61.000]
1235
+ increasing roi result 1 [ 257.600 257.600 61.000]
1236
+ increasing roi step [ 296.240 296.240 61.000]
1237
+ increasing roi result 1 [ 296.240 296.240 61.000]
1238
+ increasing roi step [ 340.676 340.676 61.000]
1239
+ increasing roi result 1 [ 340.676 340.676 61.000]
1240
+ increasing roi step [ 391.777 391.777 61.000]
1241
+ increasing roi result 1 [ 391.777 391.777 61.000]
1242
+ roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
1243
+ kept filters the same base_numel 7225344, gpu_factor 1
1244
+ kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
1245
+ Suggested network parameters:
1246
+ Batch size 1 => 1
1247
+ ROI size [224, 224, 144] => [384, 384, 60]
1248
+ init_filters 32 => 32
1249
+ aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
1250
+
1251
+ Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
1252
+ SegResNetDS(
1253
+ (encoder): SegResEncoder(
1254
+ (conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1255
+ (layers): ModuleList(
1256
+ (0): ModuleDict(
1257
+ (blocks): Sequential(
1258
+ (0): SegResBlock(
1259
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1260
+ (act1): ReLU(inplace=True)
1261
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1262
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1263
+ (act2): ReLU(inplace=True)
1264
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1265
+ )
1266
+ )
1267
+ (downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
1268
+ )
1269
+ (1): ModuleDict(
1270
+ (blocks): Sequential(
1271
+ (0): SegResBlock(
1272
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1273
+ (act1): ReLU(inplace=True)
1274
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1275
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1276
+ (act2): ReLU(inplace=True)
1277
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1278
+ )
1279
+ (1): SegResBlock(
1280
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1281
+ (act1): ReLU(inplace=True)
1282
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1283
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1284
+ (act2): ReLU(inplace=True)
1285
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1286
+ )
1287
+ )
1288
+ (downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
1289
+ )
1290
+ (2): ModuleDict(
1291
+ (blocks): Sequential(
1292
+ (0): SegResBlock(
1293
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1294
+ (act1): ReLU(inplace=True)
1295
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1296
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1297
+ (act2): ReLU(inplace=True)
1298
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1299
+ )
1300
+ (1): SegResBlock(
1301
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1302
+ (act1): ReLU(inplace=True)
1303
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1304
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1305
+ (act2): ReLU(inplace=True)
1306
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1307
+ )
1308
+ )
1309
+ (downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
1310
+ )
1311
+ (3): ModuleDict(
1312
+ (blocks): Sequential(
1313
+ (0): SegResBlock(
1314
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1315
+ (act1): ReLU(inplace=True)
1316
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1317
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1318
+ (act2): ReLU(inplace=True)
1319
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1320
+ )
1321
+ (1): SegResBlock(
1322
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1323
+ (act1): ReLU(inplace=True)
1324
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1325
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1326
+ (act2): ReLU(inplace=True)
1327
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1328
+ )
1329
+ (2): SegResBlock(
1330
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1331
+ (act1): ReLU(inplace=True)
1332
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1333
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1334
+ (act2): ReLU(inplace=True)
1335
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1336
+ )
1337
+ (3): SegResBlock(
1338
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1339
+ (act1): ReLU(inplace=True)
1340
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1341
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1342
+ (act2): ReLU(inplace=True)
1343
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1344
+ )
1345
+ )
1346
+ (downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
1347
+ )
1348
+ (4): ModuleDict(
1349
+ (blocks): Sequential(
1350
+ (0): SegResBlock(
1351
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1352
+ (act1): ReLU(inplace=True)
1353
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1354
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1355
+ (act2): ReLU(inplace=True)
1356
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1357
+ )
1358
+ (1): SegResBlock(
1359
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1360
+ (act1): ReLU(inplace=True)
1361
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1362
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1363
+ (act2): ReLU(inplace=True)
1364
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1365
+ )
1366
+ (2): SegResBlock(
1367
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1368
+ (act1): ReLU(inplace=True)
1369
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1370
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1371
+ (act2): ReLU(inplace=True)
1372
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1373
+ )
1374
+ (3): SegResBlock(
1375
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1376
+ (act1): ReLU(inplace=True)
1377
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1378
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1379
+ (act2): ReLU(inplace=True)
1380
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1381
+ )
1382
+ )
1383
+ (downsample): Identity()
1384
+ )
1385
+ )
1386
+ )
1387
+ (up_layers): ModuleList(
1388
+ (0): ModuleDict(
1389
+ (upsample): UpSample(
1390
+ (deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1391
+ )
1392
+ (blocks): Sequential(
1393
+ (0): SegResBlock(
1394
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1395
+ (act1): ReLU(inplace=True)
1396
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1397
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1398
+ (act2): ReLU(inplace=True)
1399
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1400
+ )
1401
+ )
1402
+ (head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1403
+ )
1404
+ (1): ModuleDict(
1405
+ (upsample): UpSample(
1406
+ (deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1407
+ )
1408
+ (blocks): Sequential(
1409
+ (0): SegResBlock(
1410
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1411
+ (act1): ReLU(inplace=True)
1412
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1413
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1414
+ (act2): ReLU(inplace=True)
1415
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1416
+ )
1417
+ )
1418
+ (head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1419
+ )
1420
+ (2): ModuleDict(
1421
+ (upsample): UpSample(
1422
+ (deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1423
+ )
1424
+ (blocks): Sequential(
1425
+ (0): SegResBlock(
1426
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1427
+ (act1): ReLU(inplace=True)
1428
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1429
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1430
+ (act2): ReLU(inplace=True)
1431
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1432
+ )
1433
+ )
1434
+ (head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1435
+ )
1436
+ (3): ModuleDict(
1437
+ (upsample): UpSample(
1438
+ (deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1439
+ )
1440
+ (blocks): Sequential(
1441
+ (0): SegResBlock(
1442
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1443
+ (act1): ReLU(inplace=True)
1444
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1445
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1446
+ (act2): ReLU(inplace=True)
1447
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1448
+ )
1449
+ )
1450
+ (head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1451
+ )
1452
+ )
1453
+ )
1454
+ => loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
1455
+ Total parameters count: 86278888 distributed: False
1456
+ _meta_: {}
1457
+ acc: null
1458
+ amp: false
1459
+ anisotropic_scales: true
1460
+ auto_scale_allowed: true
1461
+ auto_scale_batch: true
1462
+ auto_scale_filters: false
1463
+ auto_scale_roi: false
1464
+ batch_size: 1
1465
+ bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
1466
+ cache_class_indices: null
1467
+ cache_rate: null
1468
+ calc_val_loss: false
1469
+ channels_last: true
1470
+ ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
1471
+ ckpt_save: true
1472
+ class_index: null
1473
+ class_names:
1474
+ - acc_0
1475
+ crop_add_background: true
1476
+ crop_foreground: true
1477
+ crop_mode: ratio
1478
+ crop_ratios: null
1479
+ cuda: false
1480
+ data_file_base_dir: /Users/sakshirathi/neurotk/bundles
1481
+ data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
1482
+ debug: false
1483
+ determ: false
1484
+ early_stopping_fraction: 0.001
1485
+ extra_modalities: {}
1486
+ finetune:
1487
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1488
+ enabled: false
1489
+ float32_precision: null
1490
+ fold: 0
1491
+ fork: true
1492
+ global_rank: 0
1493
+ image_size:
1494
+ - 544
1495
+ - 544
1496
+ - 69
1497
+ image_size_mm_90:
1498
+ - 265.61599121093747
1499
+ - 265.6159922216141
1500
+ - 190.12765338720757
1501
+ image_size_mm_median:
1502
+ - 249.68374495589455
1503
+ - 249.68375462603575
1504
+ - 168.30083390623668
1505
+ infer:
1506
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1507
+ data_list_key: testing
1508
+ enabled: true
1509
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
1510
+ input_channels: 1
1511
+ intensity_bounds:
1512
+ - 39.63595217750186
1513
+ - 97.59593563988095
1514
+ learning_rate: 0.0002
1515
+ log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
1516
+ loss:
1517
+ _target_: DiceCELoss
1518
+ include_background: true
1519
+ sigmoid: false
1520
+ smooth_dr: 1.0e-05
1521
+ smooth_nr: 0
1522
+ softmax: true
1523
+ squared_pred: true
1524
+ to_onehot_y: true
1525
+ max_samples_per_class: 12500
1526
+ mlflow_experiment_name: Auto3DSeg
1527
+ mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
1528
+ modality: ct
1529
+ network:
1530
+ _target_: SegResNetDS
1531
+ blocks_down:
1532
+ - 1
1533
+ - 2
1534
+ - 2
1535
+ - 4
1536
+ - 4
1537
+ dsdepth: 4
1538
+ in_channels: 1
1539
+ init_filters: 32
1540
+ norm: INSTANCE_NVFUSER
1541
+ out_channels: 2
1542
+ normalize_mode: range
1543
+ notf32: false
1544
+ num_crops_per_image: 2
1545
+ num_epochs: 1250
1546
+ num_epochs_per_saving: 1
1547
+ num_epochs_per_validation: null
1548
+ num_images_per_batch: 1
1549
+ num_steps_per_image: null
1550
+ num_warmup_epochs: 3
1551
+ num_workers: 4
1552
+ optimizer:
1553
+ _target_: torch.optim.AdamW
1554
+ lr: 0.0002
1555
+ weight_decay: 1.0e-05
1556
+ orientation_ras: true
1557
+ output_classes: 2
1558
+ pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1559
+ quick: false
1560
+ rank: 0
1561
+ resample: true
1562
+ resample_resolution:
1563
+ - 0.48766356436698155
1564
+ - 0.4876635832539761
1565
+ - 2.748479210553717
1566
+ roi_size:
1567
+ - 384
1568
+ - 384
1569
+ - 60
1570
+ sigmoid: false
1571
+ spacing_lower:
1572
+ - 0.42813486948609353
1573
+ - 0.428134856247896
1574
+ - 2.499999978382533
1575
+ spacing_median:
1576
+ - 0.48766356436698155
1577
+ - 0.4876635832539761
1578
+ - 4.770811902267695
1579
+ spacing_upper:
1580
+ - 0.5859375
1581
+ - 0.5859375004856939
1582
+ - 5.012642938162783
1583
+ start_epoch: 0
1584
+ stop_on_lowacc: true
1585
+ validate:
1586
+ ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
1587
+ enabled: false
1588
+ invert: true
1589
+ output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
1590
+ save_mask: false
1591
+ validate_final_original_res: true
1592
+
1593
+ auto_adjust_network_settings no distributed global_rank 0
1594
+ GPU device memory min: 16
1595
+ base_numel 7225344 gpu_factor 1 gpu_factor_init 1
1596
+ input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
1597
+ increasing roi step [ 257.600 257.600 61.000]
1598
+ increasing roi result 1 [ 257.600 257.600 61.000]
1599
+ increasing roi step [ 296.240 296.240 61.000]
1600
+ increasing roi result 1 [ 296.240 296.240 61.000]
1601
+ increasing roi step [ 340.676 340.676 61.000]
1602
+ increasing roi result 1 [ 340.676 340.676 61.000]
1603
+ increasing roi step [ 391.777 391.777 61.000]
1604
+ increasing roi result 1 [ 391.777 391.777 61.000]
1605
+ roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
1606
+ kept filters the same base_numel 7225344, gpu_factor 1
1607
+ kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
1608
+ Suggested network parameters:
1609
+ Batch size 1 => 1
1610
+ ROI size [224, 224, 144] => [384, 384, 60]
1611
+ init_filters 32 => 32
1612
+ aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
1613
+
1614
+ Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
1615
+ SegResNetDS(
1616
+ (encoder): SegResEncoder(
1617
+ (conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1618
+ (layers): ModuleList(
1619
+ (0): ModuleDict(
1620
+ (blocks): Sequential(
1621
+ (0): SegResBlock(
1622
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1623
+ (act1): ReLU(inplace=True)
1624
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1625
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1626
+ (act2): ReLU(inplace=True)
1627
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1628
+ )
1629
+ )
1630
+ (downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
1631
+ )
1632
+ (1): ModuleDict(
1633
+ (blocks): Sequential(
1634
+ (0): SegResBlock(
1635
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1636
+ (act1): ReLU(inplace=True)
1637
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1638
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1639
+ (act2): ReLU(inplace=True)
1640
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1641
+ )
1642
+ (1): SegResBlock(
1643
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1644
+ (act1): ReLU(inplace=True)
1645
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1646
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1647
+ (act2): ReLU(inplace=True)
1648
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1649
+ )
1650
+ )
1651
+ (downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
1652
+ )
1653
+ (2): ModuleDict(
1654
+ (blocks): Sequential(
1655
+ (0): SegResBlock(
1656
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1657
+ (act1): ReLU(inplace=True)
1658
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1659
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1660
+ (act2): ReLU(inplace=True)
1661
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1662
+ )
1663
+ (1): SegResBlock(
1664
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1665
+ (act1): ReLU(inplace=True)
1666
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1667
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1668
+ (act2): ReLU(inplace=True)
1669
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1670
+ )
1671
+ )
1672
+ (downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
1673
+ )
1674
+ (3): ModuleDict(
1675
+ (blocks): Sequential(
1676
+ (0): SegResBlock(
1677
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1678
+ (act1): ReLU(inplace=True)
1679
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1680
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1681
+ (act2): ReLU(inplace=True)
1682
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1683
+ )
1684
+ (1): SegResBlock(
1685
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1686
+ (act1): ReLU(inplace=True)
1687
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1688
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1689
+ (act2): ReLU(inplace=True)
1690
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1691
+ )
1692
+ (2): SegResBlock(
1693
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1694
+ (act1): ReLU(inplace=True)
1695
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1696
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1697
+ (act2): ReLU(inplace=True)
1698
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1699
+ )
1700
+ (3): SegResBlock(
1701
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1702
+ (act1): ReLU(inplace=True)
1703
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1704
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1705
+ (act2): ReLU(inplace=True)
1706
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1707
+ )
1708
+ )
1709
+ (downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
1710
+ )
1711
+ (4): ModuleDict(
1712
+ (blocks): Sequential(
1713
+ (0): SegResBlock(
1714
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1715
+ (act1): ReLU(inplace=True)
1716
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1717
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1718
+ (act2): ReLU(inplace=True)
1719
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1720
+ )
1721
+ (1): SegResBlock(
1722
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1723
+ (act1): ReLU(inplace=True)
1724
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1725
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1726
+ (act2): ReLU(inplace=True)
1727
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1728
+ )
1729
+ (2): SegResBlock(
1730
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1731
+ (act1): ReLU(inplace=True)
1732
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1733
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1734
+ (act2): ReLU(inplace=True)
1735
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1736
+ )
1737
+ (3): SegResBlock(
1738
+ (norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1739
+ (act1): ReLU(inplace=True)
1740
+ (conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1741
+ (norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1742
+ (act2): ReLU(inplace=True)
1743
+ (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1744
+ )
1745
+ )
1746
+ (downsample): Identity()
1747
+ )
1748
+ )
1749
+ )
1750
+ (up_layers): ModuleList(
1751
+ (0): ModuleDict(
1752
+ (upsample): UpSample(
1753
+ (deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1754
+ )
1755
+ (blocks): Sequential(
1756
+ (0): SegResBlock(
1757
+ (norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1758
+ (act1): ReLU(inplace=True)
1759
+ (conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1760
+ (norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1761
+ (act2): ReLU(inplace=True)
1762
+ (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1763
+ )
1764
+ )
1765
+ (head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1766
+ )
1767
+ (1): ModuleDict(
1768
+ (upsample): UpSample(
1769
+ (deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
1770
+ )
1771
+ (blocks): Sequential(
1772
+ (0): SegResBlock(
1773
+ (norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1774
+ (act1): ReLU(inplace=True)
1775
+ (conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1776
+ (norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1777
+ (act2): ReLU(inplace=True)
1778
+ (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
1779
+ )
1780
+ )
1781
+ (head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1782
+ )
1783
+ (2): ModuleDict(
1784
+ (upsample): UpSample(
1785
+ (deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1786
+ )
1787
+ (blocks): Sequential(
1788
+ (0): SegResBlock(
1789
+ (norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1790
+ (act1): ReLU(inplace=True)
1791
+ (conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1792
+ (norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1793
+ (act2): ReLU(inplace=True)
1794
+ (conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1795
+ )
1796
+ )
1797
+ (head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1798
+ )
1799
+ (3): ModuleDict(
1800
+ (upsample): UpSample(
1801
+ (deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
1802
+ )
1803
+ (blocks): Sequential(
1804
+ (0): SegResBlock(
1805
+ (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1806
+ (act1): ReLU(inplace=True)
1807
+ (conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1808
+ (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
1809
+ (act2): ReLU(inplace=True)
1810
+ (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
1811
+ )
1812
+ )
1813
+ (head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
1814
+ )
1815
+ )
1816
+ )
1817
+ => loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
1818
+ Total parameters count: 86278888 distributed: False
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/__pycache__/segmenter.cpython-310.pyc ADDED
Binary file (53.7 kB). View file
 
scripts/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.08 kB). View file
 
scripts/segmenter.py ADDED
@@ -0,0 +1,2212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import csv
16
+ import gc
17
+ import logging
18
+ import multiprocessing as mp
19
+ import os
20
+ import shutil
21
+ import sys
22
+ import time
23
+ import warnings
24
+ from datetime import datetime, timedelta
25
+ from pathlib import Path
26
+ from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
27
+
28
+ import numpy as np
29
+ import psutil
30
+ import torch
31
+ import torch.distributed as dist
32
+ import torch.multiprocessing as mp
33
+ import yaml
34
+ from torch.amp import GradScaler, autocast
35
+ from torch.nn.parallel import DistributedDataParallel
36
+ from torch.utils.data.distributed import DistributedSampler
37
+ from torch.utils.tensorboard import SummaryWriter
38
+
39
+ from monai.apps.auto3dseg.auto_runner import logger
40
+ from monai.apps.auto3dseg.transforms import EnsureSameShaped
41
+ from monai.auto3dseg.utils import datafold_read
42
+ from monai.bundle.config_parser import ConfigParser
43
+ from monai.config import KeysCollection
44
+ from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, decollate_batch, list_data_collate
45
+ from monai.inferers import SlidingWindowInfererAdapt
46
+ from monai.losses import DeepSupervisionLoss
47
+ from monai.metrics import CumulativeAverage, DiceHelper
48
+ from monai.networks.layers.factories import split_args
49
+ from monai.optimizers.lr_scheduler import WarmupCosineSchedule
50
+ from monai.transforms import (
51
+ AsDiscreted,
52
+ CastToTyped,
53
+ ClassesToIndicesd,
54
+ Compose,
55
+ ConcatItemsd,
56
+ CopyItemsd,
57
+ CropForegroundd,
58
+ DataStatsd,
59
+ DeleteItemsd,
60
+ EnsureTyped,
61
+ Identityd,
62
+ Invertd,
63
+ Lambdad,
64
+ LoadImaged,
65
+ NormalizeIntensityd,
66
+ Orientationd,
67
+ RandAdjustContrastd,
68
+ RandAffined,
69
+ RandCropByLabelClassesd,
70
+ RandFlipd,
71
+ RandGaussianNoised,
72
+ RandGaussianSmoothd,
73
+ RandHistogramShiftd,
74
+ RandIdentity,
75
+ RandRotate90d,
76
+ RandScaleIntensityd,
77
+ RandScaleIntensityFixedMeand,
78
+ RandShiftIntensityd,
79
+ RandSpatialCropd,
80
+ ResampleToMatchd,
81
+ SaveImaged,
82
+ ScaleIntensityRanged,
83
+ Spacingd,
84
+ SpatialPadd,
85
+ ToDeviced,
86
+ )
87
+ from monai.transforms.transform import MapTransform
88
+ from monai.utils import ImageMetaKey, convert_to_dst_type, optional_import, set_determinism
89
+
90
+ mlflow, mlflow_is_imported = optional_import("mlflow")
91
+
92
+
93
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:2048"
94
+ print = logger.debug
95
+ tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
96
+
97
+ if __package__ in (None, ""):
98
+ from utils import auto_adjust_network_settings, logger_configure
99
+ else:
100
+ from .utils import auto_adjust_network_settings, logger_configure
101
+
102
+
103
+ class LabelEmbedClassIndex(MapTransform):
104
+ """
105
+ Label embedding according to class_index
106
+ """
107
+
108
+ def __init__(
109
+ self, keys: KeysCollection = "label", allow_missing_keys: bool = False, class_index: Optional[List] = None
110
+ ) -> None:
111
+ """
112
+ Args:
113
+ keys: keys of the corresponding items to be compared to the source_key item shape.
114
+ allow_missing_keys: do not raise exception if key is missing.
115
+ class_index: a list of class indices
116
+ """
117
+ super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
118
+ self.class_index = class_index
119
+
120
+ def label_mapping(self, x: torch.Tensor) -> torch.Tensor:
121
+ dtype = x.dtype
122
+ return torch.cat([sum([x == i for i in c]) for c in self.class_index], dim=0).to(dtype=dtype)
123
+
124
+ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
125
+ d = dict(data)
126
+ if self.class_index is not None:
127
+ for key in self.key_iterator(d):
128
+ d[key] = self.label_mapping(d[key])
129
+ return d
130
+
131
+
132
+ def schedule_validation_epochs(num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list:
133
+ """
134
+ Schedule of epochs to validate (progressively more frequently)
135
+ num_epochs - total number of epochs
136
+ num_epochs_per_validation - if provided use a linear schedule with this step
137
+ init_step
138
+ """
139
+
140
+ if num_epochs_per_validation is None:
141
+ x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int)
142
+ x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1])
143
+ x[-1] = num_epochs
144
+ x = x.tolist()
145
+ else:
146
+ if num_epochs_per_validation >= num_epochs:
147
+ x = [num_epochs_per_validation]
148
+ else:
149
+ x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation))
150
+
151
+ if len(x) == 0:
152
+ x = [0]
153
+
154
+ return x
155
+
156
+
157
+ class DataTransformBuilder:
158
+ def __init__(
159
+ self,
160
+ roi_size: list,
161
+ image_key: str = "image",
162
+ label_key: str = "label",
163
+ resample: bool = False,
164
+ resample_resolution: Optional[list] = None,
165
+ normalize_mode: str = "meanstd",
166
+ normalize_params: Optional[dict] = None,
167
+ crop_mode: str = "ratio",
168
+ crop_params: Optional[dict] = None,
169
+ extra_modalities: Optional[dict] = None,
170
+ custom_transforms=None,
171
+ augment_params: Optional[dict] = None,
172
+ debug: bool = False,
173
+ rank: int = 0,
174
+ class_index=None,
175
+ **kwargs,
176
+ ) -> None:
177
+ self.roi_size, self.image_key, self.label_key = roi_size, image_key, label_key
178
+
179
+ self.resample, self.resample_resolution = resample, resample_resolution
180
+ self.normalize_mode = normalize_mode
181
+ self.normalize_params = normalize_params if normalize_params is not None else {}
182
+ self.crop_mode = crop_mode
183
+ self.crop_params = crop_params if crop_params is not None else {}
184
+ self.augment_params = augment_params if augment_params is not None else {}
185
+
186
+ self.extra_modalities = extra_modalities if extra_modalities is not None else {}
187
+ self.custom_transforms = custom_transforms if custom_transforms is not None else {}
188
+
189
+ self.extra_options = kwargs
190
+ self.debug = debug
191
+ self.rank = rank
192
+ self.class_index = class_index
193
+
194
+ def get_custom(self, name, **kwargs):
195
+ tr = []
196
+ for t in self.custom_transforms.get(name, []):
197
+ if isinstance(t, dict):
198
+ t.update(kwargs)
199
+ t = ConfigParser(t).get_parsed_content(instantiate=True)
200
+ tr.append(t)
201
+
202
+ return tr
203
+
204
+ def get_load_transforms(self):
205
+ ts = self.get_custom("load_transforms")
206
+ if len(ts) > 0:
207
+ return ts
208
+
209
+ keys = [self.image_key, self.label_key] + list(self.extra_modalities)
210
+ ts.append(
211
+ LoadImaged(keys=keys, ensure_channel_first=True, dtype=None, allow_missing_keys=True, image_only=True)
212
+ )
213
+ ts.append(EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float, allow_missing_keys=True))
214
+ ts.append(
215
+ EnsureSameShaped(keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug)
216
+ )
217
+
218
+ ts.extend(self.get_custom("after_load_transforms"))
219
+
220
+ return ts
221
+
222
+ def get_resample_transforms(self, resample_label=True):
223
+ ts = self.get_custom("resample_transforms", resample_label=resample_label)
224
+ if len(ts) > 0:
225
+ return ts
226
+
227
+ keys = [self.image_key, self.label_key] if resample_label else [self.image_key]
228
+ mode = ["bilinear", "nearest"] if resample_label else ["bilinear"]
229
+ extra_keys = list(self.extra_modalities)
230
+
231
+ if self.extra_options.get("orientation_ras", False):
232
+ ts.append(Orientationd(keys=keys, axcodes="RAS", labels=(("L", "R"), ("P", "A"), ("I", "S"))))
233
+
234
+ if self.extra_options.get("crop_foreground", False) and len(extra_keys) == 0:
235
+ ts.append(
236
+ CropForegroundd(
237
+ keys=keys, source_key=self.image_key, allow_missing_keys=True, margin=10, allow_smaller=True
238
+ )
239
+ )
240
+ if self.resample:
241
+ if self.resample_resolution is None:
242
+ raise ValueError("resample_resolution is not provided")
243
+
244
+ pixdim = self.resample_resolution
245
+ ts.append(
246
+ Spacingd(
247
+ keys=keys,
248
+ pixdim=pixdim,
249
+ mode=mode,
250
+ dtype=torch.float,
251
+ min_pixdim=np.array(pixdim) * 0.75,
252
+ max_pixdim=np.array(pixdim) * 1.25,
253
+ allow_missing_keys=True,
254
+ )
255
+ )
256
+
257
+ if resample_label:
258
+ ts.append(
259
+ EnsureSameShaped(
260
+ keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug
261
+ )
262
+ )
263
+
264
+ for extra_key in extra_keys:
265
+ ts.append(ResampleToMatchd(keys=extra_key, key_dst=self.image_key, dtype=np.float32))
266
+
267
+ ts.extend(self.get_custom("after_resample_transforms", resample_label=resample_label))
268
+
269
+ return ts
270
+
271
+ def get_normalize_transforms(self):
272
+
273
+ ts = self.get_custom("normalize_transforms")
274
+ if len(ts) > 0:
275
+ return ts
276
+
277
+ label_dtype = self.normalize_params.get("label_dtype", None)
278
+ if label_dtype is not None:
279
+ ts.append(CastToTyped(keys=self.label_key, dtype=label_dtype, allow_missing_keys=True))
280
+ image_dtype = self.normalize_params.get("image_dtype", None)
281
+ if image_dtype is not None:
282
+ ts.append(CastToTyped(keys=self.image_key, dtype=image_dtype, allow_missing_keys=True)) # for caching
283
+ ts.append(RandIdentity()) # indicate to stop caching after this point
284
+ ts.append(CastToTyped(keys=self.image_key, dtype=torch.float, allow_missing_keys=True))
285
+
286
+ modalities = {self.image_key: self.normalize_mode}
287
+ modalities.update(self.extra_modalities)
288
+
289
+ for key, normalize_mode in modalities.items():
290
+ if normalize_mode == "none":
291
+ pass
292
+ elif normalize_mode in ["range", "ct"]:
293
+ intensity_bounds = self.normalize_params.get("intensity_bounds", None)
294
+ if intensity_bounds is None:
295
+ intensity_bounds = [-250, 250]
296
+ warnings.warn(f"intensity_bounds is not specified, assuming {intensity_bounds}")
297
+
298
+ ts.append(
299
+ ScaleIntensityRanged(
300
+ keys=key, a_min=intensity_bounds[0], a_max=intensity_bounds[1], b_min=-1, b_max=1, clip=False
301
+ )
302
+ )
303
+ ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid(x)))
304
+ elif normalize_mode in ["meanstd", "mri"]:
305
+ ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
306
+ elif normalize_mode in ["meanstdtanh"]:
307
+ ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
308
+ ts.append(Lambdad(keys=key, func=lambda x: 3 * torch.tanh(x / 3)))
309
+ elif normalize_mode in ["pet"]:
310
+ ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid((x - x.min()) / x.std())))
311
+ else:
312
+ raise ValueError("Unsupported normalize_mode" + str(normalize_mode))
313
+
314
+ if len(self.extra_modalities) > 0:
315
+ ts.append(ConcatItemsd(keys=list(modalities), name=self.image_key)) # concat
316
+ ts.append(DeleteItemsd(keys=list(self.extra_modalities))) # release memory
317
+
318
+ ts.extend(self.get_custom("after_normalize_transforms"))
319
+ return ts
320
+
321
+ def get_crop_transforms(self):
322
+ ts = self.get_custom("crop_transforms")
323
+ if len(ts) > 0:
324
+ return ts
325
+
326
+ if self.roi_size is None:
327
+ raise ValueError("roi_size is not specified")
328
+
329
+ keys = [self.image_key, self.label_key]
330
+ ts = []
331
+ ts.append(SpatialPadd(keys=keys, spatial_size=self.roi_size))
332
+
333
+ if self.crop_mode == "ratio":
334
+ output_classes = self.crop_params.get("output_classes", None)
335
+ if output_classes is None:
336
+ raise ValueError("crop_params option output_classes must be specified")
337
+
338
+ crop_ratios = self.crop_params.get("crop_ratios", None)
339
+ cache_class_indices = self.crop_params.get("cache_class_indices", False)
340
+ max_samples_per_class = self.crop_params.get("max_samples_per_class", None)
341
+ if max_samples_per_class <= 0:
342
+ max_samples_per_class = None
343
+ indices_key = None
344
+
345
+ sigmoid = self.extra_options.get("sigmoid", False)
346
+ crop_add_background = self.crop_params.get("crop_add_background", False)
347
+
348
+ if crop_ratios is None:
349
+ crop_classes = output_classes
350
+ if sigmoid and crop_add_background and self.class_index is not None and len(self.class_index) > 1:
351
+ crop_classes = crop_classes + 1
352
+ else:
353
+ crop_classes = len(crop_ratios)
354
+
355
+ if self.debug:
356
+ print(
357
+ f"Cropping with classes {crop_classes} and crop_add_background {crop_add_background} ratios {crop_ratios}"
358
+ )
359
+
360
+ if cache_class_indices:
361
+ ts.append(
362
+ ClassesToIndicesd(
363
+ keys=self.label_key,
364
+ num_classes=crop_classes,
365
+ indices_postfix="_cls_indices",
366
+ max_samples_per_class=max_samples_per_class,
367
+ )
368
+ )
369
+ indices_key = self.label_key + "_cls_indices"
370
+
371
+ num_crops_per_image = self.crop_params.get("num_crops_per_image", 1)
372
+ # if num_crops_per_image > 1:
373
+ # print(f"Cropping with num_crops_per_image {num_crops_per_image}")
374
+
375
+ ts.append(
376
+ RandCropByLabelClassesd(
377
+ keys=keys,
378
+ label_key=self.label_key,
379
+ num_classes=crop_classes,
380
+ spatial_size=self.roi_size,
381
+ num_samples=num_crops_per_image,
382
+ ratios=crop_ratios,
383
+ indices_key=indices_key,
384
+ warn=False,
385
+ )
386
+ )
387
+ elif self.crop_mode == "rand":
388
+ ts.append(RandSpatialCropd(keys=keys, roi_size=self.roi_size, random_size=False))
389
+ else:
390
+ raise ValueError("Unsupported crop mode" + str(self.crop_mode))
391
+
392
+ ts.extend(self.get_custom("after_crop_transforms"))
393
+
394
+ return ts
395
+
396
+ def get_augment_transforms(self):
397
+ ts = self.get_custom("augment_transforms")
398
+ if len(ts) > 0:
399
+ return ts
400
+
401
+ if self.roi_size is None:
402
+ raise ValueError("roi_size is not specified")
403
+
404
+ augment_mode = self.augment_params.get("augment_mode", None)
405
+ augment_flips = self.augment_params.get("augment_flips", None)
406
+ augment_rots = self.augment_params.get("augment_rots", None)
407
+
408
+ if self.debug:
409
+ print(f"Using augment_mode {augment_mode}, augment_flips {augment_flips} augment_rots {augment_rots}")
410
+
411
+ ts = []
412
+
413
+ if augment_mode is None or augment_mode == "default":
414
+
415
+ ts.append(
416
+ RandAffined(
417
+ keys=[self.image_key, self.label_key],
418
+ prob=0.2,
419
+ rotate_range=[0.26, 0.26, 0.26],
420
+ scale_range=[0.2, 0.2, 0.2],
421
+ mode=["bilinear", "nearest"],
422
+ spatial_size=self.roi_size,
423
+ cache_grid=True,
424
+ padding_mode="border",
425
+ )
426
+ )
427
+ ts.append(
428
+ RandGaussianSmoothd(
429
+ keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0]
430
+ )
431
+ )
432
+ ts.append(RandScaleIntensityd(keys=self.image_key, prob=0.5, factors=0.3))
433
+ ts.append(RandShiftIntensityd(keys=self.image_key, prob=0.5, offsets=0.1))
434
+ ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1))
435
+
436
+ elif augment_mode == "none":
437
+
438
+ augment_flips = []
439
+ augment_rots = []
440
+
441
+ elif augment_mode == "ct_ax_1":
442
+
443
+ ts.append(RandHistogramShiftd(keys="image", prob=0.5, num_control_points=16))
444
+ ts.append(RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.5, 3.0]))
445
+
446
+ ts.append(
447
+ RandAffined(
448
+ keys=[self.image_key, self.label_key],
449
+ prob=0.5,
450
+ rotate_range=[0, 0, 0.26],
451
+ scale_range=[0.35, 0.35, 0],
452
+ mode=["bilinear", "nearest"],
453
+ spatial_size=self.roi_size,
454
+ cache_grid=True,
455
+ padding_mode="border",
456
+ )
457
+ )
458
+
459
+ elif augment_mode == "mri_1":
460
+
461
+ ts.append(
462
+ RandAffined(
463
+ keys=[self.image_key, self.label_key],
464
+ prob=0.2,
465
+ rotate_range=[0.26, 0.26, 0.26],
466
+ scale_range=[0.2, 0.2, 0.2],
467
+ mode=["bilinear", "nearest"],
468
+ spatial_size=self.roi_size,
469
+ cache_grid=True,
470
+ padding_mode="border",
471
+ )
472
+ )
473
+
474
+ ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1))
475
+
476
+ ts.append(
477
+ RandGaussianSmoothd(
478
+ keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0]
479
+ )
480
+ )
481
+
482
+ ts.append(RandScaleIntensityFixedMeand(keys="image", prob=0.2, fixed_mean=True, factors=0.3))
483
+ ts.append(
484
+ RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=False)
485
+ )
486
+ ts.append(
487
+ RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=True)
488
+ )
489
+
490
+ else:
491
+ raise ValueError("Unsupported augment_mode: " + str(augment_mode))
492
+
493
+ # default to all flips
494
+ if augment_flips is None:
495
+ augment_flips = [0, 1, 2]
496
+ for sa in augment_flips:
497
+ ts.append(RandFlipd(keys=[self.image_key, self.label_key], prob=0.5, spatial_axis=sa))
498
+
499
+ # default to no rots
500
+ if augment_rots is not None:
501
+ for sa in augment_rots:
502
+ ts.append(RandRotate90d(keys=[self.image_key, self.label_key], prob=0.5, spatial_axes=sa))
503
+
504
+ ts.extend(self.get_custom("after_augment_transforms"))
505
+
506
+ return ts
507
+
508
+ def get_final_transforms(self):
509
+ return self.get_custom("final_transforms")
510
+
511
+ @classmethod
512
+ def get_postprocess_transform(
513
+ cls,
514
+ save_mask=False,
515
+ invert=False,
516
+ transform=None,
517
+ sigmoid=False,
518
+ output_path=None,
519
+ resample=False,
520
+ data_root_dir="",
521
+ output_dtype=np.uint8,
522
+ save_mask_mode=None,
523
+ ) -> Compose:
524
+ ts = []
525
+ if invert and transform is not None:
526
+ # if resample:
527
+ # ts.append(ToDeviced(keys="pred", device=torch.device("cpu")))
528
+ ts.append(Invertd(keys="pred", orig_keys="image", transform=transform, nearest_interp=False))
529
+
530
+ if save_mask and output_path is not None:
531
+ ts.append(CopyItemsd(keys="pred", times=1, names="seg"))
532
+ if save_mask_mode == "prob":
533
+ output_dtype = np.float32
534
+ else:
535
+ ts.append(
536
+ AsDiscreted(keys="seg", argmax=True) if not sigmoid else AsDiscreted(keys="seg", threshold=0.5)
537
+ )
538
+ ts.append(
539
+ SaveImaged(
540
+ keys=["seg"],
541
+ output_dir=output_path,
542
+ output_postfix="",
543
+ data_root_dir=data_root_dir,
544
+ output_dtype=output_dtype,
545
+ separate_folder=False,
546
+ squeeze_end_dims=True,
547
+ resample=False,
548
+ print_log=False,
549
+ )
550
+ )
551
+
552
+ return Compose(ts)
553
+
554
+ def __call__(self, augment=False, resample_label=False) -> Compose:
555
+ ts = []
556
+ ts.extend(self.get_load_transforms())
557
+ ts.extend(self.get_resample_transforms(resample_label=resample_label))
558
+ ts.extend(self.get_normalize_transforms())
559
+
560
+ if augment:
561
+ ts.extend(self.get_crop_transforms())
562
+ ts.extend(self.get_augment_transforms())
563
+
564
+ ts.extend(self.get_final_transforms())
565
+
566
+ compose_ts = Compose(ts)
567
+
568
+ return compose_ts
569
+
570
+ def __repr__(self) -> str:
571
+ out: str = f"DataTransformBuilder: with image_key: {self.image_key}, label_key: {self.label_key} \n"
572
+ out += f"roi_size {self.roi_size} resample {self.resample} resample_resolution {self.resample_resolution} \n"
573
+ out += f"normalize_mode {self.normalize_mode} normalize_params {self.normalize_params} \n"
574
+ out += f"crop_mode {self.crop_mode} crop_params {self.crop_params} \n"
575
+ out += f"extra_modalities {self.extra_modalities} \n"
576
+ for k, trs in self.custom_transforms.items():
577
+ out += f"Custom {k} : {str(trs)} \n"
578
+ return out
579
+
580
+
581
+ class Segmenter:
582
+ def __init__(
583
+ self,
584
+ config_file: Optional[Union[str, Sequence[str]]] = None,
585
+ config_dict: Dict = {},
586
+ rank: int = 0,
587
+ global_rank: int = 0,
588
+ ) -> None:
589
+ self.rank = rank
590
+ self.global_rank = global_rank
591
+ self.distributed = dist.is_initialized()
592
+
593
+ if self.global_rank == 0:
594
+ print(f"Segmenter started config_file: {config_file}, config_dict: {config_dict}")
595
+
596
+ np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
597
+ logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.WARNING)
598
+
599
+ config = self.parse_input_config(config_file=config_file, override=config_dict)
600
+ self.config = config
601
+ self.config_file = config_file if not isinstance(config_file, (list, tuple)) else config_file[0]
602
+ self.override = config_dict
603
+
604
+ if config["ckpt_path"] is not None and not os.path.exists(config["ckpt_path"]):
605
+ os.makedirs(config["ckpt_path"], exist_ok=True)
606
+
607
+ if config["log_output_file"] is None:
608
+ config["log_output_file"] = os.path.join(self.config["ckpt_path"], "training.log")
609
+ logger_configure(log_output_file=config["log_output_file"], debug=config["debug"], global_rank=self.global_rank)
610
+
611
+ if config["fork"] and "fork" in mp.get_all_start_methods():
612
+ mp.set_start_method("fork", force=True) # lambda functions fail to pickle without it
613
+ else:
614
+ warnings.warn(
615
+ "Multiprocessing method fork is not available, some non-picklable objects (e.g. lambda ) may fail"
616
+ )
617
+
618
+ if config["cuda"] and torch.cuda.is_available():
619
+ self.device = torch.device(self.rank)
620
+ if self.distributed and dist.get_backend() == dist.Backend.NCCL:
621
+ torch.cuda.set_device(rank)
622
+ else:
623
+ self.device = torch.device("cpu")
624
+
625
+ if self.global_rank == 0:
626
+ print(yaml.dump(config))
627
+
628
+ if config["determ"]:
629
+ set_determinism(seed=0)
630
+ elif torch.cuda.is_available():
631
+ torch.backends.cudnn.benchmark = True
632
+
633
+ if config["notf32"]:
634
+ torch.backends.cuda.matmul.allow_tf32 = False
635
+ torch.backends.cudnn.allow_tf32 = False
636
+ print(f"!!!disabling tf32")
637
+ if config.get("float32_precision", None) is not None:
638
+ torch.set_float32_matmul_precision(config["float32_precision"])
639
+ print(f"!!!setting matmul precession {config['float32_precision']}")
640
+
641
+ # auto adjust network settings
642
+ if config["auto_scale_allowed"]:
643
+ if config["auto_scale_batch"] or config["auto_scale_roi"] or config["auto_scale_filters"]:
644
+ roi_size, _, init_filters, batch_size = auto_adjust_network_settings(
645
+ auto_scale_batch=config["auto_scale_batch"],
646
+ auto_scale_roi=config["auto_scale_roi"],
647
+ auto_scale_filters=config["auto_scale_filters"],
648
+ image_size_mm=config["image_size_mm_median"],
649
+ spacing=config["resample_resolution"],
650
+ anisotropic_scales=config["anisotropic_scales"],
651
+ levels=len(config["network"]["blocks_down"]),
652
+ output_classes=config["output_classes"],
653
+ )
654
+
655
+ config["roi_size"] = roi_size
656
+ if config["auto_scale_batch"]:
657
+ config["batch_size"] = batch_size
658
+ if config["auto_scale_filters"] and config["pretrained_ckpt_name"] is None:
659
+ config["network"]["init_filters"] = init_filters
660
+
661
+ self.model = self.setup_model(pretrained_ckpt_name=config["pretrained_ckpt_name"])
662
+
663
+ loss_function = ConfigParser(config["loss"]).get_parsed_content(instantiate=True)
664
+ self.loss_function = DeepSupervisionLoss(loss_function)
665
+
666
+ dice_ignore_empty = config.get("dice_ignore_empty", True)
667
+ self.acc_function = DiceHelper(threshold=config["sigmoid"], ignore_empty=dice_ignore_empty)
668
+ self.amp_device_type = "cuda" if torch.cuda.is_available() else "cpu"
669
+ self.grad_scaler = GradScaler(self.amp_device_type, enabled=config["amp"])
670
+
671
+ if config.get("sliding_inferrer") is not None:
672
+ self.sliding_inferrer = ConfigParser(config["sliding_inferrer"]).get_parsed_content()
673
+ else:
674
+ self.sliding_inferrer = SlidingWindowInfererAdapt(
675
+ roi_size=config["roi_size"],
676
+ sw_batch_size=1,
677
+ overlap=0.625,
678
+ mode="gaussian",
679
+ cache_roi_weight_map=True,
680
+ progress=False,
681
+ )
682
+
683
+ self._data_transform_builder: DataTransformBuilder = None
684
+ self.lr_scheduler = None
685
+ self.optimizer = None
686
+
687
+ def get_custom_transforms(self):
688
+ config = self.config
689
+
690
+ # check for custom transforms
691
+ custom_transforms = {}
692
+ for tr in config.get("custom_data_transforms", []):
693
+ must_include_keys = ("key", "path", "transform")
694
+ if not all(k in tr for k in must_include_keys):
695
+ raise ValueError("custom transform must include " + str(must_include_keys))
696
+
697
+ if os.path.abspath(tr["path"]) not in sys.path:
698
+ sys.path.append(os.path.abspath(tr["path"]))
699
+
700
+ custom_transforms.setdefault(tr["key"], [])
701
+ custom_transforms[tr["key"]].append(tr["transform"])
702
+
703
+ if len(custom_transforms) > 0 and self.global_rank == 0:
704
+ print(f"Using custom transforms {custom_transforms}")
705
+
706
+ if isinstance(config["class_index"], list) and len(config["class_index"]) > 0:
707
+ # custom label embedding, if class_index provided
708
+ custom_transforms.setdefault("final_transforms", [])
709
+ custom_transforms["final_transforms"].append(
710
+ LabelEmbedClassIndex(keys="label", class_index=config["class_index"], allow_missing_keys=True)
711
+ )
712
+
713
+ return custom_transforms
714
+
715
+ def get_data_transform_builder(self):
716
+ if self._data_transform_builder is None:
717
+ config = self.config
718
+ custom_transforms = self.get_custom_transforms()
719
+
720
+ self._data_transform_builder = DataTransformBuilder(
721
+ roi_size=config["roi_size"],
722
+ resample=config["resample"],
723
+ resample_resolution=config["resample_resolution"],
724
+ normalize_mode=config["normalize_mode"],
725
+ normalize_params={
726
+ "intensity_bounds": config["intensity_bounds"],
727
+ "label_dtype": torch.uint8 if config["input_channels"] < 255 else torch.int16,
728
+ "image_dtype": torch.int16 if config.get("cache_image_int16", False) else None,
729
+ },
730
+ crop_mode=config["crop_mode"],
731
+ crop_params={
732
+ "output_classes": config["output_classes"],
733
+ "input_channels": config["input_channels"],
734
+ "crop_ratios": config["crop_ratios"],
735
+ "cache_class_indices": config["cache_class_indices"],
736
+ "num_crops_per_image": config["num_crops_per_image"],
737
+ "max_samples_per_class": config["max_samples_per_class"],
738
+ "crop_add_background": config["crop_add_background"],
739
+ },
740
+ augment_params={
741
+ "augment_mode": config.get("augment_mode", None),
742
+ "augment_flips": config.get("augment_flips", None),
743
+ "augment_rots": config.get("augment_rots", None),
744
+ },
745
+ extra_modalities=config["extra_modalities"],
746
+ custom_transforms=custom_transforms,
747
+ crop_foreground=config.get("crop_foreground", True),
748
+ sigmoid=config["sigmoid"],
749
+ orientation_ras=config.get("orientation_ras", False),
750
+ class_index=config["class_index"],
751
+ debug=config["debug"],
752
+ )
753
+
754
+ return self._data_transform_builder
755
+
756
+ def setup_model(self, pretrained_ckpt_name=None):
757
+ config = self.config
758
+ spatial_dims = config["network"].get("spatial_dims", 3)
759
+ norm_name, norm_args = split_args(config["network"].get("norm", ""))
760
+ norm_name = norm_name.upper()
761
+
762
+ if norm_name == "INSTANCE_NVFUSER":
763
+ _, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
764
+ if has_nvfuser and spatial_dims == 3:
765
+ act = config["network"].get("act", "relu")
766
+ if isinstance(act, str):
767
+ config["network"]["act"] = [act, {"inplace": False}]
768
+ else:
769
+ norm_name = "INSTANCE"
770
+
771
+ if len(norm_name) > 0:
772
+ config["network"]["norm"] = norm_name if len(norm_args) == 0 else [norm_name, norm_args]
773
+
774
+ if spatial_dims == 3:
775
+ if config.get("anisotropic_scales", False) and "SegResNetDS" in config["network"]["_target_"]:
776
+ config["network"]["resolution"] = copy.deepcopy(config["resample_resolution"])
777
+ if self.global_rank == 0:
778
+ print(f"Using anisotropic scales {config['network']}")
779
+
780
+ model = ConfigParser(config["network"]).get_parsed_content()
781
+
782
+ if self.global_rank == 0:
783
+ print(str(model))
784
+
785
+ if pretrained_ckpt_name is not None:
786
+ self.checkpoint_load(ckpt=pretrained_ckpt_name, model=model)
787
+
788
+ model = model.to(self.device)
789
+
790
+ if spatial_dims == 3:
791
+ memory_format = torch.channels_last_3d if config["channels_last"] else torch.preserve_format
792
+ model = model.to(memory_format=memory_format)
793
+
794
+ if self.distributed and not config["infer"]["enabled"]:
795
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
796
+ model = DistributedDataParallel(
797
+ module=model, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=False
798
+ )
799
+
800
+ if self.global_rank == 0:
801
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
802
+ print(f"Total parameters count: {pytorch_total_params} distributed: {self.distributed}")
803
+
804
+ return model
805
+
806
+ def parse_input_config(
807
+ self, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {}
808
+ ) -> Tuple[ConfigParser, Dict]:
809
+ config = {}
810
+ if config_file is None or override.get("use_ckpt_config", False):
811
+ # attempt to load config from model ckpt file
812
+ for ckpt_key in ["pretrained_ckpt_name", "validate#ckpt_name", "infer#ckpt_name", "finetune#ckpt_name"]:
813
+ ckpt = override.get(ckpt_key, None)
814
+ if ckpt and os.path.exists(ckpt):
815
+ checkpoint = torch.load(ckpt, map_location="cpu")
816
+ config = checkpoint.get("config", {})
817
+ if self.global_rank == 0:
818
+ print(f"Initializing config from the checkpoint {ckpt}: {yaml.dump(config)}")
819
+
820
+ if len(config) == 0 and config_file is None:
821
+ warnings.warn("No input config_file provided, and no valid checkpoints found")
822
+
823
+ if config_file is not None and len(config) == 0:
824
+ config = ConfigParser.load_config_files(config_file)
825
+ config.setdefault("finetune", {"enabled": False, "ckpt_name": None})
826
+ config.setdefault(
827
+ "validate", {"enabled": False, "ckpt_name": None, "save_mask": False, "output_path": None}
828
+ )
829
+ config.setdefault("infer", {"enabled": False, "ckpt_name": None})
830
+
831
+ parser = ConfigParser(config=config)
832
+ parser.update(pairs=override)
833
+ config = parser.config # just in case
834
+
835
+ if config.get("data_file_base_dir", None) is None or config.get("data_list_file_path", None) is None:
836
+ raise ValueError("CONFIG: data_file_base_dir and data_list_file_path must be provided")
837
+
838
+ if config.get("bundle_root", None) is None:
839
+ config["bundle_root"] = str(Path(__file__).parent.parent)
840
+
841
+ if "modality" not in config:
842
+ if self.global_rank == 0:
843
+ warnings.warn("CONFIG: modality is not provided, assuming MRI")
844
+ config["modality"] = "mri"
845
+
846
+ if "normalize_mode" not in config:
847
+ config["normalize_mode"] = "range" if config["modality"].lower() == "ct" else "meanstd"
848
+ if self.global_rank == 0:
849
+ print(f"CONFIG: normalize_mode is not provided, assuming: {config['normalize_mode']}")
850
+
851
+ # assign defaults
852
+ config.setdefault("debug", False)
853
+
854
+ config.setdefault("loss", None)
855
+ config.setdefault("acc", None)
856
+ config.setdefault("amp", True)
857
+ config.setdefault("cuda", True)
858
+ config.setdefault("fold", 0)
859
+ config.setdefault("batch_size", 1)
860
+ config.setdefault("determ", False)
861
+ config.setdefault("quick", False)
862
+ config.setdefault("sigmoid", False)
863
+ config.setdefault("cache_rate", None)
864
+ config.setdefault("cache_class_indices", None)
865
+ config.setdefault("crop_add_background", True)
866
+ config.setdefault("orientation_ras", False)
867
+
868
+ config.setdefault("channels_last", True)
869
+ config.setdefault("fork", True)
870
+
871
+ config.setdefault("num_epochs", 300)
872
+ config.setdefault("num_warmup_epochs", 3)
873
+ config.setdefault("num_epochs_per_validation", None)
874
+ config.setdefault("num_epochs_per_saving", 10)
875
+ config.setdefault("num_steps_per_image", None)
876
+ config.setdefault("num_crops_per_image", 1)
877
+ config.setdefault("max_samples_per_class", None)
878
+
879
+ config.setdefault("calc_val_loss", False)
880
+ config.setdefault("validate_final_original_res", False)
881
+ config.setdefault("early_stopping_fraction", 0)
882
+ config.setdefault("start_epoch", 0)
883
+
884
+ config.setdefault("ckpt_path", None)
885
+ config.setdefault("ckpt_save", True)
886
+ config.setdefault("log_output_file", None)
887
+
888
+ config.setdefault("crop_mode", "ratio")
889
+ config.setdefault("crop_ratios", None)
890
+ config.setdefault("resample_resolution", [1.0, 1.0, 1.0])
891
+ config.setdefault("resample", False)
892
+ config.setdefault("roi_size", [128, 128, 128])
893
+ config.setdefault("num_workers", 4)
894
+ config.setdefault("extra_modalities", {})
895
+ config.setdefault("intensity_bounds", [-250, 250])
896
+ config.setdefault("stop_on_lowacc", True)
897
+
898
+ config.setdefault("float32_precision", None)
899
+ config.setdefault("notf32", False)
900
+
901
+ config.setdefault("class_index", None)
902
+ config.setdefault("class_names", [])
903
+ if not isinstance(config["class_names"], (list, tuple)):
904
+ config["class_names"] = []
905
+
906
+ if len(config["class_names"]) == 0:
907
+ n_foreground_classes = int(config["output_classes"])
908
+ if not config["sigmoid"]:
909
+ n_foreground_classes -= 1
910
+ config["class_names"] = ["acc_" + str(i) for i in range(n_foreground_classes)]
911
+
912
+ pretrained_ckpt_name = config.get("pretrained_ckpt_name", None)
913
+ if pretrained_ckpt_name is None:
914
+ if config["validate"]["enabled"]:
915
+ pretrained_ckpt_name = config["validate"]["ckpt_name"]
916
+ elif config["infer"]["enabled"]:
917
+ pretrained_ckpt_name = config["infer"]["ckpt_name"]
918
+ elif config["finetune"]["enabled"]:
919
+ pretrained_ckpt_name = config["finetune"]["ckpt_name"]
920
+ config["pretrained_ckpt_name"] = pretrained_ckpt_name
921
+
922
+ config.setdefault("auto_scale_allowed", False)
923
+ config.setdefault("auto_scale_batch", False)
924
+ config.setdefault("auto_scale_roi", False)
925
+ config.setdefault("auto_scale_filters", False)
926
+
927
+ if pretrained_ckpt_name is not None:
928
+ config["auto_scale_roi"] = False
929
+ config["auto_scale_filters"] = False
930
+
931
+ if config["max_samples_per_class"] is None:
932
+ config["max_samples_per_class"] = 10 * config["num_epochs"]
933
+
934
+ if not torch.cuda.is_available() and config["cuda"]:
935
+ print("No cuda is available.! Running on CPU!!!")
936
+ config["cuda"] = False
937
+
938
+ config["amp"] = config["amp"] and config["cuda"]
939
+ config["rank"] = self.rank
940
+ config["global_rank"] = self.global_rank
941
+
942
+ # resolve content
943
+ for k, v in config.items():
944
+ if isinstance(v, dict) and "_target_" in v:
945
+ config[k] = parser.get_parsed_content(k, instantiate=False).config
946
+ elif "_target_" in str(v):
947
+ config[k] = copy.deepcopy(v)
948
+ else:
949
+ config[k] = parser.get_parsed_content(k)
950
+
951
+ return config
952
+
953
+ def config_save_updated(self, save_path=None):
954
+ if self.global_rank == 0 and self.config["auto_scale_allowed"]:
955
+ # reload input config
956
+ config = ConfigParser.load_config_files(self.config_file)
957
+ parser = ConfigParser(config=config)
958
+ parser.update(pairs=self.override)
959
+ config = parser.config
960
+
961
+ config["batch_size"] = self.config["batch_size"]
962
+ config["roi_size"] = self.config["roi_size"]
963
+ config["num_crops_per_image"] = self.config["num_crops_per_image"]
964
+
965
+ if "init_filters" in self.config["network"]:
966
+ config["network"]["init_filters"] = self.config["network"]["init_filters"]
967
+
968
+ if save_path is None:
969
+ save_path = self.config_file
970
+
971
+ print(f"Re-saving main config to {save_path}.")
972
+ ConfigParser.export_config_file(config, save_path, fmt="yaml", default_flow_style=None, sort_keys=False)
973
+
974
+ def config_with_relpath(self, config=None):
975
+ if config is None:
976
+ config = self.config
977
+ config = copy.deepcopy(config)
978
+ bundle_root = config["bundle_root"]
979
+
980
+ def convert_rel_path(conf):
981
+ for k, v in conf.items():
982
+ if isinstance(v, str) and v.startswith(bundle_root):
983
+ conf[k] = f"$@bundle_root + '/{os.path.relpath(v, bundle_root)}'"
984
+
985
+ convert_rel_path(config)
986
+ convert_rel_path(config["finetune"])
987
+ convert_rel_path(config["validate"])
988
+ convert_rel_path(config["infer"])
989
+ config["bundle_root"] = bundle_root
990
+
991
+ return config
992
+
993
+ def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs):
994
+ save_time = time.time()
995
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
996
+ state_dict = model.module.state_dict()
997
+ else:
998
+ state_dict = model.state_dict()
999
+
1000
+ config = self.config_with_relpath()
1001
+
1002
+ torch.save({"state_dict": state_dict, "config": config, **kwargs}, ckpt)
1003
+
1004
+ save_time = time.time() - save_time
1005
+ print(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s")
1006
+
1007
+ return save_time
1008
+
1009
+ def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs):
1010
+ if not os.path.isfile(ckpt):
1011
+ if self.global_rank == 0:
1012
+ warnings.warn("Invalid checkpoint file: " + str(ckpt))
1013
+ else:
1014
+ checkpoint = torch.load(ckpt, map_location="cpu")
1015
+ model.load_state_dict(checkpoint["state_dict"], strict=True)
1016
+ epoch = checkpoint.get("epoch", 0)
1017
+ best_metric = checkpoint.get("best_metric", 0)
1018
+
1019
+ if self.config.pop("continue", False):
1020
+ if "epoch" in checkpoint:
1021
+ self.config["start_epoch"] = checkpoint["epoch"]
1022
+ if "best_metric" in checkpoint:
1023
+ self.config["best_metric"] = checkpoint["best_metric"]
1024
+
1025
+ print(
1026
+ f"=> loaded checkpoint {ckpt} (epoch {epoch}) (best_metric {best_metric}) setting start_epoch {self.config['start_epoch']}"
1027
+ )
1028
+ self.config["start_epoch"] = int(self.config["start_epoch"]) + 1
1029
+
1030
+ def get_shared_memory_list(self, length=0):
1031
+ mp.current_process().authkey = np.arange(32, dtype=np.uint8).tobytes()
1032
+ shl0 = mp.Manager().list([None] * length)
1033
+
1034
+ if self.distributed:
1035
+ # to support multi-node training, we need check for a local process group
1036
+ is_multinode = False
1037
+
1038
+ if dist_launched():
1039
+ local_world_size = int(os.getenv("LOCAL_WORLD_SIZE"))
1040
+ world_size = int(os.getenv("WORLD_SIZE"))
1041
+ group_rank = int(os.getenv("GROUP_RANK"))
1042
+ if world_size > local_world_size:
1043
+ is_multinode = True
1044
+ # we're in multi-node, get local world sizes
1045
+ lw = torch.tensor(local_world_size, dtype=torch.int, device=self.device)
1046
+ lw_sizes = [torch.zeros_like(lw) for _ in range(world_size)]
1047
+ dist.all_gather(tensor_list=lw_sizes, tensor=lw)
1048
+
1049
+ src = g_rank = 0
1050
+ while src < world_size:
1051
+ # create sub-groups local to a node, to share memory only within a node
1052
+ # and broadcast shared list within a node
1053
+ group = dist.new_group(ranks=list(range(src, src + local_world_size)))
1054
+ if group_rank == g_rank:
1055
+ shl_list = [shl0]
1056
+ dist.broadcast_object_list(shl_list, src=src, group=group, device=self.device)
1057
+ shl = shl_list[0]
1058
+ dist.destroy_process_group(group)
1059
+ src = src + lw_sizes[src].item() # rank of first process in the next node
1060
+ g_rank += 1
1061
+
1062
+ if not is_multinode:
1063
+ shl_list = [shl0]
1064
+ dist.broadcast_object_list(shl_list, src=0, device=self.device)
1065
+ shl = shl_list[0]
1066
+
1067
+ else:
1068
+ shl = shl0
1069
+
1070
+ return shl
1071
+
1072
+ def get_train_loader(self, data, cache_rate=0, persistent_workers=False):
1073
+ distributed = self.distributed
1074
+ num_workers = self.config["num_workers"]
1075
+ batch_size = self.config["batch_size"]
1076
+
1077
+ train_transform = self.get_data_transform_builder()(augment=True, resample_label=True)
1078
+
1079
+ if cache_rate > 0:
1080
+ runtime_cache = self.get_shared_memory_list(length=len(data))
1081
+ train_ds = CacheDataset(
1082
+ data=data,
1083
+ transform=train_transform,
1084
+ copy_cache=False,
1085
+ cache_rate=cache_rate,
1086
+ runtime_cache=runtime_cache,
1087
+ )
1088
+ else:
1089
+ train_ds = Dataset(data=data, transform=train_transform)
1090
+
1091
+ train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
1092
+ train_loader = DataLoader(
1093
+ train_ds,
1094
+ batch_size=batch_size,
1095
+ shuffle=(train_sampler is None),
1096
+ num_workers=num_workers,
1097
+ sampler=train_sampler,
1098
+ persistent_workers=persistent_workers and num_workers > 0,
1099
+ pin_memory=True,
1100
+ )
1101
+
1102
+ return train_loader
1103
+
1104
+ def get_val_loader(self, data, cache_rate=0, resample_label=False, persistent_workers=False):
1105
+ distributed = self.distributed
1106
+ num_workers = self.config["num_workers"]
1107
+
1108
+ val_transform = self.get_data_transform_builder()(augment=False, resample_label=resample_label)
1109
+
1110
+ if cache_rate > 0:
1111
+ runtime_cache = self.get_shared_memory_list(length=len(data))
1112
+ val_ds = CacheDataset(
1113
+ data=data, transform=val_transform, copy_cache=False, cache_rate=cache_rate, runtime_cache=runtime_cache
1114
+ )
1115
+ else:
1116
+ val_ds = Dataset(data=data, transform=val_transform)
1117
+
1118
+ val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None
1119
+ val_loader = DataLoader(
1120
+ val_ds,
1121
+ batch_size=1,
1122
+ shuffle=False,
1123
+ num_workers=num_workers,
1124
+ sampler=val_sampler,
1125
+ persistent_workers=persistent_workers and num_workers > 0,
1126
+ pin_memory=True,
1127
+ )
1128
+
1129
+ return val_loader
1130
+
1131
+ def train(self):
1132
+ if self.global_rank == 0:
1133
+ print("Segmenter train called")
1134
+
1135
+ if self.loss_function is None:
1136
+ raise ValueError("CONFIG loss function is not provided")
1137
+ if self.acc_function is None:
1138
+ raise ValueError("CONFIG accuracy function is not provided")
1139
+
1140
+ config = self.config
1141
+ distributed = self.distributed
1142
+ sliding_inferrer = self.sliding_inferrer
1143
+
1144
+ loss_function = self.loss_function
1145
+ acc_function = self.acc_function
1146
+ grad_scaler = self.grad_scaler
1147
+
1148
+ use_amp = config["amp"]
1149
+ use_cuda = config["cuda"]
1150
+ ckpt_path = config["ckpt_path"]
1151
+ sigmoid = config["sigmoid"]
1152
+ channels_last = config["channels_last"]
1153
+ calc_val_loss = config["calc_val_loss"]
1154
+
1155
+ data_list_file_path = config["data_list_file_path"]
1156
+ if not os.path.isabs(data_list_file_path):
1157
+ data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path))
1158
+
1159
+ if config.get("validation_key", None) is not None:
1160
+ train_files, _ = datafold_read(datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=-1)
1161
+ validation_files, _ = datafold_read(
1162
+ datalist=data_list_file_path,
1163
+ basedir=config["data_file_base_dir"],
1164
+ fold=-1,
1165
+ key=config["validation_key"],
1166
+ )
1167
+ else:
1168
+ train_files, validation_files = datafold_read(
1169
+ datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"]
1170
+ )
1171
+
1172
+ if config["quick"]: # quick run on a smaller subset of files
1173
+ train_files, validation_files = train_files[:8], validation_files[:8]
1174
+ if self.global_rank == 0:
1175
+ print(f"train_files files {len(train_files)}, validation files {len(validation_files)}")
1176
+
1177
+ if len(validation_files) == 0:
1178
+ warnings.warn("No validation files found!")
1179
+
1180
+ cache_rate_train, cache_rate_val = self.get_cache_rate(
1181
+ train_cases=len(train_files), validation_cases=len(validation_files)
1182
+ )
1183
+
1184
+ if config["cache_class_indices"] is None:
1185
+ config["cache_class_indices"] = cache_rate_train > 0
1186
+
1187
+ if self.global_rank == 0:
1188
+ print(
1189
+ f"Auto setting max_samples_per_class: {config['max_samples_per_class']} cache_class_indices: {config['cache_class_indices']}"
1190
+ )
1191
+
1192
+ num_steps_per_image = config["num_steps_per_image"]
1193
+ if config["auto_scale_allowed"] and num_steps_per_image is None:
1194
+ be = config["batch_size"]
1195
+
1196
+ if config["crop_mode"] == "ratio":
1197
+ config["num_crops_per_image"] = config["batch_size"]
1198
+ config["batch_size"] = 1
1199
+ else:
1200
+ config["num_crops_per_image"] = 1
1201
+
1202
+ if cache_rate_train < 0.75:
1203
+ num_steps_per_image = max(1, 4 // be)
1204
+ else:
1205
+ num_steps_per_image = 1
1206
+
1207
+ elif num_steps_per_image is None:
1208
+ num_steps_per_image = 1
1209
+
1210
+ num_crops_per_image = int(config["num_crops_per_image"])
1211
+ num_epochs_per_saving = max(1, config["num_epochs_per_saving"] // num_crops_per_image)
1212
+ num_warmup_epochs = max(3, config["num_warmup_epochs"] // num_crops_per_image)
1213
+ num_epochs_per_validation = config["num_epochs_per_validation"]
1214
+ num_epochs = max(1, config["num_epochs"] // min(3, num_crops_per_image))
1215
+ if self.global_rank == 0:
1216
+ print(
1217
+ f"Given num_crops_per_image {num_crops_per_image}, num_epochs was adjusted {config['num_epochs']} => {num_epochs}"
1218
+ )
1219
+
1220
+ if num_epochs_per_validation is not None:
1221
+ num_epochs_per_validation = max(1, num_epochs_per_validation // num_crops_per_image)
1222
+
1223
+ val_schedule_list = schedule_validation_epochs(
1224
+ num_epochs=num_epochs,
1225
+ num_epochs_per_validation=num_epochs_per_validation,
1226
+ fraction=min(0.3, 0.16 * num_crops_per_image),
1227
+ )
1228
+ if self.global_rank == 0:
1229
+ print(f"Scheduling validation loops at epochs: {val_schedule_list}")
1230
+
1231
+ train_loader = self.get_train_loader(data=train_files, cache_rate=cache_rate_train, persistent_workers=True)
1232
+
1233
+ val_loader = self.get_val_loader(
1234
+ data=validation_files, cache_rate=cache_rate_val, resample_label=True, persistent_workers=True
1235
+ )
1236
+
1237
+ optim_name = config.get("optim_name", None) # experimental
1238
+ if optim_name is not None:
1239
+ if self.global_rank == 0:
1240
+ print(f"Using optimizer: {optim_name}")
1241
+ if optim_name == "fusednovograd":
1242
+ import apex
1243
+
1244
+ optimizer = apex.optimizers.FusedNovoGrad(
1245
+ params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5
1246
+ )
1247
+ elif optim_name == "sgd":
1248
+ momentum = config.get("sgd_momentum", 0.9)
1249
+ optimizer = torch.optim.SGD(
1250
+ params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5, momentum=momentum
1251
+ )
1252
+ if self.global_rank == 0:
1253
+ print(f"Using momentum: {momentum}")
1254
+ else:
1255
+ raise ValueError("Unsupported optim_name" + str(optim_name))
1256
+
1257
+ elif self.optimizer is None:
1258
+ optimizer_part = ConfigParser(config["optimizer"]).get_parsed_content(instantiate=False)
1259
+ optimizer = optimizer_part.instantiate(params=self.model.parameters())
1260
+ else:
1261
+ optimizer = self.optimizer
1262
+
1263
+ tb_writer = None
1264
+ csv_path = progress_path = None
1265
+
1266
+ if self.global_rank == 0 and ckpt_path is not None:
1267
+ # rank 0 is responsible for heavy lifting of logging/saving
1268
+ progress_path = os.path.join(ckpt_path, "progress.yaml")
1269
+
1270
+ tb_writer = SummaryWriter(log_dir=ckpt_path)
1271
+ print(f"Writing Tensorboard logs to {tb_writer.log_dir}")
1272
+
1273
+ if mlflow_is_imported:
1274
+ mlflow.set_tracking_uri(config["mlflow_tracking_uri"])
1275
+ mlflow.set_experiment(config["mlflow_experiment_name"])
1276
+ mlflow.start_run(run_name=f'segresnet - fold{config["fold"]} - train')
1277
+
1278
+ csv_path = os.path.join(ckpt_path, "accuracy_history.csv")
1279
+ self.save_history_csv(
1280
+ csv_path=csv_path,
1281
+ header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"],
1282
+ )
1283
+
1284
+ do_torch_save = (self.global_rank == 0) and ckpt_path is not None and config["ckpt_save"]
1285
+ best_ckpt_path = os.path.join(ckpt_path, "model.pt")
1286
+ intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt")
1287
+
1288
+ best_metric = -1
1289
+ best_metric_epoch = -1
1290
+ pre_loop_time = time.time()
1291
+ report_num_epochs = num_epochs * num_crops_per_image
1292
+ train_time = validation_time = 0
1293
+ val_acc_history = []
1294
+
1295
+ start_epoch = config["start_epoch"]
1296
+ if "best_metric" in config:
1297
+ best_metric = float(config["best_metric"])
1298
+
1299
+ start_epoch = start_epoch // num_crops_per_image
1300
+ if start_epoch > 0:
1301
+ val_schedule_list = [v for v in val_schedule_list if v >= start_epoch]
1302
+ if len(val_schedule_list) == 0:
1303
+ val_schedule_list = [start_epoch]
1304
+ print(f"adjusted schedule_list {val_schedule_list}")
1305
+
1306
+ if self.global_rank == 0:
1307
+ print(
1308
+ f"Using num_epochs => {num_epochs}\n "
1309
+ f"Using start_epoch => {start_epoch}\n "
1310
+ f"batch_size => {config['batch_size']} \n "
1311
+ f"num_crops_per_image => {config['num_crops_per_image']} \n "
1312
+ f"num_steps_per_image => {num_steps_per_image} \n "
1313
+ f"num_warmup_epochs => {num_warmup_epochs} \n "
1314
+ )
1315
+
1316
+ if self.lr_scheduler is None:
1317
+ lr_scheduler = WarmupCosineSchedule(
1318
+ optimizer=optimizer, warmup_steps=num_warmup_epochs, warmup_multiplier=0.1, t_total=num_epochs
1319
+ )
1320
+ else:
1321
+ lr_scheduler = self.lr_scheduler
1322
+ if lr_scheduler is not None and start_epoch > 0:
1323
+ lr_scheduler.last_epoch = start_epoch
1324
+
1325
+ range_num_epochs = range(start_epoch, num_epochs)
1326
+ if self.global_rank == 0 and has_tqdm and not config["debug"]:
1327
+ range_num_epochs = tqdm(
1328
+ range(start_epoch, num_epochs),
1329
+ desc=str(os.path.basename(config["bundle_root"])) + " - training",
1330
+ unit="epoch",
1331
+ )
1332
+
1333
+ if distributed:
1334
+ dist.barrier()
1335
+ self.config_save_updated(save_path=self.config_file) # overwriting main input config
1336
+
1337
+ for epoch in range_num_epochs:
1338
+ report_epoch = epoch * num_crops_per_image
1339
+
1340
+ if distributed:
1341
+ if isinstance(train_loader.sampler, DistributedSampler):
1342
+ train_loader.sampler.set_epoch(epoch)
1343
+ dist.barrier()
1344
+
1345
+ epoch_time = start_time = time.time()
1346
+
1347
+ train_loss, train_acc = 0, 0
1348
+ if not config.get("skip_train", False):
1349
+ train_loss, train_acc = self.train_epoch(
1350
+ model=self.model,
1351
+ train_loader=train_loader,
1352
+ optimizer=optimizer,
1353
+ loss_function=loss_function,
1354
+ acc_function=acc_function,
1355
+ grad_scaler=grad_scaler,
1356
+ epoch=report_epoch,
1357
+ rank=self.rank,
1358
+ global_rank=self.global_rank,
1359
+ num_epochs=report_num_epochs,
1360
+ sigmoid=sigmoid,
1361
+ use_amp=use_amp,
1362
+ use_cuda=use_cuda,
1363
+ channels_last=channels_last,
1364
+ num_steps_per_image=num_steps_per_image,
1365
+ )
1366
+
1367
+ train_time = time.time() - start_time
1368
+
1369
+ if self.global_rank == 0:
1370
+ print(
1371
+ f"Final training {report_epoch}/{report_num_epochs - 1} "
1372
+ f"loss: {train_loss:.4f} acc_avg: {np.mean(train_acc):.4f} "
1373
+ f"acc {train_acc} time {train_time:.2f}s "
1374
+ f"lr: {optimizer.param_groups[0]['lr']:.4e}"
1375
+ )
1376
+
1377
+ if tb_writer is not None:
1378
+ tb_writer.add_scalar("train/loss", train_loss, report_epoch)
1379
+ tb_writer.add_scalar("train/acc", np.mean(train_acc), report_epoch)
1380
+ if mlflow_is_imported:
1381
+ mlflow.log_metric("train/loss", train_loss, step=report_epoch)
1382
+
1383
+ # validate every num_epochs_per_validation epochs (defaults to 1, every epoch)
1384
+ val_acc_mean = -1
1385
+ if (
1386
+ len(val_schedule_list) > 0
1387
+ and epoch + 1 >= val_schedule_list[0]
1388
+ and val_loader is not None
1389
+ and len(val_loader) > 0
1390
+ ):
1391
+ val_schedule_list.pop(0)
1392
+
1393
+ start_time = time.time()
1394
+ torch.cuda.empty_cache()
1395
+
1396
+ val_loss, val_acc = self.val_epoch(
1397
+ model=self.model,
1398
+ val_loader=val_loader,
1399
+ sliding_inferrer=sliding_inferrer,
1400
+ loss_function=loss_function,
1401
+ acc_function=acc_function,
1402
+ epoch=report_epoch,
1403
+ rank=self.rank,
1404
+ global_rank=self.global_rank,
1405
+ num_epochs=report_num_epochs,
1406
+ sigmoid=sigmoid,
1407
+ use_amp=use_amp,
1408
+ use_cuda=use_cuda,
1409
+ channels_last=channels_last,
1410
+ calc_val_loss=calc_val_loss,
1411
+ )
1412
+
1413
+ torch.cuda.empty_cache()
1414
+ validation_time = time.time() - start_time
1415
+
1416
+ val_acc_mean = float(np.mean(val_acc))
1417
+ val_acc_history.append((report_epoch, val_acc_mean))
1418
+
1419
+ if self.global_rank == 0:
1420
+ print(
1421
+ f"Final validation {report_epoch}/{report_num_epochs - 1} "
1422
+ f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s"
1423
+ )
1424
+
1425
+ if tb_writer is not None:
1426
+ tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch)
1427
+ if mlflow_is_imported:
1428
+ mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch)
1429
+
1430
+ for i in range(min(len(config["class_names"]), len(val_acc))): # accuracy per class
1431
+ tb_writer.add_scalar("val_class/" + config["class_names"][i], val_acc[i], report_epoch)
1432
+ if mlflow_is_imported:
1433
+ mlflow.log_metric(
1434
+ "val_class/" + config["class_names"][i], val_acc[i], step=report_epoch
1435
+ )
1436
+
1437
+ if calc_val_loss:
1438
+ tb_writer.add_scalar("val/loss", val_loss, report_epoch)
1439
+
1440
+ timing_dict = dict(
1441
+ time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600),
1442
+ train_time="{:.2f}s".format(train_time),
1443
+ validation_time="{:.2f}s".format(validation_time),
1444
+ epoch_time="{:.2f}s".format(time.time() - epoch_time),
1445
+ )
1446
+
1447
+ if val_acc_mean > best_metric:
1448
+ print(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ")
1449
+ best_metric, best_metric_epoch = val_acc_mean, report_epoch
1450
+ save_time = 0
1451
+ if do_torch_save:
1452
+ save_time = self.checkpoint_save(
1453
+ ckpt=best_ckpt_path, model=self.model, epoch=best_metric_epoch, best_metric=best_metric
1454
+ )
1455
+
1456
+ if progress_path is not None:
1457
+ self.save_progress_yaml(
1458
+ progress_path=progress_path,
1459
+ ckpt=best_ckpt_path if do_torch_save else None,
1460
+ best_avg_dice_score_epoch=best_metric_epoch,
1461
+ best_avg_dice_score=best_metric,
1462
+ save_time=save_time,
1463
+ **timing_dict,
1464
+ )
1465
+ if csv_path is not None:
1466
+ self.save_history_csv(
1467
+ csv_path=csv_path,
1468
+ epoch=report_epoch,
1469
+ metric="{:.4f}".format(val_acc_mean),
1470
+ loss="{:.4f}".format(train_loss),
1471
+ iter=report_epoch * len(train_loader.dataset),
1472
+ **timing_dict,
1473
+ )
1474
+
1475
+ # sanity check
1476
+ if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config["stop_on_lowacc"]:
1477
+ raise ValueError(
1478
+ f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. "
1479
+ f"Most likely optimization diverged, try setting a smaller learning_rate than {config['learning_rate']}"
1480
+ )
1481
+
1482
+ # early stopping
1483
+ if config["early_stopping_fraction"] > 0 and epoch > num_epochs / 2 and len(val_acc_history) > 10:
1484
+ check_interval = int(0.1 * num_epochs * num_crops_per_image)
1485
+ check_stats = [
1486
+ va[1] for va in val_acc_history if report_epoch - va[0] < check_interval
1487
+ ] # at least 10% epochs
1488
+ if len(check_stats) < 10:
1489
+ check_stats = [va[1] for va in val_acc_history[-10:]] # at least 10 sample points
1490
+ mac, mic = max(check_stats), min(check_stats)
1491
+
1492
+ early_stopping_fraction = (mac - mic) / (abs(mac) + 1e-8)
1493
+ if mac > 0 and mic > 0 and early_stopping_fraction < config["early_stopping_fraction"]:
1494
+ if self.global_rank == 0:
1495
+ print(
1496
+ f"Early stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}"
1497
+ )
1498
+ break
1499
+ else:
1500
+ if self.global_rank == 0:
1501
+ print(
1502
+ f"No stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}"
1503
+ )
1504
+
1505
+ # save intermediate checkpoint every num_epochs_per_saving epochs
1506
+ if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs):
1507
+ if report_epoch != best_metric_epoch:
1508
+ self.checkpoint_save(
1509
+ ckpt=intermediate_ckpt_path, model=self.model, epoch=report_epoch, best_metric=val_acc_mean
1510
+ )
1511
+ else:
1512
+ shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) # if already saved once
1513
+
1514
+ if lr_scheduler is not None:
1515
+ lr_scheduler.step()
1516
+
1517
+ if self.global_rank == 0:
1518
+ # report time estimate
1519
+ time_remaining_estimate = train_time * (num_epochs - epoch)
1520
+ if val_loader is not None and len(val_loader) > 0:
1521
+ if validation_time == 0:
1522
+ validation_time = train_time
1523
+ time_remaining_estimate += validation_time * len(val_schedule_list)
1524
+
1525
+ print(
1526
+ f"Estimated remaining training time for the current model fold {config['fold']} is "
1527
+ f"{time_remaining_estimate/3600:.2f} hr, "
1528
+ f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, "
1529
+ f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n"
1530
+ )
1531
+
1532
+ # end of main epoch loop
1533
+
1534
+ train_loader = val_loader = optimizer = None
1535
+
1536
+ # optionally validate best checkpoint at the original image resolution
1537
+ orig_res = config["resample"] == False
1538
+ if config["validate_final_original_res"] and config["resample"]:
1539
+ pretrained_ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path
1540
+ if os.path.exists(pretrained_ckpt_name):
1541
+ self.model = None
1542
+ gc.collect()
1543
+ torch.cuda.empty_cache()
1544
+
1545
+ best_metric = self.original_resolution_validate(
1546
+ pretrained_ckpt_name=pretrained_ckpt_name,
1547
+ progress_path=progress_path,
1548
+ best_metric_epoch=best_metric_epoch,
1549
+ pre_loop_time=pre_loop_time,
1550
+ )
1551
+ orig_res = True
1552
+ else:
1553
+ if self.global_rank == 0:
1554
+ print(
1555
+ f"Unable to validate at the original res since no model checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}"
1556
+ )
1557
+
1558
+ if tb_writer is not None:
1559
+ tb_writer.flush()
1560
+ tb_writer.close()
1561
+
1562
+ if mlflow_is_imported:
1563
+ mlflow.end_run()
1564
+
1565
+ if self.global_rank == 0:
1566
+ print(
1567
+ f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs} orig_res {orig_res}. Training time {(time.time() - pre_loop_time)/3600:.2f} hr."
1568
+ )
1569
+
1570
+ return best_metric
1571
+
1572
+ def original_resolution_validate(self, pretrained_ckpt_name, progress_path, best_metric_epoch, pre_loop_time):
1573
+ if self.global_rank == 0:
1574
+ print("Running final best model validation on the original image resolution!")
1575
+
1576
+ self.model = self.setup_model(pretrained_ckpt_name=pretrained_ckpt_name)
1577
+
1578
+ # validate
1579
+ start_time = time.time()
1580
+ val_acc_mean, val_loss, val_acc = self.validate()
1581
+ validation_time = "{:.2f}s".format(time.time() - start_time)
1582
+ val_acc_mean = float(np.mean(val_acc))
1583
+ if self.global_rank == 0:
1584
+ print(
1585
+ f"Original resolution validation: "
1586
+ f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} "
1587
+ f"acc {val_acc} time {validation_time}"
1588
+ )
1589
+
1590
+ if progress_path is not None:
1591
+ self.save_progress_yaml(
1592
+ progress_path=progress_path,
1593
+ ckpt=pretrained_ckpt_name,
1594
+ best_avg_dice_score_epoch=best_metric_epoch,
1595
+ best_avg_dice_score=val_acc_mean,
1596
+ validation_time=validation_time,
1597
+ inverted_best_validation=True,
1598
+ time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600),
1599
+ )
1600
+
1601
+ return val_acc_mean
1602
+
1603
+ def validate(self, validation_files=None):
1604
+ config = self.config
1605
+ resample = config["resample"]
1606
+
1607
+ val_config = self.config["validate"]
1608
+ output_path = val_config.get("output_path", None)
1609
+ save_mask = val_config.get("save_mask", False) and output_path is not None
1610
+ invert = val_config.get("invert", True)
1611
+
1612
+ data_list_file_path = config["data_list_file_path"]
1613
+ if not os.path.isabs(data_list_file_path):
1614
+ data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path))
1615
+
1616
+ if validation_files is None:
1617
+ if config.get("validation_key", None) is not None:
1618
+ validation_files, _ = datafold_read(
1619
+ datalist=data_list_file_path,
1620
+ basedir=config["data_file_base_dir"],
1621
+ fold=-1,
1622
+ key=config["validation_key"],
1623
+ )
1624
+ else:
1625
+ _, validation_files = datafold_read(
1626
+ datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"]
1627
+ )
1628
+
1629
+ if self.global_rank == 0:
1630
+ print(f"validation files {len(validation_files)}")
1631
+
1632
+ if len(validation_files) == 0:
1633
+ warnings.warn("No validation files found!")
1634
+ return
1635
+
1636
+ val_loader = self.get_val_loader(data=validation_files, resample_label=not invert)
1637
+ val_transform = val_loader.dataset.transform
1638
+
1639
+ post_transforms = None
1640
+ if save_mask or invert:
1641
+ post_transforms = DataTransformBuilder.get_postprocess_transform(
1642
+ save_mask=save_mask,
1643
+ invert=invert,
1644
+ transform=val_transform,
1645
+ sigmoid=self.config["sigmoid"],
1646
+ output_path=output_path,
1647
+ resample=resample,
1648
+ data_root_dir=self.config["data_file_base_dir"],
1649
+ output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
1650
+ save_mask_mode=self.config.get("save_mask_mode", None),
1651
+ )
1652
+
1653
+ start_time = time.time()
1654
+ val_loss, val_acc = self.val_epoch(
1655
+ model=self.model,
1656
+ val_loader=val_loader,
1657
+ sliding_inferrer=self.sliding_inferrer,
1658
+ loss_function=self.loss_function,
1659
+ acc_function=self.acc_function,
1660
+ rank=self.rank,
1661
+ global_rank=self.global_rank,
1662
+ sigmoid=self.config["sigmoid"],
1663
+ use_amp=self.config["amp"],
1664
+ use_cuda=self.config["cuda"],
1665
+ post_transforms=post_transforms,
1666
+ channels_last=self.config["channels_last"],
1667
+ calc_val_loss=self.config["calc_val_loss"],
1668
+ )
1669
+ val_acc_mean = float(np.mean(val_acc))
1670
+
1671
+ if self.global_rank == 0:
1672
+ print(
1673
+ f"Validation complete, loss_avg: {val_loss:.4f} "
1674
+ f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s"
1675
+ )
1676
+
1677
+ return val_acc_mean, val_loss, val_acc
1678
+
1679
+ def infer(self, testing_files=None):
1680
+ output_path = self.config["infer"].get("output_path", None)
1681
+ testing_key = self.config["infer"].get("data_list_key", "testing")
1682
+
1683
+ if output_path is None:
1684
+ if self.global_rank == 0:
1685
+ print("Inference output_path is not specified")
1686
+ return
1687
+
1688
+ if testing_files is None:
1689
+ data_list_file_path = self.config["data_list_file_path"]
1690
+ if not os.path.isabs(data_list_file_path):
1691
+ data_list_file_path = os.path.abspath(os.path.join(self.config["bundle_root"], data_list_file_path))
1692
+
1693
+ testing_files, _ = datafold_read(
1694
+ datalist=data_list_file_path, basedir=self.config["data_file_base_dir"], fold=-1, key=testing_key
1695
+ )
1696
+
1697
+ if self.global_rank == 0:
1698
+ print(f"testing_files files {len(testing_files)}")
1699
+
1700
+ if len(testing_files) == 0:
1701
+ warnings.warn("No testing_files files found!")
1702
+ return
1703
+
1704
+ inf_loader = self.get_val_loader(data=testing_files, resample_label=False)
1705
+ inf_transform = inf_loader.dataset.transform
1706
+
1707
+ post_transforms = DataTransformBuilder.get_postprocess_transform(
1708
+ save_mask=True,
1709
+ invert=True,
1710
+ transform=inf_transform,
1711
+ sigmoid=self.config["sigmoid"],
1712
+ output_path=output_path,
1713
+ resample=self.config["resample"],
1714
+ data_root_dir=self.config["data_file_base_dir"],
1715
+ output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
1716
+ save_mask_mode=self.config.get("save_mask_mode", None),
1717
+ )
1718
+
1719
+ start_time = time.time()
1720
+ self.val_epoch(
1721
+ model=self.model,
1722
+ val_loader=inf_loader,
1723
+ sliding_inferrer=self.sliding_inferrer,
1724
+ rank=self.rank,
1725
+ global_rank=self.global_rank,
1726
+ sigmoid=self.config["sigmoid"],
1727
+ use_amp=self.config["amp"],
1728
+ use_cuda=self.config["cuda"],
1729
+ post_transforms=post_transforms,
1730
+ channels_last=self.config["channels_last"],
1731
+ calc_val_loss=self.config["calc_val_loss"],
1732
+ )
1733
+
1734
+ if self.global_rank == 0:
1735
+ print(f"Inference complete, time {time.time() - start_time:.2f}s")
1736
+
1737
+ @torch.no_grad()
1738
+ def infer_image(self, image_file):
1739
+ self.model.eval()
1740
+
1741
+ infer_config = self.config["infer"]
1742
+ output_path = infer_config.get("output_path", None)
1743
+ save_mask = infer_config.get("save_mask", False) and output_path is not None
1744
+ invert_on_gpu = infer_config.get("invert_on_gpu", False)
1745
+
1746
+ start_time = time.time()
1747
+ sigmoid = self.config["sigmoid"]
1748
+ resample = self.config["resample"]
1749
+ channels_last = self.config["channels_last"]
1750
+
1751
+ inf_transform = self.get_data_transform_builder()(augment=False, resample_label=False)
1752
+
1753
+ batch_data = inf_transform([image_file])
1754
+ batch_data = list_data_collate([batch_data])
1755
+
1756
+ memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
1757
+ data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=self.device)
1758
+
1759
+ with autocast(self.amp_device_type, enabled=self.config["amp"]):
1760
+ logits = self.sliding_inferrer(inputs=data, network=self.model)
1761
+
1762
+ data = None
1763
+
1764
+ try:
1765
+ pred = self.logits2pred(logits, sigmoid=sigmoid)
1766
+ except RuntimeError as e:
1767
+ if not logits.is_cuda:
1768
+ raise e
1769
+ print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape}")
1770
+ logits = logits.cpu()
1771
+ pred = self.logits2pred(logits, sigmoid=sigmoid)
1772
+
1773
+ logits = None
1774
+
1775
+ if not invert_on_gpu:
1776
+ pred = pred.cpu() # invert on cpu (default)
1777
+
1778
+ post_transforms = DataTransformBuilder.get_postprocess_transform(
1779
+ save_mask=save_mask,
1780
+ invert=True,
1781
+ transform=inf_transform,
1782
+ sigmoid=sigmoid,
1783
+ output_path=output_path,
1784
+ resample=resample,
1785
+ data_root_dir=self.config["data_file_base_dir"],
1786
+ output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
1787
+ save_mask_mode=self.config.get("save_mask_mode", None),
1788
+ )
1789
+
1790
+ batch_data["pred"] = convert_to_dst_type(pred, batch_data["image"], dtype=pred.dtype, device=pred.device)[
1791
+ 0
1792
+ ] # make Meta tensor
1793
+ pred = [post_transforms(x)["pred"] for x in decollate_batch(batch_data)]
1794
+
1795
+ pred = pred[0]
1796
+
1797
+ print(f"Inference complete, time {time.time() - start_time:.2f}s shape {pred.shape} {image_file}")
1798
+
1799
+ return pred
1800
+
1801
+ def train_epoch(
1802
+ self,
1803
+ model,
1804
+ train_loader,
1805
+ optimizer,
1806
+ loss_function,
1807
+ acc_function,
1808
+ grad_scaler,
1809
+ epoch,
1810
+ rank,
1811
+ global_rank=0,
1812
+ num_epochs=0,
1813
+ sigmoid=False,
1814
+ use_amp=True,
1815
+ use_cuda=True,
1816
+ channels_last=False,
1817
+ num_steps_per_image=1,
1818
+ ):
1819
+ model.train()
1820
+ device = torch.device(rank) if use_cuda else torch.device("cpu")
1821
+ memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
1822
+
1823
+ run_loss = CumulativeAverage()
1824
+ run_acc = CumulativeAverage()
1825
+
1826
+ start_time = time.time()
1827
+ avg_loss = avg_acc = 0
1828
+ for idx, batch_data in enumerate(train_loader):
1829
+ data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
1830
+ target = batch_data["label"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
1831
+
1832
+ data_list = data.chunk(num_steps_per_image) if num_steps_per_image > 1 else [data]
1833
+ target_list = target.chunk(num_steps_per_image) if num_steps_per_image > 1 else [target]
1834
+
1835
+ for ich in range(min(num_steps_per_image, len(data_list))):
1836
+ data = data_list[ich]
1837
+ target = target_list[ich]
1838
+
1839
+ # optimizer.zero_grad(set_to_none=True)
1840
+ for param in model.parameters():
1841
+ param.grad = None
1842
+
1843
+ with autocast(self.amp_device_type, enabled=use_amp):
1844
+ logits = model(data)
1845
+
1846
+ loss = loss_function(logits, target)
1847
+ grad_scaler.scale(loss).backward()
1848
+ grad_scaler.step(optimizer)
1849
+ grad_scaler.update()
1850
+
1851
+ with torch.no_grad():
1852
+ pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True)
1853
+ acc = acc_function(pred, target)
1854
+
1855
+ batch_size_adjusted = batch_size = data.shape[0]
1856
+ if isinstance(acc, (list, tuple)):
1857
+ acc, batch_size_adjusted = acc
1858
+
1859
+ run_loss.append(loss, count=batch_size)
1860
+ run_acc.append(acc, count=batch_size_adjusted)
1861
+
1862
+ avg_loss = run_loss.aggregate()
1863
+ avg_acc = run_acc.aggregate()
1864
+
1865
+ if global_rank == 0:
1866
+ print(
1867
+ f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} "
1868
+ f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s "
1869
+ )
1870
+ start_time = time.time()
1871
+
1872
+ # optimizer.zero_grad(set_to_none=True)
1873
+ for param in model.parameters():
1874
+ param.grad = None
1875
+
1876
+ data = None
1877
+ target = None
1878
+ data_list = None
1879
+ target_list = None
1880
+ batch_data = None
1881
+
1882
+ return avg_loss, avg_acc
1883
+
1884
+ @torch.no_grad()
1885
+ def val_epoch(
1886
+ self,
1887
+ model,
1888
+ val_loader,
1889
+ sliding_inferrer,
1890
+ loss_function=None,
1891
+ acc_function=None,
1892
+ epoch=0,
1893
+ rank=0,
1894
+ global_rank=0,
1895
+ num_epochs=0,
1896
+ sigmoid=False,
1897
+ use_amp=True,
1898
+ use_cuda=True,
1899
+ post_transforms=None,
1900
+ channels_last=False,
1901
+ calc_val_loss=False,
1902
+ ):
1903
+ model.eval()
1904
+ device = torch.device(rank) if use_cuda else torch.device("cpu")
1905
+ memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
1906
+ distributed = dist.is_initialized()
1907
+
1908
+ run_loss = CumulativeAverage()
1909
+ run_acc = CumulativeAverage()
1910
+ run_loss.append(torch.tensor(0, device=device), count=0)
1911
+
1912
+ avg_loss = avg_acc = 0
1913
+ start_time = time.time()
1914
+
1915
+ # In DDP, each replica has a subset of data, but if total data length is not evenly divisible by num_replicas, then some replicas has 1 extra repeated item.
1916
+ # For proper validation with batch of 1, we only want to collect metrics for non-repeated items, hence let's compute a proper subset length
1917
+ nonrepeated_data_length = len(val_loader.dataset)
1918
+ sampler = val_loader.sampler
1919
+ if dist.is_initialized and isinstance(sampler, DistributedSampler) and not sampler.drop_last:
1920
+ nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas))
1921
+
1922
+ for idx, batch_data in enumerate(val_loader):
1923
+ data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
1924
+ filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ]
1925
+ batch_size = data.shape[0]
1926
+
1927
+ with autocast(self.amp_device_type, enabled=use_amp):
1928
+ logits = sliding_inferrer(inputs=data, network=model)
1929
+
1930
+ data = None
1931
+
1932
+ if post_transforms:
1933
+
1934
+ try:
1935
+ pred = self.logits2pred(logits, sigmoid=sigmoid)
1936
+ except RuntimeError as e:
1937
+ if not logits.is_cuda:
1938
+ raise e
1939
+ print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape} {filename}")
1940
+ logits = logits.cpu()
1941
+ pred = self.logits2pred(logits, sigmoid=sigmoid)
1942
+
1943
+ if not calc_val_loss:
1944
+ logits = None
1945
+
1946
+ batch_data["pred"] = convert_to_dst_type(
1947
+ pred, batch_data["image"], dtype=pred.dtype, device=pred.device
1948
+ )[0]
1949
+ pred = None
1950
+
1951
+ try:
1952
+ # inverting on gpu can OOM due inverse resampling or un-cropping
1953
+ pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)])
1954
+ except RuntimeError as e:
1955
+ if not batch_data["pred"].is_cuda:
1956
+ raise e
1957
+ print(f"post_transforms failed on GPU pred retrying on CPU {batch_data['pred'].shape}")
1958
+ batch_data["pred"] = batch_data["pred"].cpu()
1959
+ pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)])
1960
+
1961
+ batch_data["pred"] = None
1962
+ if logits is not None and pred.shape != logits.shape:
1963
+ logits = None # if shape has changed due to inverse resampling or un-cropping
1964
+ else:
1965
+ pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True)
1966
+
1967
+ if "label" in batch_data and loss_function is not None and acc_function is not None:
1968
+ loss = acc = None
1969
+ target = batch_data["label"].as_subclass(torch.Tensor)
1970
+
1971
+ if calc_val_loss:
1972
+ if logits is not None:
1973
+ loss = loss_function(logits, target.to(device=logits.device))
1974
+ run_loss.append(loss.to(device=device), count=batch_size)
1975
+ logits = None
1976
+
1977
+ with torch.no_grad():
1978
+ try:
1979
+ acc = acc_function(pred.to(device=device), target.to(device=device)) # try GPU
1980
+ except RuntimeError as e:
1981
+ if "OutOfMemoryError" not in str(type(e).__name__):
1982
+ raise e
1983
+ print(
1984
+ f"acc_function val failed on GPU pred: {pred.shape} on {pred.device}, target: {target.shape} on {target.device}. retrying on CPU"
1985
+ )
1986
+ acc = acc_function(pred.cpu(), target.cpu())
1987
+
1988
+ batch_size_adjusted = batch_size
1989
+ if isinstance(acc, (list, tuple)):
1990
+ acc, batch_size_adjusted = acc
1991
+ acc = acc.detach().clone()
1992
+
1993
+ if idx < nonrepeated_data_length:
1994
+ run_acc.append(acc.to(device=device), count=batch_size_adjusted)
1995
+ else:
1996
+ run_acc.append(torch.zeros_like(acc, device=device), count=torch.zeros_like(batch_size_adjusted))
1997
+
1998
+ avg_loss = loss.cpu() if loss is not None else 0
1999
+ avg_acc = acc.cpu().numpy() if acc is not None else 0
2000
+ pred, target = None, None
2001
+
2002
+ if global_rank == 0:
2003
+ print(
2004
+ f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} loss: {avg_loss:.4f} "
2005
+ f"acc {avg_acc} time {time.time() - start_time:.2f}s {filename}"
2006
+ )
2007
+
2008
+ else:
2009
+ if global_rank == 0:
2010
+ print(f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} time {time.time() - start_time:.2f}s")
2011
+
2012
+ start_time = time.time()
2013
+
2014
+ pred = target = data = batch_data = None
2015
+
2016
+ if distributed:
2017
+ dist.barrier()
2018
+
2019
+ avg_loss = run_loss.aggregate()
2020
+ avg_acc = run_acc.aggregate()
2021
+
2022
+ if np.any(avg_acc < 0):
2023
+ dist.barrier()
2024
+ warnings.warn("Avg dice accuracy is negative, something went wrong!!!!!")
2025
+
2026
+ return avg_loss, avg_acc
2027
+
2028
+ def logits2pred(self, logits, sigmoid=False, dim=1, skip_softmax=False):
2029
+ if isinstance(logits, (list, tuple)):
2030
+ logits = logits[0]
2031
+
2032
+ if sigmoid:
2033
+ pred = torch.sigmoid(logits)
2034
+ else:
2035
+ pred = logits if skip_softmax else torch.softmax(logits, dim=dim, dtype=torch.double).float()
2036
+
2037
+ return pred
2038
+
2039
+ def get_avail_cpu_memory(self):
2040
+ avail_memory = psutil.virtual_memory().available
2041
+
2042
+ # check if in docker
2043
+ memory_limit_filename = "/sys/fs/cgroup/memory/memory.limit_in_bytes"
2044
+ if os.path.exists(memory_limit_filename):
2045
+ with open(memory_limit_filename, "r") as f:
2046
+ docker_limit = int(f.read())
2047
+ avail_memory = min(docker_limit, avail_memory) # could be lower limit in docker
2048
+
2049
+ return avail_memory
2050
+
2051
+ def get_cache_rate(self, train_cases=0, validation_cases=0, prioritise_train=True):
2052
+ config = self.config
2053
+ cache_rate = config["cache_rate"]
2054
+ avail_memory = None
2055
+
2056
+ total_cases = train_cases + validation_cases
2057
+
2058
+ image_size_mm_90 = config.get("image_size_mm_90", None)
2059
+ if config["resample"] and image_size_mm_90 is not None:
2060
+ image_size = (
2061
+ (np.array(image_size_mm_90) / np.array(config["resample_resolution"])).astype(np.int32).tolist()
2062
+ )
2063
+ else:
2064
+ image_size = config["image_size"]
2065
+
2066
+ approx_data_cache_required = (4 * config["input_channels"] + 1) * np.prod(image_size) * total_cases
2067
+ approx_os_cache_required = 50 * 1024**3 # reserve 50gb
2068
+
2069
+ if cache_rate is None:
2070
+ cache_rate = 0
2071
+
2072
+ if image_size is not None:
2073
+ avail_memory = self.get_avail_cpu_memory()
2074
+ cache_rate = min(avail_memory / float(approx_data_cache_required + approx_os_cache_required), 1.0)
2075
+ if cache_rate < 0.1:
2076
+ cache_rate = 0.0 # don't cache small
2077
+
2078
+ if self.global_rank == 0:
2079
+ print(
2080
+ f"Calculating cache required {approx_data_cache_required >> 30}GB, available RAM {avail_memory >> 30}GB given avg image size {image_size}."
2081
+ )
2082
+ if cache_rate < 1:
2083
+ print(
2084
+ f"Available RAM is not enought to cache full dataset, caching a fraction {cache_rate:.2f}"
2085
+ )
2086
+ else:
2087
+ print("Caching full dataset in RAM")
2088
+ else:
2089
+ print("Cant calculate cache_rate since image_size is not provided!!!!")
2090
+
2091
+ else:
2092
+ if self.global_rank == 0:
2093
+ print(f"Using user specified cache_rate={cache_rate} to cache data in RAM")
2094
+
2095
+ # allocate cache_rate to training files first
2096
+ cache_rate_train = cache_rate_val = cache_rate
2097
+
2098
+ if prioritise_train:
2099
+ if cache_rate > 0 and cache_rate < 1:
2100
+ cache_num = cache_rate * total_cases
2101
+ cache_rate_train = min(1.0, cache_num / train_cases) if train_cases > 0 else 0
2102
+ if (cache_rate_train < 1 and train_cases > 0) or validation_cases == 0:
2103
+ cache_rate_val = 0
2104
+ else:
2105
+ cache_rate_val = (cache_num - cache_rate_train * train_cases) / validation_cases
2106
+
2107
+ if self.global_rank == 0:
2108
+ print(f"Prioritizing cache_rate training {cache_rate_train} validation {cache_rate_val}")
2109
+
2110
+ return cache_rate_train, cache_rate_val
2111
+
2112
+ def save_history_csv(self, csv_path=None, header=None, **kwargs):
2113
+ if csv_path is not None:
2114
+ if header is not None:
2115
+ with open(csv_path, "a") as myfile:
2116
+ wrtr = csv.writer(myfile, delimiter="\t")
2117
+ wrtr.writerow(header)
2118
+ if len(kwargs):
2119
+ with open(csv_path, "a") as myfile:
2120
+ wrtr = csv.writer(myfile, delimiter="\t")
2121
+ wrtr.writerow(list(kwargs.values()))
2122
+
2123
+ def save_progress_yaml(self, progress_path=None, ckpt=None, **report):
2124
+ if ckpt is not None:
2125
+ report["model"] = ckpt
2126
+
2127
+ report["date"] = str(datetime.now())[:19]
2128
+
2129
+ if progress_path is not None:
2130
+ yaml.add_representer(
2131
+ float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", "{0:.4f}".format(value))
2132
+ )
2133
+ with open(progress_path, "a") as progress_file:
2134
+ yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False)
2135
+
2136
+ print("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items()))
2137
+
2138
+ def run(self):
2139
+ if self.config["validate"]["enabled"]:
2140
+ self.validate()
2141
+ elif self.config["infer"]["enabled"]:
2142
+ self.infer()
2143
+ else:
2144
+ self.train()
2145
+
2146
+
2147
+ def run_segmenter_worker(rank=0, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {}):
2148
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
2149
+ dist_available = dist.is_available()
2150
+ global_rank = rank
2151
+
2152
+ if type(config_file) == str and "," in config_file:
2153
+ config_file = config_file.split(",")
2154
+
2155
+ if dist_available:
2156
+ mgpu = override.get("mgpu", None)
2157
+ if mgpu is not None:
2158
+ logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING)
2159
+ dist.init_process_group(backend="nccl", rank=rank, timeout=timedelta(seconds=5400), **mgpu)
2160
+ mgpu.update({"rank": rank, "global_rank": rank})
2161
+ if rank == 0:
2162
+ print(f"Distributed: initializing multi-gpu tcp:// process group {mgpu}")
2163
+
2164
+ elif dist_launched() and torch.cuda.device_count() > 1:
2165
+ rank = int(os.getenv("LOCAL_RANK"))
2166
+ global_rank = int(os.getenv("RANK"))
2167
+ world_size = int(os.getenv("LOCAL_WORLD_SIZE"))
2168
+ logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING)
2169
+ dist.init_process_group(backend="nccl", init_method="env://") # torchrun spawned it
2170
+ override["mgpu"] = {"world_size": world_size, "rank": rank, "global_rank": global_rank}
2171
+
2172
+ print(f"Distributed launched: initializing multi-gpu env:// process group {override['mgpu']}")
2173
+
2174
+ segmenter = Segmenter(config_file=config_file, config_dict=override, rank=rank, global_rank=global_rank)
2175
+ best_metric = segmenter.run()
2176
+ segmenter = None
2177
+
2178
+ if dist_available and dist.is_initialized():
2179
+ dist.destroy_process_group()
2180
+
2181
+ return best_metric
2182
+
2183
+
2184
+ def dist_launched() -> bool:
2185
+ return dist.is_torchelastic_launched() or (
2186
+ os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1
2187
+ )
2188
+
2189
+
2190
+ def run_segmenter(config_file: Optional[Union[str, Sequence[str]]] = None, **kwargs):
2191
+ """
2192
+ if multiple gpu available, start multiprocessing for all gpus
2193
+ """
2194
+
2195
+ nprocs = torch.cuda.device_count()
2196
+
2197
+ if nprocs > 1 and not dist_launched():
2198
+ print("Manually spawning processes {nprocs}")
2199
+ kwargs["mgpu"] = {"world_size": nprocs, "init_method": kwargs.get("init_method", "tcp://127.0.0.1:23456")}
2200
+ torch.multiprocessing.spawn(run_segmenter_worker, nprocs=nprocs, args=(config_file, kwargs))
2201
+ else:
2202
+ print("Not spawning processes, dist is already launched {nprocs}")
2203
+ run_segmenter_worker(0, config_file, kwargs)
2204
+
2205
+
2206
+ if __name__ == "__main__":
2207
+ fire, fire_is_imported = optional_import("fire")
2208
+ if fire_is_imported:
2209
+ fire.Fire(run_segmenter)
2210
+ else:
2211
+ warnings.warn("Fire commandline parser cannot be imported, using options from config/hyper_parameters.yaml")
2212
+ run_segmenter(config_file="config/hyper_parameters.yaml")
scripts/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ from monai.apps.auto3dseg.auto_runner import logger
9
+
10
+ print = logger.debug
11
+ roi_size_default = [224, 224, 144]
12
+
13
+
14
+ def logger_configure(log_output_file: str = None, debug=False, global_rank=0) -> None:
15
+ log_config = {
16
+ "version": 1,
17
+ "disable_existing_loggers": False,
18
+ "formatters": {"monai_default": {"format": "%(message)s"}},
19
+ "loggers": {
20
+ "monai.apps.auto3dseg.auto_runner": {"handlers": ["console", "file"], "level": "DEBUG", "propagate": False}
21
+ },
22
+ # "filters": {"rank_filter": {"()": RankFilter}},
23
+ "handlers": {
24
+ "console": {
25
+ "class": "logging.StreamHandler",
26
+ "level": "INFO",
27
+ "formatter": "monai_default",
28
+ # "filters": ["rank_filter"],
29
+ },
30
+ "file": {
31
+ "class": "logging.FileHandler",
32
+ "filename": "runner.log",
33
+ "mode": "a",
34
+ "level": "DEBUG",
35
+ "formatter": "monai_default",
36
+ # "filters": ["rank_filter"],
37
+ },
38
+ },
39
+ }
40
+
41
+ if log_output_file is not None:
42
+ log_config["handlers"]["file"]["filename"] = log_output_file
43
+ log_config["handlers"]["file"]["level"] = "DEBUG"
44
+ else:
45
+ log_config["handlers"]["file"]["level"] = "CRITICAL"
46
+
47
+ if debug or bool(os.environ.get("SEGRESNET_DEBUG", False)):
48
+ log_config["handlers"]["console"]["level"] = "DEBUG"
49
+
50
+ logging.config.dictConfig(log_config)
51
+ # if global_rank!=0:
52
+ # logger.addFilter(lambda x: False)
53
+
54
+
55
+ def get_gpu_mem_size():
56
+ gpu_mem = 0
57
+ n_gpus = torch.cuda.device_count()
58
+ if n_gpus > 0:
59
+ gpu_mem = min([torch.cuda.get_device_properties(i).total_memory for i in range(n_gpus)])
60
+ gpu_mem = gpu_mem / 1024**3
61
+ else:
62
+ gpu_mem = 16
63
+
64
+ return gpu_mem
65
+
66
+
67
+ def auto_adjust_network_settings(
68
+ auto_scale_roi=False,
69
+ auto_scale_batch=False,
70
+ auto_scale_filters=False,
71
+ image_size_mm=None,
72
+ spacing=None,
73
+ output_classes=None,
74
+ levels=None,
75
+ anisotropic_scales=False,
76
+ levels_limit=5,
77
+ gpu_mem=None,
78
+ ):
79
+ global_rank = 0
80
+ if dist.is_available() and dist.is_initialized():
81
+ global_rank = dist.get_rank()
82
+ print(f"auto_adjust_network_settings dist global_rank {global_rank}")
83
+ else:
84
+ print(f"auto_adjust_network_settings no distributed global_rank {global_rank}")
85
+
86
+ batch_size_default = 1
87
+ init_filters_default = 32
88
+
89
+ roi_size = np.array(roi_size_default)
90
+ base_numel = roi_size.prod()
91
+ gpu_factor = 1
92
+
93
+ if gpu_mem is None:
94
+ gpu_mem = get_gpu_mem_size()
95
+ if global_rank == 0:
96
+ print(f"GPU device memory min: {gpu_mem}")
97
+
98
+ # adapting
99
+ if auto_scale_batch or auto_scale_roi or auto_scale_filters:
100
+ gpu_factor_init = gpu_factor = max(1, gpu_mem / 16)
101
+ if anisotropic_scales:
102
+ gpu_factor = max(1, 0.8 * gpu_factor)
103
+ if global_rank == 0:
104
+ print(f"base_numel {base_numel} gpu_factor {gpu_factor} gpu_factor_init {gpu_factor_init}")
105
+ else:
106
+ gpu_mem = 16
107
+ gpu_factor = gpu_factor_init = 1
108
+
109
+ # account for output_classes
110
+ output_classes_thresh = 20
111
+ if output_classes is not None and output_classes > output_classes_thresh:
112
+ base_adjust = gpu_mem / (output_classes * 0.2 + 11.5)
113
+ if gpu_mem < 17:
114
+ base_adjust /= 2
115
+
116
+ if global_rank == 0:
117
+ print(f"base_adjust {base_adjust} since output_classes {output_classes} > {output_classes_thresh}")
118
+ if base_adjust < 0.95: # reduce roi
119
+ base_numel *= base_adjust
120
+ r = int(base_numel ** (1 / 3) / 2**4)
121
+ if r == 0 and global_rank == 0:
122
+ print(f"Warning: given output_classes {output_classes}, unable to fit any ROI on the gpu {gpu_mem} Gb!")
123
+ roi_size = np.array([max(1, r) * 2**4] * 3)
124
+ gpu_factor = gpu_factor_init = 1
125
+ auto_scale_roi = False
126
+ else:
127
+ gpu_factor_init = gpu_factor = base_adjust
128
+
129
+ if global_rank == 0:
130
+ print(f"base_numel {base_numel} roi_size {roi_size} gpu_factor {gpu_factor}")
131
+
132
+ if image_size_mm is not None and spacing is not None:
133
+ image_size = np.floor(np.array(image_size_mm) / np.array(spacing))
134
+ if global_rank == 0:
135
+ print(f"input roi {roi_size} image_size {image_size} numel {roi_size.prod()}")
136
+ roi_size = np.minimum(roi_size, image_size)
137
+ else:
138
+ raise ValueError("image_size_mm or spacing is not provided, network params may be inaccuracy")
139
+
140
+ # adjust ROI
141
+ max_numel = base_numel * gpu_factor if auto_scale_roi else base_numel
142
+ while roi_size.prod() < max_numel:
143
+ old_numel = roi_size.prod()
144
+ roi_size = np.minimum(roi_size * 1.15, image_size)
145
+ if global_rank == 0:
146
+ print(f"increasing roi step {roi_size}")
147
+ if roi_size.prod() == old_numel:
148
+ break
149
+ if global_rank == 0:
150
+ print(f"increasing roi result 1 {roi_size}")
151
+
152
+ # adjust number of network downsize levels
153
+ if not anisotropic_scales:
154
+ if levels is None:
155
+ levels = np.floor(np.log2(roi_size))
156
+ if global_rank == 0:
157
+ print(f"levels 1 {levels}")
158
+ levels = min(min(levels), levels_limit) # limit to 5
159
+ if global_rank == 0:
160
+ print(f"levels 2' {levels}")
161
+
162
+ factor = 2 ** (levels - 1)
163
+ roi_size = factor * np.maximum(2, np.floor(roi_size / factor))
164
+ if global_rank == 0:
165
+ print(f"roi_size factored {roi_size}")
166
+
167
+ else:
168
+ extra_levels = np.floor(np.log2(np.max(spacing) / spacing))
169
+ extra_levels = max(extra_levels) - extra_levels
170
+
171
+ if levels is None:
172
+ # calc levels
173
+ levels = np.floor(np.log2(roi_size))
174
+ if global_rank == 0:
175
+ print(f"levels 1 aniso {levels} extra_levels {extra_levels}")
176
+ levels = min(min(levels + extra_levels), levels_limit) # limit to 5
177
+ if global_rank == 0:
178
+ print(f"levels 2 {levels}")
179
+
180
+ factor = 2 ** (np.maximum(1, levels - extra_levels) - 1)
181
+ roi_size = factor * np.maximum(2, np.floor(roi_size / factor))
182
+ if global_rank == 0:
183
+ print(f"roi_size factored {roi_size} factor {factor} extra_levels {extra_levels}")
184
+
185
+ # optionally adjust initial filters (above 32)
186
+ if auto_scale_filters and roi_size.prod() < base_numel * gpu_factor:
187
+ init_filters = int(max(32, np.floor(4 * (base_numel / roi_size.prod())) * 8))
188
+ if global_rank == 0:
189
+ print(f"checking to increase init_filters {init_filters}")
190
+ gpu_factor_init *= init_filters / 32
191
+ gpu_factor *= init_filters / 32
192
+ else:
193
+ if global_rank == 0:
194
+ print(f"kept filters the same base_numel {base_numel}, gpu_factor {gpu_factor}")
195
+
196
+ init_filters = init_filters_default
197
+
198
+ # finally scale batch
199
+ if auto_scale_batch and roi_size.prod() < base_numel * gpu_factor_init:
200
+ batch_size = int(1.1 * gpu_factor_init)
201
+ if global_rank == 0:
202
+ print(
203
+ f"increased batch_size {batch_size} base_numel {base_numel}, gpu_factor {gpu_factor}, gpu_factor_init {gpu_factor_init}"
204
+ )
205
+
206
+ else:
207
+ batch_size = batch_size_default
208
+ if global_rank == 0:
209
+ print(
210
+ f"kept batch the same base_numel {base_numel}, gpu_factor {gpu_factor}, gpu_factor_init {gpu_factor_init}"
211
+ )
212
+
213
+ levels = int(levels)
214
+ roi_size = roi_size.astype(int).tolist()
215
+
216
+ if global_rank == 0:
217
+ print(
218
+ f"Suggested network parameters: \n"
219
+ f"Batch size {batch_size_default} => {batch_size} \n"
220
+ f"ROI size {roi_size_default} => {roi_size} \n"
221
+ f"init_filters {init_filters_default} => {init_filters} \n"
222
+ f"aniso: {anisotropic_scales} image_size_mm: {image_size_mm} spacing: {spacing} levels: {levels} \n"
223
+ )
224
+
225
+ return roi_size, levels, init_filters, batch_size