Upload folder using huggingface_hub
Browse files- configs/hyper_parameters.yaml +72 -0
- model/model.pt +3 -0
- model/training.log +1818 -0
- scripts/__init__.py +10 -0
- scripts/__pycache__/segmenter.cpython-310.pyc +0 -0
- scripts/__pycache__/utils.cpython-310.pyc +0 -0
- scripts/segmenter.py +2212 -0
- scripts/utils.py +225 -0
configs/hyper_parameters.yaml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_meta_: {}
|
| 2 |
+
bundle_root: /Users/sakshirathi/Downloads/work_dir/segresnet_0
|
| 3 |
+
ckpt_path: $@bundle_root + '/model'
|
| 4 |
+
mlflow_tracking_uri: $@ckpt_path + '/mlruns/'
|
| 5 |
+
mlflow_experiment_name: Auto3DSeg
|
| 6 |
+
data_file_base_dir: /Users/sakshirathi/Documents/ShamLab
|
| 7 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 8 |
+
modality: ct
|
| 9 |
+
fold: 0
|
| 10 |
+
input_channels: 1
|
| 11 |
+
output_classes: 2
|
| 12 |
+
class_names: null
|
| 13 |
+
class_index: null
|
| 14 |
+
debug: false
|
| 15 |
+
ckpt_save: true
|
| 16 |
+
cache_rate: null
|
| 17 |
+
roi_size: [384, 384, 60]
|
| 18 |
+
auto_scale_allowed: true
|
| 19 |
+
auto_scale_batch: true
|
| 20 |
+
auto_scale_roi: false
|
| 21 |
+
auto_scale_filters: false
|
| 22 |
+
quick: false
|
| 23 |
+
channels_last: true
|
| 24 |
+
validate_final_original_res: true
|
| 25 |
+
calc_val_loss: false
|
| 26 |
+
amp: true
|
| 27 |
+
log_output_file: null
|
| 28 |
+
cache_class_indices: null
|
| 29 |
+
early_stopping_fraction: 0.001
|
| 30 |
+
determ: false
|
| 31 |
+
orientation_ras: true
|
| 32 |
+
crop_foreground: true
|
| 33 |
+
learning_rate: 0.0002
|
| 34 |
+
batch_size: 1
|
| 35 |
+
num_images_per_batch: 1
|
| 36 |
+
num_epochs: 1250
|
| 37 |
+
num_warmup_epochs: 3
|
| 38 |
+
sigmoid: false
|
| 39 |
+
resample: true
|
| 40 |
+
resample_resolution: [0.48766356436698155, 0.4876635832539761, 2.748479210553717]
|
| 41 |
+
crop_mode: ratio
|
| 42 |
+
normalize_mode: range
|
| 43 |
+
intensity_bounds: [39.63595217750186, 97.59593563988095]
|
| 44 |
+
num_epochs_per_validation: null
|
| 45 |
+
num_epochs_per_saving: 1
|
| 46 |
+
num_workers: 4
|
| 47 |
+
num_steps_per_image: null
|
| 48 |
+
num_crops_per_image: 2
|
| 49 |
+
loss: {_target_: DiceCELoss, include_background: true, squared_pred: true, smooth_nr: 0,
|
| 50 |
+
smooth_dr: 1.0e-05, softmax: $not @sigmoid, sigmoid: $@sigmoid, to_onehot_y: $not
|
| 51 |
+
@sigmoid}
|
| 52 |
+
optimizer: {_target_: torch.optim.AdamW, lr: '@learning_rate', weight_decay: 1.0e-05}
|
| 53 |
+
network:
|
| 54 |
+
_target_: SegResNetDS
|
| 55 |
+
init_filters: 32
|
| 56 |
+
blocks_down: [1, 2, 2, 4, 4]
|
| 57 |
+
norm: INSTANCE_NVFUSER
|
| 58 |
+
in_channels: '@input_channels'
|
| 59 |
+
out_channels: '@output_classes'
|
| 60 |
+
dsdepth: 4
|
| 61 |
+
finetune: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt'}
|
| 62 |
+
validate: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt', output_path: $@bundle_root
|
| 63 |
+
+ '/prediction_validation', save_mask: false, invert: true}
|
| 64 |
+
infer: {enabled: false, ckpt_name: $@bundle_root + '/model/model.pt', output_path: $@bundle_root
|
| 65 |
+
+ '/prediction_' + @infer#data_list_key, data_list_key: testing}
|
| 66 |
+
anisotropic_scales: true
|
| 67 |
+
spacing_median: [0.48766356436698155, 0.4876635832539761, 4.770811902267695]
|
| 68 |
+
spacing_lower: [0.42813486948609353, 0.428134856247896, 2.499999978382533]
|
| 69 |
+
spacing_upper: [0.5859375, 0.5859375004856939, 5.012642938162783]
|
| 70 |
+
image_size_mm_median: [249.68374495589455, 249.68375462603575, 168.30083390623668]
|
| 71 |
+
image_size_mm_90: [265.61599121093747, 265.6159922216141, 190.12765338720757]
|
| 72 |
+
image_size: [544, 544, 69]
|
model/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc8b31e85759b2e6f77b9ec71df0c07988281f8a8ec349b349c2a31c68a3b846
|
| 3 |
+
size 345158862
|
model/training.log
ADDED
|
@@ -0,0 +1,1818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_meta_: {}
|
| 2 |
+
acc: null
|
| 3 |
+
amp: false
|
| 4 |
+
anisotropic_scales: true
|
| 5 |
+
auto_scale_allowed: true
|
| 6 |
+
auto_scale_batch: true
|
| 7 |
+
auto_scale_filters: false
|
| 8 |
+
auto_scale_roi: false
|
| 9 |
+
batch_size: 1
|
| 10 |
+
bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
|
| 11 |
+
cache_class_indices: null
|
| 12 |
+
cache_rate: null
|
| 13 |
+
calc_val_loss: false
|
| 14 |
+
channels_last: true
|
| 15 |
+
ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
|
| 16 |
+
ckpt_save: true
|
| 17 |
+
class_index: null
|
| 18 |
+
class_names:
|
| 19 |
+
- acc_0
|
| 20 |
+
crop_add_background: true
|
| 21 |
+
crop_foreground: true
|
| 22 |
+
crop_mode: ratio
|
| 23 |
+
crop_ratios: null
|
| 24 |
+
cuda: false
|
| 25 |
+
data_file_base_dir: /Users/sakshirathi/neurotk/bundles
|
| 26 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 27 |
+
debug: false
|
| 28 |
+
determ: false
|
| 29 |
+
early_stopping_fraction: 0.001
|
| 30 |
+
extra_modalities: {}
|
| 31 |
+
finetune:
|
| 32 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 33 |
+
enabled: false
|
| 34 |
+
float32_precision: null
|
| 35 |
+
fold: 0
|
| 36 |
+
fork: true
|
| 37 |
+
global_rank: 0
|
| 38 |
+
image_size:
|
| 39 |
+
- 544
|
| 40 |
+
- 544
|
| 41 |
+
- 69
|
| 42 |
+
image_size_mm_90:
|
| 43 |
+
- 265.61599121093747
|
| 44 |
+
- 265.6159922216141
|
| 45 |
+
- 190.12765338720757
|
| 46 |
+
image_size_mm_median:
|
| 47 |
+
- 249.68374495589455
|
| 48 |
+
- 249.68375462603575
|
| 49 |
+
- 168.30083390623668
|
| 50 |
+
infer:
|
| 51 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 52 |
+
data_list_key: testing
|
| 53 |
+
enabled: true
|
| 54 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
|
| 55 |
+
input_channels: 1
|
| 56 |
+
intensity_bounds:
|
| 57 |
+
- 39.63595217750186
|
| 58 |
+
- 97.59593563988095
|
| 59 |
+
learning_rate: 0.0002
|
| 60 |
+
log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
|
| 61 |
+
loss:
|
| 62 |
+
_target_: DiceCELoss
|
| 63 |
+
include_background: true
|
| 64 |
+
sigmoid: false
|
| 65 |
+
smooth_dr: 1.0e-05
|
| 66 |
+
smooth_nr: 0
|
| 67 |
+
softmax: true
|
| 68 |
+
squared_pred: true
|
| 69 |
+
to_onehot_y: true
|
| 70 |
+
max_samples_per_class: 12500
|
| 71 |
+
mlflow_experiment_name: Auto3DSeg
|
| 72 |
+
mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
|
| 73 |
+
modality: ct
|
| 74 |
+
network:
|
| 75 |
+
_target_: SegResNetDS
|
| 76 |
+
blocks_down:
|
| 77 |
+
- 1
|
| 78 |
+
- 2
|
| 79 |
+
- 2
|
| 80 |
+
- 4
|
| 81 |
+
- 4
|
| 82 |
+
dsdepth: 4
|
| 83 |
+
in_channels: 1
|
| 84 |
+
init_filters: 32
|
| 85 |
+
norm: INSTANCE_NVFUSER
|
| 86 |
+
out_channels: 2
|
| 87 |
+
normalize_mode: range
|
| 88 |
+
notf32: false
|
| 89 |
+
num_crops_per_image: 2
|
| 90 |
+
num_epochs: 1250
|
| 91 |
+
num_epochs_per_saving: 1
|
| 92 |
+
num_epochs_per_validation: null
|
| 93 |
+
num_images_per_batch: 1
|
| 94 |
+
num_steps_per_image: null
|
| 95 |
+
num_warmup_epochs: 3
|
| 96 |
+
num_workers: 4
|
| 97 |
+
optimizer:
|
| 98 |
+
_target_: torch.optim.AdamW
|
| 99 |
+
lr: 0.0002
|
| 100 |
+
weight_decay: 1.0e-05
|
| 101 |
+
orientation_ras: true
|
| 102 |
+
output_classes: 2
|
| 103 |
+
pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 104 |
+
quick: false
|
| 105 |
+
rank: 0
|
| 106 |
+
resample: true
|
| 107 |
+
resample_resolution:
|
| 108 |
+
- 0.48766356436698155
|
| 109 |
+
- 0.4876635832539761
|
| 110 |
+
- 2.748479210553717
|
| 111 |
+
roi_size:
|
| 112 |
+
- 384
|
| 113 |
+
- 384
|
| 114 |
+
- 60
|
| 115 |
+
sigmoid: false
|
| 116 |
+
spacing_lower:
|
| 117 |
+
- 0.42813486948609353
|
| 118 |
+
- 0.428134856247896
|
| 119 |
+
- 2.499999978382533
|
| 120 |
+
spacing_median:
|
| 121 |
+
- 0.48766356436698155
|
| 122 |
+
- 0.4876635832539761
|
| 123 |
+
- 4.770811902267695
|
| 124 |
+
spacing_upper:
|
| 125 |
+
- 0.5859375
|
| 126 |
+
- 0.5859375004856939
|
| 127 |
+
- 5.012642938162783
|
| 128 |
+
start_epoch: 0
|
| 129 |
+
stop_on_lowacc: true
|
| 130 |
+
validate:
|
| 131 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 132 |
+
enabled: false
|
| 133 |
+
invert: true
|
| 134 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
|
| 135 |
+
save_mask: false
|
| 136 |
+
validate_final_original_res: true
|
| 137 |
+
|
| 138 |
+
auto_adjust_network_settings no distributed global_rank 0
|
| 139 |
+
GPU device memory min: 16
|
| 140 |
+
base_numel 7225344 gpu_factor 1 gpu_factor_init 1
|
| 141 |
+
input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
|
| 142 |
+
increasing roi step [ 257.600 257.600 61.000]
|
| 143 |
+
increasing roi result 1 [ 257.600 257.600 61.000]
|
| 144 |
+
increasing roi step [ 296.240 296.240 61.000]
|
| 145 |
+
increasing roi result 1 [ 296.240 296.240 61.000]
|
| 146 |
+
increasing roi step [ 340.676 340.676 61.000]
|
| 147 |
+
increasing roi result 1 [ 340.676 340.676 61.000]
|
| 148 |
+
increasing roi step [ 391.777 391.777 61.000]
|
| 149 |
+
increasing roi result 1 [ 391.777 391.777 61.000]
|
| 150 |
+
roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
|
| 151 |
+
kept filters the same base_numel 7225344, gpu_factor 1
|
| 152 |
+
kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
|
| 153 |
+
Suggested network parameters:
|
| 154 |
+
Batch size 1 => 1
|
| 155 |
+
ROI size [224, 224, 144] => [384, 384, 60]
|
| 156 |
+
init_filters 32 => 32
|
| 157 |
+
aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
|
| 158 |
+
|
| 159 |
+
Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
|
| 160 |
+
SegResNetDS(
|
| 161 |
+
(encoder): SegResEncoder(
|
| 162 |
+
(conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 163 |
+
(layers): ModuleList(
|
| 164 |
+
(0): ModuleDict(
|
| 165 |
+
(blocks): Sequential(
|
| 166 |
+
(0): SegResBlock(
|
| 167 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 168 |
+
(act1): ReLU(inplace=True)
|
| 169 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 170 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 171 |
+
(act2): ReLU(inplace=True)
|
| 172 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
(downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 176 |
+
)
|
| 177 |
+
(1): ModuleDict(
|
| 178 |
+
(blocks): Sequential(
|
| 179 |
+
(0): SegResBlock(
|
| 180 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 181 |
+
(act1): ReLU(inplace=True)
|
| 182 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 183 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 184 |
+
(act2): ReLU(inplace=True)
|
| 185 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 186 |
+
)
|
| 187 |
+
(1): SegResBlock(
|
| 188 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 189 |
+
(act1): ReLU(inplace=True)
|
| 190 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 191 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 192 |
+
(act2): ReLU(inplace=True)
|
| 193 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
(downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 197 |
+
)
|
| 198 |
+
(2): ModuleDict(
|
| 199 |
+
(blocks): Sequential(
|
| 200 |
+
(0): SegResBlock(
|
| 201 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 202 |
+
(act1): ReLU(inplace=True)
|
| 203 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 204 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 205 |
+
(act2): ReLU(inplace=True)
|
| 206 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 207 |
+
)
|
| 208 |
+
(1): SegResBlock(
|
| 209 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 210 |
+
(act1): ReLU(inplace=True)
|
| 211 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 212 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 213 |
+
(act2): ReLU(inplace=True)
|
| 214 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
(downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 218 |
+
)
|
| 219 |
+
(3): ModuleDict(
|
| 220 |
+
(blocks): Sequential(
|
| 221 |
+
(0): SegResBlock(
|
| 222 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 223 |
+
(act1): ReLU(inplace=True)
|
| 224 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 225 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 226 |
+
(act2): ReLU(inplace=True)
|
| 227 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 228 |
+
)
|
| 229 |
+
(1): SegResBlock(
|
| 230 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 231 |
+
(act1): ReLU(inplace=True)
|
| 232 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 233 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 234 |
+
(act2): ReLU(inplace=True)
|
| 235 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 236 |
+
)
|
| 237 |
+
(2): SegResBlock(
|
| 238 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 239 |
+
(act1): ReLU(inplace=True)
|
| 240 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 241 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 242 |
+
(act2): ReLU(inplace=True)
|
| 243 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 244 |
+
)
|
| 245 |
+
(3): SegResBlock(
|
| 246 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 247 |
+
(act1): ReLU(inplace=True)
|
| 248 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 249 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 250 |
+
(act2): ReLU(inplace=True)
|
| 251 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
(downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 255 |
+
)
|
| 256 |
+
(4): ModuleDict(
|
| 257 |
+
(blocks): Sequential(
|
| 258 |
+
(0): SegResBlock(
|
| 259 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 260 |
+
(act1): ReLU(inplace=True)
|
| 261 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 262 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 263 |
+
(act2): ReLU(inplace=True)
|
| 264 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 265 |
+
)
|
| 266 |
+
(1): SegResBlock(
|
| 267 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 268 |
+
(act1): ReLU(inplace=True)
|
| 269 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 270 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 271 |
+
(act2): ReLU(inplace=True)
|
| 272 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 273 |
+
)
|
| 274 |
+
(2): SegResBlock(
|
| 275 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 276 |
+
(act1): ReLU(inplace=True)
|
| 277 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 278 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 279 |
+
(act2): ReLU(inplace=True)
|
| 280 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 281 |
+
)
|
| 282 |
+
(3): SegResBlock(
|
| 283 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 284 |
+
(act1): ReLU(inplace=True)
|
| 285 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 286 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 287 |
+
(act2): ReLU(inplace=True)
|
| 288 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 289 |
+
)
|
| 290 |
+
)
|
| 291 |
+
(downsample): Identity()
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
(up_layers): ModuleList(
|
| 296 |
+
(0): ModuleDict(
|
| 297 |
+
(upsample): UpSample(
|
| 298 |
+
(deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 299 |
+
)
|
| 300 |
+
(blocks): Sequential(
|
| 301 |
+
(0): SegResBlock(
|
| 302 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 303 |
+
(act1): ReLU(inplace=True)
|
| 304 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 305 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 306 |
+
(act2): ReLU(inplace=True)
|
| 307 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 308 |
+
)
|
| 309 |
+
)
|
| 310 |
+
(head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 311 |
+
)
|
| 312 |
+
(1): ModuleDict(
|
| 313 |
+
(upsample): UpSample(
|
| 314 |
+
(deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 315 |
+
)
|
| 316 |
+
(blocks): Sequential(
|
| 317 |
+
(0): SegResBlock(
|
| 318 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 319 |
+
(act1): ReLU(inplace=True)
|
| 320 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 321 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 322 |
+
(act2): ReLU(inplace=True)
|
| 323 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
(head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 327 |
+
)
|
| 328 |
+
(2): ModuleDict(
|
| 329 |
+
(upsample): UpSample(
|
| 330 |
+
(deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 331 |
+
)
|
| 332 |
+
(blocks): Sequential(
|
| 333 |
+
(0): SegResBlock(
|
| 334 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 335 |
+
(act1): ReLU(inplace=True)
|
| 336 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 337 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 338 |
+
(act2): ReLU(inplace=True)
|
| 339 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 340 |
+
)
|
| 341 |
+
)
|
| 342 |
+
(head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 343 |
+
)
|
| 344 |
+
(3): ModuleDict(
|
| 345 |
+
(upsample): UpSample(
|
| 346 |
+
(deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 347 |
+
)
|
| 348 |
+
(blocks): Sequential(
|
| 349 |
+
(0): SegResBlock(
|
| 350 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 351 |
+
(act1): ReLU(inplace=True)
|
| 352 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 353 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 354 |
+
(act2): ReLU(inplace=True)
|
| 355 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 356 |
+
)
|
| 357 |
+
)
|
| 358 |
+
(head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 359 |
+
)
|
| 360 |
+
)
|
| 361 |
+
)
|
| 362 |
+
=> loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
|
| 363 |
+
Total parameters count: 86278888 distributed: False
|
| 364 |
+
Inference complete, time 234.85s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
|
| 365 |
+
_meta_: {}
|
| 366 |
+
acc: null
|
| 367 |
+
amp: false
|
| 368 |
+
anisotropic_scales: true
|
| 369 |
+
auto_scale_allowed: true
|
| 370 |
+
auto_scale_batch: true
|
| 371 |
+
auto_scale_filters: false
|
| 372 |
+
auto_scale_roi: false
|
| 373 |
+
batch_size: 1
|
| 374 |
+
bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
|
| 375 |
+
cache_class_indices: null
|
| 376 |
+
cache_rate: null
|
| 377 |
+
calc_val_loss: false
|
| 378 |
+
channels_last: true
|
| 379 |
+
ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
|
| 380 |
+
ckpt_save: true
|
| 381 |
+
class_index: null
|
| 382 |
+
class_names:
|
| 383 |
+
- acc_0
|
| 384 |
+
crop_add_background: true
|
| 385 |
+
crop_foreground: true
|
| 386 |
+
crop_mode: ratio
|
| 387 |
+
crop_ratios: null
|
| 388 |
+
cuda: false
|
| 389 |
+
data_file_base_dir: /Users/sakshirathi/neurotk/bundles
|
| 390 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 391 |
+
debug: false
|
| 392 |
+
determ: false
|
| 393 |
+
early_stopping_fraction: 0.001
|
| 394 |
+
extra_modalities: {}
|
| 395 |
+
finetune:
|
| 396 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 397 |
+
enabled: false
|
| 398 |
+
float32_precision: null
|
| 399 |
+
fold: 0
|
| 400 |
+
fork: true
|
| 401 |
+
global_rank: 0
|
| 402 |
+
image_size:
|
| 403 |
+
- 544
|
| 404 |
+
- 544
|
| 405 |
+
- 69
|
| 406 |
+
image_size_mm_90:
|
| 407 |
+
- 265.61599121093747
|
| 408 |
+
- 265.6159922216141
|
| 409 |
+
- 190.12765338720757
|
| 410 |
+
image_size_mm_median:
|
| 411 |
+
- 249.68374495589455
|
| 412 |
+
- 249.68375462603575
|
| 413 |
+
- 168.30083390623668
|
| 414 |
+
infer:
|
| 415 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 416 |
+
data_list_key: testing
|
| 417 |
+
enabled: true
|
| 418 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
|
| 419 |
+
input_channels: 1
|
| 420 |
+
intensity_bounds:
|
| 421 |
+
- 39.63595217750186
|
| 422 |
+
- 97.59593563988095
|
| 423 |
+
learning_rate: 0.0002
|
| 424 |
+
log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
|
| 425 |
+
loss:
|
| 426 |
+
_target_: DiceCELoss
|
| 427 |
+
include_background: true
|
| 428 |
+
sigmoid: false
|
| 429 |
+
smooth_dr: 1.0e-05
|
| 430 |
+
smooth_nr: 0
|
| 431 |
+
softmax: true
|
| 432 |
+
squared_pred: true
|
| 433 |
+
to_onehot_y: true
|
| 434 |
+
max_samples_per_class: 12500
|
| 435 |
+
mlflow_experiment_name: Auto3DSeg
|
| 436 |
+
mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
|
| 437 |
+
modality: ct
|
| 438 |
+
network:
|
| 439 |
+
_target_: SegResNetDS
|
| 440 |
+
blocks_down:
|
| 441 |
+
- 1
|
| 442 |
+
- 2
|
| 443 |
+
- 2
|
| 444 |
+
- 4
|
| 445 |
+
- 4
|
| 446 |
+
dsdepth: 4
|
| 447 |
+
in_channels: 1
|
| 448 |
+
init_filters: 32
|
| 449 |
+
norm: INSTANCE_NVFUSER
|
| 450 |
+
out_channels: 2
|
| 451 |
+
normalize_mode: range
|
| 452 |
+
notf32: false
|
| 453 |
+
num_crops_per_image: 2
|
| 454 |
+
num_epochs: 1250
|
| 455 |
+
num_epochs_per_saving: 1
|
| 456 |
+
num_epochs_per_validation: null
|
| 457 |
+
num_images_per_batch: 1
|
| 458 |
+
num_steps_per_image: null
|
| 459 |
+
num_warmup_epochs: 3
|
| 460 |
+
num_workers: 4
|
| 461 |
+
optimizer:
|
| 462 |
+
_target_: torch.optim.AdamW
|
| 463 |
+
lr: 0.0002
|
| 464 |
+
weight_decay: 1.0e-05
|
| 465 |
+
orientation_ras: true
|
| 466 |
+
output_classes: 2
|
| 467 |
+
pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 468 |
+
quick: false
|
| 469 |
+
rank: 0
|
| 470 |
+
resample: true
|
| 471 |
+
resample_resolution:
|
| 472 |
+
- 0.48766356436698155
|
| 473 |
+
- 0.4876635832539761
|
| 474 |
+
- 2.748479210553717
|
| 475 |
+
roi_size:
|
| 476 |
+
- 384
|
| 477 |
+
- 384
|
| 478 |
+
- 60
|
| 479 |
+
sigmoid: false
|
| 480 |
+
spacing_lower:
|
| 481 |
+
- 0.42813486948609353
|
| 482 |
+
- 0.428134856247896
|
| 483 |
+
- 2.499999978382533
|
| 484 |
+
spacing_median:
|
| 485 |
+
- 0.48766356436698155
|
| 486 |
+
- 0.4876635832539761
|
| 487 |
+
- 4.770811902267695
|
| 488 |
+
spacing_upper:
|
| 489 |
+
- 0.5859375
|
| 490 |
+
- 0.5859375004856939
|
| 491 |
+
- 5.012642938162783
|
| 492 |
+
start_epoch: 0
|
| 493 |
+
stop_on_lowacc: true
|
| 494 |
+
validate:
|
| 495 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 496 |
+
enabled: false
|
| 497 |
+
invert: true
|
| 498 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
|
| 499 |
+
save_mask: false
|
| 500 |
+
validate_final_original_res: true
|
| 501 |
+
|
| 502 |
+
auto_adjust_network_settings no distributed global_rank 0
|
| 503 |
+
GPU device memory min: 16
|
| 504 |
+
base_numel 7225344 gpu_factor 1 gpu_factor_init 1
|
| 505 |
+
input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
|
| 506 |
+
increasing roi step [ 257.600 257.600 61.000]
|
| 507 |
+
increasing roi result 1 [ 257.600 257.600 61.000]
|
| 508 |
+
increasing roi step [ 296.240 296.240 61.000]
|
| 509 |
+
increasing roi result 1 [ 296.240 296.240 61.000]
|
| 510 |
+
increasing roi step [ 340.676 340.676 61.000]
|
| 511 |
+
increasing roi result 1 [ 340.676 340.676 61.000]
|
| 512 |
+
increasing roi step [ 391.777 391.777 61.000]
|
| 513 |
+
increasing roi result 1 [ 391.777 391.777 61.000]
|
| 514 |
+
roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
|
| 515 |
+
kept filters the same base_numel 7225344, gpu_factor 1
|
| 516 |
+
kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
|
| 517 |
+
Suggested network parameters:
|
| 518 |
+
Batch size 1 => 1
|
| 519 |
+
ROI size [224, 224, 144] => [384, 384, 60]
|
| 520 |
+
init_filters 32 => 32
|
| 521 |
+
aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
|
| 522 |
+
|
| 523 |
+
Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
|
| 524 |
+
SegResNetDS(
|
| 525 |
+
(encoder): SegResEncoder(
|
| 526 |
+
(conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 527 |
+
(layers): ModuleList(
|
| 528 |
+
(0): ModuleDict(
|
| 529 |
+
(blocks): Sequential(
|
| 530 |
+
(0): SegResBlock(
|
| 531 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 532 |
+
(act1): ReLU(inplace=True)
|
| 533 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 534 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 535 |
+
(act2): ReLU(inplace=True)
|
| 536 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
(downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 540 |
+
)
|
| 541 |
+
(1): ModuleDict(
|
| 542 |
+
(blocks): Sequential(
|
| 543 |
+
(0): SegResBlock(
|
| 544 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 545 |
+
(act1): ReLU(inplace=True)
|
| 546 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 547 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 548 |
+
(act2): ReLU(inplace=True)
|
| 549 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 550 |
+
)
|
| 551 |
+
(1): SegResBlock(
|
| 552 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 553 |
+
(act1): ReLU(inplace=True)
|
| 554 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 555 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 556 |
+
(act2): ReLU(inplace=True)
|
| 557 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 558 |
+
)
|
| 559 |
+
)
|
| 560 |
+
(downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 561 |
+
)
|
| 562 |
+
(2): ModuleDict(
|
| 563 |
+
(blocks): Sequential(
|
| 564 |
+
(0): SegResBlock(
|
| 565 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 566 |
+
(act1): ReLU(inplace=True)
|
| 567 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 568 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 569 |
+
(act2): ReLU(inplace=True)
|
| 570 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 571 |
+
)
|
| 572 |
+
(1): SegResBlock(
|
| 573 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 574 |
+
(act1): ReLU(inplace=True)
|
| 575 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 576 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 577 |
+
(act2): ReLU(inplace=True)
|
| 578 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 579 |
+
)
|
| 580 |
+
)
|
| 581 |
+
(downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 582 |
+
)
|
| 583 |
+
(3): ModuleDict(
|
| 584 |
+
(blocks): Sequential(
|
| 585 |
+
(0): SegResBlock(
|
| 586 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 587 |
+
(act1): ReLU(inplace=True)
|
| 588 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 589 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 590 |
+
(act2): ReLU(inplace=True)
|
| 591 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 592 |
+
)
|
| 593 |
+
(1): SegResBlock(
|
| 594 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 595 |
+
(act1): ReLU(inplace=True)
|
| 596 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 597 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 598 |
+
(act2): ReLU(inplace=True)
|
| 599 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 600 |
+
)
|
| 601 |
+
(2): SegResBlock(
|
| 602 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 603 |
+
(act1): ReLU(inplace=True)
|
| 604 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 605 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 606 |
+
(act2): ReLU(inplace=True)
|
| 607 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 608 |
+
)
|
| 609 |
+
(3): SegResBlock(
|
| 610 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 611 |
+
(act1): ReLU(inplace=True)
|
| 612 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 613 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 614 |
+
(act2): ReLU(inplace=True)
|
| 615 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 616 |
+
)
|
| 617 |
+
)
|
| 618 |
+
(downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 619 |
+
)
|
| 620 |
+
(4): ModuleDict(
|
| 621 |
+
(blocks): Sequential(
|
| 622 |
+
(0): SegResBlock(
|
| 623 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 624 |
+
(act1): ReLU(inplace=True)
|
| 625 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 626 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 627 |
+
(act2): ReLU(inplace=True)
|
| 628 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 629 |
+
)
|
| 630 |
+
(1): SegResBlock(
|
| 631 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 632 |
+
(act1): ReLU(inplace=True)
|
| 633 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 634 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 635 |
+
(act2): ReLU(inplace=True)
|
| 636 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 637 |
+
)
|
| 638 |
+
(2): SegResBlock(
|
| 639 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 640 |
+
(act1): ReLU(inplace=True)
|
| 641 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 642 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 643 |
+
(act2): ReLU(inplace=True)
|
| 644 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 645 |
+
)
|
| 646 |
+
(3): SegResBlock(
|
| 647 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 648 |
+
(act1): ReLU(inplace=True)
|
| 649 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 650 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 651 |
+
(act2): ReLU(inplace=True)
|
| 652 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 653 |
+
)
|
| 654 |
+
)
|
| 655 |
+
(downsample): Identity()
|
| 656 |
+
)
|
| 657 |
+
)
|
| 658 |
+
)
|
| 659 |
+
(up_layers): ModuleList(
|
| 660 |
+
(0): ModuleDict(
|
| 661 |
+
(upsample): UpSample(
|
| 662 |
+
(deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 663 |
+
)
|
| 664 |
+
(blocks): Sequential(
|
| 665 |
+
(0): SegResBlock(
|
| 666 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 667 |
+
(act1): ReLU(inplace=True)
|
| 668 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 669 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 670 |
+
(act2): ReLU(inplace=True)
|
| 671 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 672 |
+
)
|
| 673 |
+
)
|
| 674 |
+
(head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 675 |
+
)
|
| 676 |
+
(1): ModuleDict(
|
| 677 |
+
(upsample): UpSample(
|
| 678 |
+
(deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 679 |
+
)
|
| 680 |
+
(blocks): Sequential(
|
| 681 |
+
(0): SegResBlock(
|
| 682 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 683 |
+
(act1): ReLU(inplace=True)
|
| 684 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 685 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 686 |
+
(act2): ReLU(inplace=True)
|
| 687 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 688 |
+
)
|
| 689 |
+
)
|
| 690 |
+
(head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 691 |
+
)
|
| 692 |
+
(2): ModuleDict(
|
| 693 |
+
(upsample): UpSample(
|
| 694 |
+
(deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 695 |
+
)
|
| 696 |
+
(blocks): Sequential(
|
| 697 |
+
(0): SegResBlock(
|
| 698 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 699 |
+
(act1): ReLU(inplace=True)
|
| 700 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 701 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 702 |
+
(act2): ReLU(inplace=True)
|
| 703 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 704 |
+
)
|
| 705 |
+
)
|
| 706 |
+
(head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 707 |
+
)
|
| 708 |
+
(3): ModuleDict(
|
| 709 |
+
(upsample): UpSample(
|
| 710 |
+
(deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 711 |
+
)
|
| 712 |
+
(blocks): Sequential(
|
| 713 |
+
(0): SegResBlock(
|
| 714 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 715 |
+
(act1): ReLU(inplace=True)
|
| 716 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 717 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 718 |
+
(act2): ReLU(inplace=True)
|
| 719 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 720 |
+
)
|
| 721 |
+
)
|
| 722 |
+
(head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 723 |
+
)
|
| 724 |
+
)
|
| 725 |
+
)
|
| 726 |
+
=> loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
|
| 727 |
+
Total parameters count: 86278888 distributed: False
|
| 728 |
+
Inference complete, time 233.93s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
|
| 729 |
+
_meta_: {}
|
| 730 |
+
acc: null
|
| 731 |
+
amp: false
|
| 732 |
+
anisotropic_scales: true
|
| 733 |
+
auto_scale_allowed: true
|
| 734 |
+
auto_scale_batch: true
|
| 735 |
+
auto_scale_filters: false
|
| 736 |
+
auto_scale_roi: false
|
| 737 |
+
batch_size: 1
|
| 738 |
+
bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
|
| 739 |
+
cache_class_indices: null
|
| 740 |
+
cache_rate: null
|
| 741 |
+
calc_val_loss: false
|
| 742 |
+
channels_last: true
|
| 743 |
+
ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
|
| 744 |
+
ckpt_save: true
|
| 745 |
+
class_index: null
|
| 746 |
+
class_names:
|
| 747 |
+
- acc_0
|
| 748 |
+
crop_add_background: true
|
| 749 |
+
crop_foreground: true
|
| 750 |
+
crop_mode: ratio
|
| 751 |
+
crop_ratios: null
|
| 752 |
+
cuda: false
|
| 753 |
+
data_file_base_dir: /Users/sakshirathi/neurotk/bundles
|
| 754 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 755 |
+
debug: false
|
| 756 |
+
determ: false
|
| 757 |
+
early_stopping_fraction: 0.001
|
| 758 |
+
extra_modalities: {}
|
| 759 |
+
finetune:
|
| 760 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 761 |
+
enabled: false
|
| 762 |
+
float32_precision: null
|
| 763 |
+
fold: 0
|
| 764 |
+
fork: true
|
| 765 |
+
global_rank: 0
|
| 766 |
+
image_size:
|
| 767 |
+
- 544
|
| 768 |
+
- 544
|
| 769 |
+
- 69
|
| 770 |
+
image_size_mm_90:
|
| 771 |
+
- 265.61599121093747
|
| 772 |
+
- 265.6159922216141
|
| 773 |
+
- 190.12765338720757
|
| 774 |
+
image_size_mm_median:
|
| 775 |
+
- 249.68374495589455
|
| 776 |
+
- 249.68375462603575
|
| 777 |
+
- 168.30083390623668
|
| 778 |
+
infer:
|
| 779 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 780 |
+
data_list_key: testing
|
| 781 |
+
enabled: true
|
| 782 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
|
| 783 |
+
input_channels: 1
|
| 784 |
+
intensity_bounds:
|
| 785 |
+
- 39.63595217750186
|
| 786 |
+
- 97.59593563988095
|
| 787 |
+
learning_rate: 0.0002
|
| 788 |
+
log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
|
| 789 |
+
loss:
|
| 790 |
+
_target_: DiceCELoss
|
| 791 |
+
include_background: true
|
| 792 |
+
sigmoid: false
|
| 793 |
+
smooth_dr: 1.0e-05
|
| 794 |
+
smooth_nr: 0
|
| 795 |
+
softmax: true
|
| 796 |
+
squared_pred: true
|
| 797 |
+
to_onehot_y: true
|
| 798 |
+
max_samples_per_class: 12500
|
| 799 |
+
mlflow_experiment_name: Auto3DSeg
|
| 800 |
+
mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
|
| 801 |
+
modality: ct
|
| 802 |
+
network:
|
| 803 |
+
_target_: SegResNetDS
|
| 804 |
+
blocks_down:
|
| 805 |
+
- 1
|
| 806 |
+
- 2
|
| 807 |
+
- 2
|
| 808 |
+
- 4
|
| 809 |
+
- 4
|
| 810 |
+
dsdepth: 4
|
| 811 |
+
in_channels: 1
|
| 812 |
+
init_filters: 32
|
| 813 |
+
norm: INSTANCE_NVFUSER
|
| 814 |
+
out_channels: 2
|
| 815 |
+
normalize_mode: range
|
| 816 |
+
notf32: false
|
| 817 |
+
num_crops_per_image: 2
|
| 818 |
+
num_epochs: 1250
|
| 819 |
+
num_epochs_per_saving: 1
|
| 820 |
+
num_epochs_per_validation: null
|
| 821 |
+
num_images_per_batch: 1
|
| 822 |
+
num_steps_per_image: null
|
| 823 |
+
num_warmup_epochs: 3
|
| 824 |
+
num_workers: 4
|
| 825 |
+
optimizer:
|
| 826 |
+
_target_: torch.optim.AdamW
|
| 827 |
+
lr: 0.0002
|
| 828 |
+
weight_decay: 1.0e-05
|
| 829 |
+
orientation_ras: true
|
| 830 |
+
output_classes: 2
|
| 831 |
+
pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 832 |
+
quick: false
|
| 833 |
+
rank: 0
|
| 834 |
+
resample: true
|
| 835 |
+
resample_resolution:
|
| 836 |
+
- 0.48766356436698155
|
| 837 |
+
- 0.4876635832539761
|
| 838 |
+
- 2.748479210553717
|
| 839 |
+
roi_size:
|
| 840 |
+
- 384
|
| 841 |
+
- 384
|
| 842 |
+
- 60
|
| 843 |
+
sigmoid: false
|
| 844 |
+
spacing_lower:
|
| 845 |
+
- 0.42813486948609353
|
| 846 |
+
- 0.428134856247896
|
| 847 |
+
- 2.499999978382533
|
| 848 |
+
spacing_median:
|
| 849 |
+
- 0.48766356436698155
|
| 850 |
+
- 0.4876635832539761
|
| 851 |
+
- 4.770811902267695
|
| 852 |
+
spacing_upper:
|
| 853 |
+
- 0.5859375
|
| 854 |
+
- 0.5859375004856939
|
| 855 |
+
- 5.012642938162783
|
| 856 |
+
start_epoch: 0
|
| 857 |
+
stop_on_lowacc: true
|
| 858 |
+
validate:
|
| 859 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 860 |
+
enabled: false
|
| 861 |
+
invert: true
|
| 862 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
|
| 863 |
+
save_mask: false
|
| 864 |
+
validate_final_original_res: true
|
| 865 |
+
|
| 866 |
+
auto_adjust_network_settings no distributed global_rank 0
|
| 867 |
+
GPU device memory min: 16
|
| 868 |
+
base_numel 7225344 gpu_factor 1 gpu_factor_init 1
|
| 869 |
+
input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
|
| 870 |
+
increasing roi step [ 257.600 257.600 61.000]
|
| 871 |
+
increasing roi result 1 [ 257.600 257.600 61.000]
|
| 872 |
+
increasing roi step [ 296.240 296.240 61.000]
|
| 873 |
+
increasing roi result 1 [ 296.240 296.240 61.000]
|
| 874 |
+
increasing roi step [ 340.676 340.676 61.000]
|
| 875 |
+
increasing roi result 1 [ 340.676 340.676 61.000]
|
| 876 |
+
increasing roi step [ 391.777 391.777 61.000]
|
| 877 |
+
increasing roi result 1 [ 391.777 391.777 61.000]
|
| 878 |
+
roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
|
| 879 |
+
kept filters the same base_numel 7225344, gpu_factor 1
|
| 880 |
+
kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
|
| 881 |
+
Suggested network parameters:
|
| 882 |
+
Batch size 1 => 1
|
| 883 |
+
ROI size [224, 224, 144] => [384, 384, 60]
|
| 884 |
+
init_filters 32 => 32
|
| 885 |
+
aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
|
| 886 |
+
|
| 887 |
+
Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
|
| 888 |
+
SegResNetDS(
|
| 889 |
+
(encoder): SegResEncoder(
|
| 890 |
+
(conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 891 |
+
(layers): ModuleList(
|
| 892 |
+
(0): ModuleDict(
|
| 893 |
+
(blocks): Sequential(
|
| 894 |
+
(0): SegResBlock(
|
| 895 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 896 |
+
(act1): ReLU(inplace=True)
|
| 897 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 898 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 899 |
+
(act2): ReLU(inplace=True)
|
| 900 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 901 |
+
)
|
| 902 |
+
)
|
| 903 |
+
(downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 904 |
+
)
|
| 905 |
+
(1): ModuleDict(
|
| 906 |
+
(blocks): Sequential(
|
| 907 |
+
(0): SegResBlock(
|
| 908 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 909 |
+
(act1): ReLU(inplace=True)
|
| 910 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 911 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 912 |
+
(act2): ReLU(inplace=True)
|
| 913 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 914 |
+
)
|
| 915 |
+
(1): SegResBlock(
|
| 916 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 917 |
+
(act1): ReLU(inplace=True)
|
| 918 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 919 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 920 |
+
(act2): ReLU(inplace=True)
|
| 921 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 922 |
+
)
|
| 923 |
+
)
|
| 924 |
+
(downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 925 |
+
)
|
| 926 |
+
(2): ModuleDict(
|
| 927 |
+
(blocks): Sequential(
|
| 928 |
+
(0): SegResBlock(
|
| 929 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 930 |
+
(act1): ReLU(inplace=True)
|
| 931 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 932 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 933 |
+
(act2): ReLU(inplace=True)
|
| 934 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 935 |
+
)
|
| 936 |
+
(1): SegResBlock(
|
| 937 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 938 |
+
(act1): ReLU(inplace=True)
|
| 939 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 940 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 941 |
+
(act2): ReLU(inplace=True)
|
| 942 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 943 |
+
)
|
| 944 |
+
)
|
| 945 |
+
(downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 946 |
+
)
|
| 947 |
+
(3): ModuleDict(
|
| 948 |
+
(blocks): Sequential(
|
| 949 |
+
(0): SegResBlock(
|
| 950 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 951 |
+
(act1): ReLU(inplace=True)
|
| 952 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 953 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 954 |
+
(act2): ReLU(inplace=True)
|
| 955 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 956 |
+
)
|
| 957 |
+
(1): SegResBlock(
|
| 958 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 959 |
+
(act1): ReLU(inplace=True)
|
| 960 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 961 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 962 |
+
(act2): ReLU(inplace=True)
|
| 963 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 964 |
+
)
|
| 965 |
+
(2): SegResBlock(
|
| 966 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 967 |
+
(act1): ReLU(inplace=True)
|
| 968 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 969 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 970 |
+
(act2): ReLU(inplace=True)
|
| 971 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 972 |
+
)
|
| 973 |
+
(3): SegResBlock(
|
| 974 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 975 |
+
(act1): ReLU(inplace=True)
|
| 976 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 977 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 978 |
+
(act2): ReLU(inplace=True)
|
| 979 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 980 |
+
)
|
| 981 |
+
)
|
| 982 |
+
(downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 983 |
+
)
|
| 984 |
+
(4): ModuleDict(
|
| 985 |
+
(blocks): Sequential(
|
| 986 |
+
(0): SegResBlock(
|
| 987 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 988 |
+
(act1): ReLU(inplace=True)
|
| 989 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 990 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 991 |
+
(act2): ReLU(inplace=True)
|
| 992 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 993 |
+
)
|
| 994 |
+
(1): SegResBlock(
|
| 995 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 996 |
+
(act1): ReLU(inplace=True)
|
| 997 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 998 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 999 |
+
(act2): ReLU(inplace=True)
|
| 1000 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1001 |
+
)
|
| 1002 |
+
(2): SegResBlock(
|
| 1003 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1004 |
+
(act1): ReLU(inplace=True)
|
| 1005 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1006 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1007 |
+
(act2): ReLU(inplace=True)
|
| 1008 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1009 |
+
)
|
| 1010 |
+
(3): SegResBlock(
|
| 1011 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1012 |
+
(act1): ReLU(inplace=True)
|
| 1013 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1014 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1015 |
+
(act2): ReLU(inplace=True)
|
| 1016 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1017 |
+
)
|
| 1018 |
+
)
|
| 1019 |
+
(downsample): Identity()
|
| 1020 |
+
)
|
| 1021 |
+
)
|
| 1022 |
+
)
|
| 1023 |
+
(up_layers): ModuleList(
|
| 1024 |
+
(0): ModuleDict(
|
| 1025 |
+
(upsample): UpSample(
|
| 1026 |
+
(deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1027 |
+
)
|
| 1028 |
+
(blocks): Sequential(
|
| 1029 |
+
(0): SegResBlock(
|
| 1030 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1031 |
+
(act1): ReLU(inplace=True)
|
| 1032 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1033 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1034 |
+
(act2): ReLU(inplace=True)
|
| 1035 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1036 |
+
)
|
| 1037 |
+
)
|
| 1038 |
+
(head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1039 |
+
)
|
| 1040 |
+
(1): ModuleDict(
|
| 1041 |
+
(upsample): UpSample(
|
| 1042 |
+
(deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1043 |
+
)
|
| 1044 |
+
(blocks): Sequential(
|
| 1045 |
+
(0): SegResBlock(
|
| 1046 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1047 |
+
(act1): ReLU(inplace=True)
|
| 1048 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1049 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1050 |
+
(act2): ReLU(inplace=True)
|
| 1051 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1052 |
+
)
|
| 1053 |
+
)
|
| 1054 |
+
(head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1055 |
+
)
|
| 1056 |
+
(2): ModuleDict(
|
| 1057 |
+
(upsample): UpSample(
|
| 1058 |
+
(deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1059 |
+
)
|
| 1060 |
+
(blocks): Sequential(
|
| 1061 |
+
(0): SegResBlock(
|
| 1062 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1063 |
+
(act1): ReLU(inplace=True)
|
| 1064 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1065 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1066 |
+
(act2): ReLU(inplace=True)
|
| 1067 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1068 |
+
)
|
| 1069 |
+
)
|
| 1070 |
+
(head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1071 |
+
)
|
| 1072 |
+
(3): ModuleDict(
|
| 1073 |
+
(upsample): UpSample(
|
| 1074 |
+
(deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1075 |
+
)
|
| 1076 |
+
(blocks): Sequential(
|
| 1077 |
+
(0): SegResBlock(
|
| 1078 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1079 |
+
(act1): ReLU(inplace=True)
|
| 1080 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1081 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1082 |
+
(act2): ReLU(inplace=True)
|
| 1083 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1084 |
+
)
|
| 1085 |
+
)
|
| 1086 |
+
(head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1087 |
+
)
|
| 1088 |
+
)
|
| 1089 |
+
)
|
| 1090 |
+
=> loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
|
| 1091 |
+
Total parameters count: 86278888 distributed: False
|
| 1092 |
+
Inference complete, time 226.94s shape torch.Size([2, 512, 512, 40]) {'image': 'sample_data/images/TBI_INVAC184NYT.nii'}
|
| 1093 |
+
_meta_: {}
|
| 1094 |
+
acc: null
|
| 1095 |
+
amp: false
|
| 1096 |
+
anisotropic_scales: true
|
| 1097 |
+
auto_scale_allowed: true
|
| 1098 |
+
auto_scale_batch: true
|
| 1099 |
+
auto_scale_filters: false
|
| 1100 |
+
auto_scale_roi: false
|
| 1101 |
+
batch_size: 1
|
| 1102 |
+
bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
|
| 1103 |
+
cache_class_indices: null
|
| 1104 |
+
cache_rate: null
|
| 1105 |
+
calc_val_loss: false
|
| 1106 |
+
channels_last: true
|
| 1107 |
+
ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
|
| 1108 |
+
ckpt_save: true
|
| 1109 |
+
class_index: null
|
| 1110 |
+
class_names:
|
| 1111 |
+
- acc_0
|
| 1112 |
+
crop_add_background: true
|
| 1113 |
+
crop_foreground: true
|
| 1114 |
+
crop_mode: ratio
|
| 1115 |
+
crop_ratios: null
|
| 1116 |
+
cuda: false
|
| 1117 |
+
data_file_base_dir: /Users/sakshirathi/neurotk/bundles
|
| 1118 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 1119 |
+
debug: false
|
| 1120 |
+
determ: false
|
| 1121 |
+
early_stopping_fraction: 0.001
|
| 1122 |
+
extra_modalities: {}
|
| 1123 |
+
finetune:
|
| 1124 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1125 |
+
enabled: false
|
| 1126 |
+
float32_precision: null
|
| 1127 |
+
fold: 0
|
| 1128 |
+
fork: true
|
| 1129 |
+
global_rank: 0
|
| 1130 |
+
image_size:
|
| 1131 |
+
- 544
|
| 1132 |
+
- 544
|
| 1133 |
+
- 69
|
| 1134 |
+
image_size_mm_90:
|
| 1135 |
+
- 265.61599121093747
|
| 1136 |
+
- 265.6159922216141
|
| 1137 |
+
- 190.12765338720757
|
| 1138 |
+
image_size_mm_median:
|
| 1139 |
+
- 249.68374495589455
|
| 1140 |
+
- 249.68375462603575
|
| 1141 |
+
- 168.30083390623668
|
| 1142 |
+
infer:
|
| 1143 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1144 |
+
data_list_key: testing
|
| 1145 |
+
enabled: true
|
| 1146 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
|
| 1147 |
+
input_channels: 1
|
| 1148 |
+
intensity_bounds:
|
| 1149 |
+
- 39.63595217750186
|
| 1150 |
+
- 97.59593563988095
|
| 1151 |
+
learning_rate: 0.0002
|
| 1152 |
+
log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
|
| 1153 |
+
loss:
|
| 1154 |
+
_target_: DiceCELoss
|
| 1155 |
+
include_background: true
|
| 1156 |
+
sigmoid: false
|
| 1157 |
+
smooth_dr: 1.0e-05
|
| 1158 |
+
smooth_nr: 0
|
| 1159 |
+
softmax: true
|
| 1160 |
+
squared_pred: true
|
| 1161 |
+
to_onehot_y: true
|
| 1162 |
+
max_samples_per_class: 12500
|
| 1163 |
+
mlflow_experiment_name: Auto3DSeg
|
| 1164 |
+
mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
|
| 1165 |
+
modality: ct
|
| 1166 |
+
network:
|
| 1167 |
+
_target_: SegResNetDS
|
| 1168 |
+
blocks_down:
|
| 1169 |
+
- 1
|
| 1170 |
+
- 2
|
| 1171 |
+
- 2
|
| 1172 |
+
- 4
|
| 1173 |
+
- 4
|
| 1174 |
+
dsdepth: 4
|
| 1175 |
+
in_channels: 1
|
| 1176 |
+
init_filters: 32
|
| 1177 |
+
norm: INSTANCE_NVFUSER
|
| 1178 |
+
out_channels: 2
|
| 1179 |
+
normalize_mode: range
|
| 1180 |
+
notf32: false
|
| 1181 |
+
num_crops_per_image: 2
|
| 1182 |
+
num_epochs: 1250
|
| 1183 |
+
num_epochs_per_saving: 1
|
| 1184 |
+
num_epochs_per_validation: null
|
| 1185 |
+
num_images_per_batch: 1
|
| 1186 |
+
num_steps_per_image: null
|
| 1187 |
+
num_warmup_epochs: 3
|
| 1188 |
+
num_workers: 4
|
| 1189 |
+
optimizer:
|
| 1190 |
+
_target_: torch.optim.AdamW
|
| 1191 |
+
lr: 0.0002
|
| 1192 |
+
weight_decay: 1.0e-05
|
| 1193 |
+
orientation_ras: true
|
| 1194 |
+
output_classes: 2
|
| 1195 |
+
pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1196 |
+
quick: false
|
| 1197 |
+
rank: 0
|
| 1198 |
+
resample: true
|
| 1199 |
+
resample_resolution:
|
| 1200 |
+
- 0.48766356436698155
|
| 1201 |
+
- 0.4876635832539761
|
| 1202 |
+
- 2.748479210553717
|
| 1203 |
+
roi_size:
|
| 1204 |
+
- 384
|
| 1205 |
+
- 384
|
| 1206 |
+
- 60
|
| 1207 |
+
sigmoid: false
|
| 1208 |
+
spacing_lower:
|
| 1209 |
+
- 0.42813486948609353
|
| 1210 |
+
- 0.428134856247896
|
| 1211 |
+
- 2.499999978382533
|
| 1212 |
+
spacing_median:
|
| 1213 |
+
- 0.48766356436698155
|
| 1214 |
+
- 0.4876635832539761
|
| 1215 |
+
- 4.770811902267695
|
| 1216 |
+
spacing_upper:
|
| 1217 |
+
- 0.5859375
|
| 1218 |
+
- 0.5859375004856939
|
| 1219 |
+
- 5.012642938162783
|
| 1220 |
+
start_epoch: 0
|
| 1221 |
+
stop_on_lowacc: true
|
| 1222 |
+
validate:
|
| 1223 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1224 |
+
enabled: false
|
| 1225 |
+
invert: true
|
| 1226 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
|
| 1227 |
+
save_mask: false
|
| 1228 |
+
validate_final_original_res: true
|
| 1229 |
+
|
| 1230 |
+
auto_adjust_network_settings no distributed global_rank 0
|
| 1231 |
+
GPU device memory min: 16
|
| 1232 |
+
base_numel 7225344 gpu_factor 1 gpu_factor_init 1
|
| 1233 |
+
input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
|
| 1234 |
+
increasing roi step [ 257.600 257.600 61.000]
|
| 1235 |
+
increasing roi result 1 [ 257.600 257.600 61.000]
|
| 1236 |
+
increasing roi step [ 296.240 296.240 61.000]
|
| 1237 |
+
increasing roi result 1 [ 296.240 296.240 61.000]
|
| 1238 |
+
increasing roi step [ 340.676 340.676 61.000]
|
| 1239 |
+
increasing roi result 1 [ 340.676 340.676 61.000]
|
| 1240 |
+
increasing roi step [ 391.777 391.777 61.000]
|
| 1241 |
+
increasing roi result 1 [ 391.777 391.777 61.000]
|
| 1242 |
+
roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
|
| 1243 |
+
kept filters the same base_numel 7225344, gpu_factor 1
|
| 1244 |
+
kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
|
| 1245 |
+
Suggested network parameters:
|
| 1246 |
+
Batch size 1 => 1
|
| 1247 |
+
ROI size [224, 224, 144] => [384, 384, 60]
|
| 1248 |
+
init_filters 32 => 32
|
| 1249 |
+
aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
|
| 1250 |
+
|
| 1251 |
+
Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
|
| 1252 |
+
SegResNetDS(
|
| 1253 |
+
(encoder): SegResEncoder(
|
| 1254 |
+
(conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1255 |
+
(layers): ModuleList(
|
| 1256 |
+
(0): ModuleDict(
|
| 1257 |
+
(blocks): Sequential(
|
| 1258 |
+
(0): SegResBlock(
|
| 1259 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1260 |
+
(act1): ReLU(inplace=True)
|
| 1261 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1262 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1263 |
+
(act2): ReLU(inplace=True)
|
| 1264 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1265 |
+
)
|
| 1266 |
+
)
|
| 1267 |
+
(downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 1268 |
+
)
|
| 1269 |
+
(1): ModuleDict(
|
| 1270 |
+
(blocks): Sequential(
|
| 1271 |
+
(0): SegResBlock(
|
| 1272 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1273 |
+
(act1): ReLU(inplace=True)
|
| 1274 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1275 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1276 |
+
(act2): ReLU(inplace=True)
|
| 1277 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1278 |
+
)
|
| 1279 |
+
(1): SegResBlock(
|
| 1280 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1281 |
+
(act1): ReLU(inplace=True)
|
| 1282 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1283 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1284 |
+
(act2): ReLU(inplace=True)
|
| 1285 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1286 |
+
)
|
| 1287 |
+
)
|
| 1288 |
+
(downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 1289 |
+
)
|
| 1290 |
+
(2): ModuleDict(
|
| 1291 |
+
(blocks): Sequential(
|
| 1292 |
+
(0): SegResBlock(
|
| 1293 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1294 |
+
(act1): ReLU(inplace=True)
|
| 1295 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1296 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1297 |
+
(act2): ReLU(inplace=True)
|
| 1298 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1299 |
+
)
|
| 1300 |
+
(1): SegResBlock(
|
| 1301 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1302 |
+
(act1): ReLU(inplace=True)
|
| 1303 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1304 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1305 |
+
(act2): ReLU(inplace=True)
|
| 1306 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1307 |
+
)
|
| 1308 |
+
)
|
| 1309 |
+
(downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 1310 |
+
)
|
| 1311 |
+
(3): ModuleDict(
|
| 1312 |
+
(blocks): Sequential(
|
| 1313 |
+
(0): SegResBlock(
|
| 1314 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1315 |
+
(act1): ReLU(inplace=True)
|
| 1316 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1317 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1318 |
+
(act2): ReLU(inplace=True)
|
| 1319 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1320 |
+
)
|
| 1321 |
+
(1): SegResBlock(
|
| 1322 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1323 |
+
(act1): ReLU(inplace=True)
|
| 1324 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1325 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1326 |
+
(act2): ReLU(inplace=True)
|
| 1327 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1328 |
+
)
|
| 1329 |
+
(2): SegResBlock(
|
| 1330 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1331 |
+
(act1): ReLU(inplace=True)
|
| 1332 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1333 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1334 |
+
(act2): ReLU(inplace=True)
|
| 1335 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1336 |
+
)
|
| 1337 |
+
(3): SegResBlock(
|
| 1338 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1339 |
+
(act1): ReLU(inplace=True)
|
| 1340 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1341 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1342 |
+
(act2): ReLU(inplace=True)
|
| 1343 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1344 |
+
)
|
| 1345 |
+
)
|
| 1346 |
+
(downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 1347 |
+
)
|
| 1348 |
+
(4): ModuleDict(
|
| 1349 |
+
(blocks): Sequential(
|
| 1350 |
+
(0): SegResBlock(
|
| 1351 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1352 |
+
(act1): ReLU(inplace=True)
|
| 1353 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1354 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1355 |
+
(act2): ReLU(inplace=True)
|
| 1356 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1357 |
+
)
|
| 1358 |
+
(1): SegResBlock(
|
| 1359 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1360 |
+
(act1): ReLU(inplace=True)
|
| 1361 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1362 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1363 |
+
(act2): ReLU(inplace=True)
|
| 1364 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1365 |
+
)
|
| 1366 |
+
(2): SegResBlock(
|
| 1367 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1368 |
+
(act1): ReLU(inplace=True)
|
| 1369 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1370 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1371 |
+
(act2): ReLU(inplace=True)
|
| 1372 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1373 |
+
)
|
| 1374 |
+
(3): SegResBlock(
|
| 1375 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1376 |
+
(act1): ReLU(inplace=True)
|
| 1377 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1378 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1379 |
+
(act2): ReLU(inplace=True)
|
| 1380 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1381 |
+
)
|
| 1382 |
+
)
|
| 1383 |
+
(downsample): Identity()
|
| 1384 |
+
)
|
| 1385 |
+
)
|
| 1386 |
+
)
|
| 1387 |
+
(up_layers): ModuleList(
|
| 1388 |
+
(0): ModuleDict(
|
| 1389 |
+
(upsample): UpSample(
|
| 1390 |
+
(deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1391 |
+
)
|
| 1392 |
+
(blocks): Sequential(
|
| 1393 |
+
(0): SegResBlock(
|
| 1394 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1395 |
+
(act1): ReLU(inplace=True)
|
| 1396 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1397 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1398 |
+
(act2): ReLU(inplace=True)
|
| 1399 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1400 |
+
)
|
| 1401 |
+
)
|
| 1402 |
+
(head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1403 |
+
)
|
| 1404 |
+
(1): ModuleDict(
|
| 1405 |
+
(upsample): UpSample(
|
| 1406 |
+
(deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1407 |
+
)
|
| 1408 |
+
(blocks): Sequential(
|
| 1409 |
+
(0): SegResBlock(
|
| 1410 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1411 |
+
(act1): ReLU(inplace=True)
|
| 1412 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1413 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1414 |
+
(act2): ReLU(inplace=True)
|
| 1415 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1416 |
+
)
|
| 1417 |
+
)
|
| 1418 |
+
(head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1419 |
+
)
|
| 1420 |
+
(2): ModuleDict(
|
| 1421 |
+
(upsample): UpSample(
|
| 1422 |
+
(deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1423 |
+
)
|
| 1424 |
+
(blocks): Sequential(
|
| 1425 |
+
(0): SegResBlock(
|
| 1426 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1427 |
+
(act1): ReLU(inplace=True)
|
| 1428 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1429 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1430 |
+
(act2): ReLU(inplace=True)
|
| 1431 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1432 |
+
)
|
| 1433 |
+
)
|
| 1434 |
+
(head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1435 |
+
)
|
| 1436 |
+
(3): ModuleDict(
|
| 1437 |
+
(upsample): UpSample(
|
| 1438 |
+
(deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1439 |
+
)
|
| 1440 |
+
(blocks): Sequential(
|
| 1441 |
+
(0): SegResBlock(
|
| 1442 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1443 |
+
(act1): ReLU(inplace=True)
|
| 1444 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1445 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1446 |
+
(act2): ReLU(inplace=True)
|
| 1447 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1448 |
+
)
|
| 1449 |
+
)
|
| 1450 |
+
(head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1451 |
+
)
|
| 1452 |
+
)
|
| 1453 |
+
)
|
| 1454 |
+
=> loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
|
| 1455 |
+
Total parameters count: 86278888 distributed: False
|
| 1456 |
+
_meta_: {}
|
| 1457 |
+
acc: null
|
| 1458 |
+
amp: false
|
| 1459 |
+
anisotropic_scales: true
|
| 1460 |
+
auto_scale_allowed: true
|
| 1461 |
+
auto_scale_batch: true
|
| 1462 |
+
auto_scale_filters: false
|
| 1463 |
+
auto_scale_roi: false
|
| 1464 |
+
batch_size: 1
|
| 1465 |
+
bundle_root: /Users/sakshirathi/neurotk/bundles/segresnet
|
| 1466 |
+
cache_class_indices: null
|
| 1467 |
+
cache_rate: null
|
| 1468 |
+
calc_val_loss: false
|
| 1469 |
+
channels_last: true
|
| 1470 |
+
ckpt_path: /Users/sakshirathi/neurotk/bundles/segresnet/model
|
| 1471 |
+
ckpt_save: true
|
| 1472 |
+
class_index: null
|
| 1473 |
+
class_names:
|
| 1474 |
+
- acc_0
|
| 1475 |
+
crop_add_background: true
|
| 1476 |
+
crop_foreground: true
|
| 1477 |
+
crop_mode: ratio
|
| 1478 |
+
crop_ratios: null
|
| 1479 |
+
cuda: false
|
| 1480 |
+
data_file_base_dir: /Users/sakshirathi/neurotk/bundles
|
| 1481 |
+
data_list_file_path: /Users/sakshirathi/Downloads/work_dir/dataset_local.json
|
| 1482 |
+
debug: false
|
| 1483 |
+
determ: false
|
| 1484 |
+
early_stopping_fraction: 0.001
|
| 1485 |
+
extra_modalities: {}
|
| 1486 |
+
finetune:
|
| 1487 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1488 |
+
enabled: false
|
| 1489 |
+
float32_precision: null
|
| 1490 |
+
fold: 0
|
| 1491 |
+
fork: true
|
| 1492 |
+
global_rank: 0
|
| 1493 |
+
image_size:
|
| 1494 |
+
- 544
|
| 1495 |
+
- 544
|
| 1496 |
+
- 69
|
| 1497 |
+
image_size_mm_90:
|
| 1498 |
+
- 265.61599121093747
|
| 1499 |
+
- 265.6159922216141
|
| 1500 |
+
- 190.12765338720757
|
| 1501 |
+
image_size_mm_median:
|
| 1502 |
+
- 249.68374495589455
|
| 1503 |
+
- 249.68375462603575
|
| 1504 |
+
- 168.30083390623668
|
| 1505 |
+
infer:
|
| 1506 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1507 |
+
data_list_key: testing
|
| 1508 |
+
enabled: true
|
| 1509 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_testing
|
| 1510 |
+
input_channels: 1
|
| 1511 |
+
intensity_bounds:
|
| 1512 |
+
- 39.63595217750186
|
| 1513 |
+
- 97.59593563988095
|
| 1514 |
+
learning_rate: 0.0002
|
| 1515 |
+
log_output_file: /Users/sakshirathi/neurotk/bundles/segresnet/model/training.log
|
| 1516 |
+
loss:
|
| 1517 |
+
_target_: DiceCELoss
|
| 1518 |
+
include_background: true
|
| 1519 |
+
sigmoid: false
|
| 1520 |
+
smooth_dr: 1.0e-05
|
| 1521 |
+
smooth_nr: 0
|
| 1522 |
+
softmax: true
|
| 1523 |
+
squared_pred: true
|
| 1524 |
+
to_onehot_y: true
|
| 1525 |
+
max_samples_per_class: 12500
|
| 1526 |
+
mlflow_experiment_name: Auto3DSeg
|
| 1527 |
+
mlflow_tracking_uri: /Users/sakshirathi/neurotk/bundles/segresnet/model/mlruns/
|
| 1528 |
+
modality: ct
|
| 1529 |
+
network:
|
| 1530 |
+
_target_: SegResNetDS
|
| 1531 |
+
blocks_down:
|
| 1532 |
+
- 1
|
| 1533 |
+
- 2
|
| 1534 |
+
- 2
|
| 1535 |
+
- 4
|
| 1536 |
+
- 4
|
| 1537 |
+
dsdepth: 4
|
| 1538 |
+
in_channels: 1
|
| 1539 |
+
init_filters: 32
|
| 1540 |
+
norm: INSTANCE_NVFUSER
|
| 1541 |
+
out_channels: 2
|
| 1542 |
+
normalize_mode: range
|
| 1543 |
+
notf32: false
|
| 1544 |
+
num_crops_per_image: 2
|
| 1545 |
+
num_epochs: 1250
|
| 1546 |
+
num_epochs_per_saving: 1
|
| 1547 |
+
num_epochs_per_validation: null
|
| 1548 |
+
num_images_per_batch: 1
|
| 1549 |
+
num_steps_per_image: null
|
| 1550 |
+
num_warmup_epochs: 3
|
| 1551 |
+
num_workers: 4
|
| 1552 |
+
optimizer:
|
| 1553 |
+
_target_: torch.optim.AdamW
|
| 1554 |
+
lr: 0.0002
|
| 1555 |
+
weight_decay: 1.0e-05
|
| 1556 |
+
orientation_ras: true
|
| 1557 |
+
output_classes: 2
|
| 1558 |
+
pretrained_ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1559 |
+
quick: false
|
| 1560 |
+
rank: 0
|
| 1561 |
+
resample: true
|
| 1562 |
+
resample_resolution:
|
| 1563 |
+
- 0.48766356436698155
|
| 1564 |
+
- 0.4876635832539761
|
| 1565 |
+
- 2.748479210553717
|
| 1566 |
+
roi_size:
|
| 1567 |
+
- 384
|
| 1568 |
+
- 384
|
| 1569 |
+
- 60
|
| 1570 |
+
sigmoid: false
|
| 1571 |
+
spacing_lower:
|
| 1572 |
+
- 0.42813486948609353
|
| 1573 |
+
- 0.428134856247896
|
| 1574 |
+
- 2.499999978382533
|
| 1575 |
+
spacing_median:
|
| 1576 |
+
- 0.48766356436698155
|
| 1577 |
+
- 0.4876635832539761
|
| 1578 |
+
- 4.770811902267695
|
| 1579 |
+
spacing_upper:
|
| 1580 |
+
- 0.5859375
|
| 1581 |
+
- 0.5859375004856939
|
| 1582 |
+
- 5.012642938162783
|
| 1583 |
+
start_epoch: 0
|
| 1584 |
+
stop_on_lowacc: true
|
| 1585 |
+
validate:
|
| 1586 |
+
ckpt_name: /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt
|
| 1587 |
+
enabled: false
|
| 1588 |
+
invert: true
|
| 1589 |
+
output_path: /Users/sakshirathi/neurotk/bundles/segresnet/prediction_validation
|
| 1590 |
+
save_mask: false
|
| 1591 |
+
validate_final_original_res: true
|
| 1592 |
+
|
| 1593 |
+
auto_adjust_network_settings no distributed global_rank 0
|
| 1594 |
+
GPU device memory min: 16
|
| 1595 |
+
base_numel 7225344 gpu_factor 1 gpu_factor_init 1
|
| 1596 |
+
input roi [224 224 144] image_size [ 512.000 512.000 61.000] numel 7225344
|
| 1597 |
+
increasing roi step [ 257.600 257.600 61.000]
|
| 1598 |
+
increasing roi result 1 [ 257.600 257.600 61.000]
|
| 1599 |
+
increasing roi step [ 296.240 296.240 61.000]
|
| 1600 |
+
increasing roi result 1 [ 296.240 296.240 61.000]
|
| 1601 |
+
increasing roi step [ 340.676 340.676 61.000]
|
| 1602 |
+
increasing roi result 1 [ 340.676 340.676 61.000]
|
| 1603 |
+
increasing roi step [ 391.777 391.777 61.000]
|
| 1604 |
+
increasing roi result 1 [ 391.777 391.777 61.000]
|
| 1605 |
+
roi_size factored [ 384.000 384.000 60.000] factor [ 16.000 16.000 4.000] extra_levels [ 0.000 0.000 2.000]
|
| 1606 |
+
kept filters the same base_numel 7225344, gpu_factor 1
|
| 1607 |
+
kept batch the same base_numel 7225344, gpu_factor 1, gpu_factor_init 1
|
| 1608 |
+
Suggested network parameters:
|
| 1609 |
+
Batch size 1 => 1
|
| 1610 |
+
ROI size [224, 224, 144] => [384, 384, 60]
|
| 1611 |
+
init_filters 32 => 32
|
| 1612 |
+
aniso: True image_size_mm: [249.68374495589455, 249.68375462603575, 168.30083390623668] spacing: [0.48766356436698155, 0.4876635832539761, 2.748479210553717] levels: 5
|
| 1613 |
+
|
| 1614 |
+
Using anisotropic scales {'_target_': 'SegResNetDS', 'init_filters': 32, 'blocks_down': [1, 2, 2, 4, 4], 'norm': 'INSTANCE', 'in_channels': 1, 'out_channels': 2, 'dsdepth': 4, 'resolution': [0.48766356436698155, 0.4876635832539761, 2.748479210553717]}
|
| 1615 |
+
SegResNetDS(
|
| 1616 |
+
(encoder): SegResEncoder(
|
| 1617 |
+
(conv_init): Conv3d(1, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1618 |
+
(layers): ModuleList(
|
| 1619 |
+
(0): ModuleDict(
|
| 1620 |
+
(blocks): Sequential(
|
| 1621 |
+
(0): SegResBlock(
|
| 1622 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1623 |
+
(act1): ReLU(inplace=True)
|
| 1624 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1625 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1626 |
+
(act2): ReLU(inplace=True)
|
| 1627 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1628 |
+
)
|
| 1629 |
+
)
|
| 1630 |
+
(downsample): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 1631 |
+
)
|
| 1632 |
+
(1): ModuleDict(
|
| 1633 |
+
(blocks): Sequential(
|
| 1634 |
+
(0): SegResBlock(
|
| 1635 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1636 |
+
(act1): ReLU(inplace=True)
|
| 1637 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1638 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1639 |
+
(act2): ReLU(inplace=True)
|
| 1640 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1641 |
+
)
|
| 1642 |
+
(1): SegResBlock(
|
| 1643 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1644 |
+
(act1): ReLU(inplace=True)
|
| 1645 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1646 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1647 |
+
(act2): ReLU(inplace=True)
|
| 1648 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1649 |
+
)
|
| 1650 |
+
)
|
| 1651 |
+
(downsample): Conv3d(64, 128, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), bias=False)
|
| 1652 |
+
)
|
| 1653 |
+
(2): ModuleDict(
|
| 1654 |
+
(blocks): Sequential(
|
| 1655 |
+
(0): SegResBlock(
|
| 1656 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1657 |
+
(act1): ReLU(inplace=True)
|
| 1658 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1659 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1660 |
+
(act2): ReLU(inplace=True)
|
| 1661 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1662 |
+
)
|
| 1663 |
+
(1): SegResBlock(
|
| 1664 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1665 |
+
(act1): ReLU(inplace=True)
|
| 1666 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1667 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1668 |
+
(act2): ReLU(inplace=True)
|
| 1669 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1670 |
+
)
|
| 1671 |
+
)
|
| 1672 |
+
(downsample): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 1673 |
+
)
|
| 1674 |
+
(3): ModuleDict(
|
| 1675 |
+
(blocks): Sequential(
|
| 1676 |
+
(0): SegResBlock(
|
| 1677 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1678 |
+
(act1): ReLU(inplace=True)
|
| 1679 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1680 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1681 |
+
(act2): ReLU(inplace=True)
|
| 1682 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1683 |
+
)
|
| 1684 |
+
(1): SegResBlock(
|
| 1685 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1686 |
+
(act1): ReLU(inplace=True)
|
| 1687 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1688 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1689 |
+
(act2): ReLU(inplace=True)
|
| 1690 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1691 |
+
)
|
| 1692 |
+
(2): SegResBlock(
|
| 1693 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1694 |
+
(act1): ReLU(inplace=True)
|
| 1695 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1696 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1697 |
+
(act2): ReLU(inplace=True)
|
| 1698 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1699 |
+
)
|
| 1700 |
+
(3): SegResBlock(
|
| 1701 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1702 |
+
(act1): ReLU(inplace=True)
|
| 1703 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1704 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1705 |
+
(act2): ReLU(inplace=True)
|
| 1706 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1707 |
+
)
|
| 1708 |
+
)
|
| 1709 |
+
(downsample): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
|
| 1710 |
+
)
|
| 1711 |
+
(4): ModuleDict(
|
| 1712 |
+
(blocks): Sequential(
|
| 1713 |
+
(0): SegResBlock(
|
| 1714 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1715 |
+
(act1): ReLU(inplace=True)
|
| 1716 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1717 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1718 |
+
(act2): ReLU(inplace=True)
|
| 1719 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1720 |
+
)
|
| 1721 |
+
(1): SegResBlock(
|
| 1722 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1723 |
+
(act1): ReLU(inplace=True)
|
| 1724 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1725 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1726 |
+
(act2): ReLU(inplace=True)
|
| 1727 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1728 |
+
)
|
| 1729 |
+
(2): SegResBlock(
|
| 1730 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1731 |
+
(act1): ReLU(inplace=True)
|
| 1732 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1733 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1734 |
+
(act2): ReLU(inplace=True)
|
| 1735 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1736 |
+
)
|
| 1737 |
+
(3): SegResBlock(
|
| 1738 |
+
(norm1): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1739 |
+
(act1): ReLU(inplace=True)
|
| 1740 |
+
(conv1): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1741 |
+
(norm2): InstanceNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1742 |
+
(act2): ReLU(inplace=True)
|
| 1743 |
+
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1744 |
+
)
|
| 1745 |
+
)
|
| 1746 |
+
(downsample): Identity()
|
| 1747 |
+
)
|
| 1748 |
+
)
|
| 1749 |
+
)
|
| 1750 |
+
(up_layers): ModuleList(
|
| 1751 |
+
(0): ModuleDict(
|
| 1752 |
+
(upsample): UpSample(
|
| 1753 |
+
(deconv): ConvTranspose3d(512, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1754 |
+
)
|
| 1755 |
+
(blocks): Sequential(
|
| 1756 |
+
(0): SegResBlock(
|
| 1757 |
+
(norm1): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1758 |
+
(act1): ReLU(inplace=True)
|
| 1759 |
+
(conv1): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1760 |
+
(norm2): InstanceNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1761 |
+
(act2): ReLU(inplace=True)
|
| 1762 |
+
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1763 |
+
)
|
| 1764 |
+
)
|
| 1765 |
+
(head): Conv3d(256, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1766 |
+
)
|
| 1767 |
+
(1): ModuleDict(
|
| 1768 |
+
(upsample): UpSample(
|
| 1769 |
+
(deconv): ConvTranspose3d(256, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1), bias=False)
|
| 1770 |
+
)
|
| 1771 |
+
(blocks): Sequential(
|
| 1772 |
+
(0): SegResBlock(
|
| 1773 |
+
(norm1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1774 |
+
(act1): ReLU(inplace=True)
|
| 1775 |
+
(conv1): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1776 |
+
(norm2): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1777 |
+
(act2): ReLU(inplace=True)
|
| 1778 |
+
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
| 1779 |
+
)
|
| 1780 |
+
)
|
| 1781 |
+
(head): Conv3d(128, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1782 |
+
)
|
| 1783 |
+
(2): ModuleDict(
|
| 1784 |
+
(upsample): UpSample(
|
| 1785 |
+
(deconv): ConvTranspose3d(128, 64, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1786 |
+
)
|
| 1787 |
+
(blocks): Sequential(
|
| 1788 |
+
(0): SegResBlock(
|
| 1789 |
+
(norm1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1790 |
+
(act1): ReLU(inplace=True)
|
| 1791 |
+
(conv1): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1792 |
+
(norm2): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1793 |
+
(act2): ReLU(inplace=True)
|
| 1794 |
+
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1795 |
+
)
|
| 1796 |
+
)
|
| 1797 |
+
(head): Conv3d(64, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1798 |
+
)
|
| 1799 |
+
(3): ModuleDict(
|
| 1800 |
+
(upsample): UpSample(
|
| 1801 |
+
(deconv): ConvTranspose3d(64, 32, kernel_size=(3, 3, 1), stride=(np.int64(2), np.int64(2), np.int64(1)), padding=(1, 1, 0), output_padding=(np.int64(1), np.int64(1), np.int64(0)), bias=False)
|
| 1802 |
+
)
|
| 1803 |
+
(blocks): Sequential(
|
| 1804 |
+
(0): SegResBlock(
|
| 1805 |
+
(norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1806 |
+
(act1): ReLU(inplace=True)
|
| 1807 |
+
(conv1): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1808 |
+
(norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 1809 |
+
(act2): ReLU(inplace=True)
|
| 1810 |
+
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
|
| 1811 |
+
)
|
| 1812 |
+
)
|
| 1813 |
+
(head): Conv3d(32, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
|
| 1814 |
+
)
|
| 1815 |
+
)
|
| 1816 |
+
)
|
| 1817 |
+
=> loaded checkpoint /Users/sakshirathi/neurotk/bundles/segresnet/model/model.pt (epoch 1122) (best_metric 0.843817412853241) setting start_epoch 0
|
| 1818 |
+
Total parameters count: 86278888 distributed: False
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
scripts/__pycache__/segmenter.cpython-310.pyc
ADDED
|
Binary file (53.7 kB). View file
|
|
|
scripts/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
scripts/segmenter.py
ADDED
|
@@ -0,0 +1,2212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import csv
|
| 16 |
+
import gc
|
| 17 |
+
import logging
|
| 18 |
+
import multiprocessing as mp
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
import warnings
|
| 24 |
+
from datetime import datetime, timedelta
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import psutil
|
| 30 |
+
import torch
|
| 31 |
+
import torch.distributed as dist
|
| 32 |
+
import torch.multiprocessing as mp
|
| 33 |
+
import yaml
|
| 34 |
+
from torch.amp import GradScaler, autocast
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 36 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 37 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 38 |
+
|
| 39 |
+
from monai.apps.auto3dseg.auto_runner import logger
|
| 40 |
+
from monai.apps.auto3dseg.transforms import EnsureSameShaped
|
| 41 |
+
from monai.auto3dseg.utils import datafold_read
|
| 42 |
+
from monai.bundle.config_parser import ConfigParser
|
| 43 |
+
from monai.config import KeysCollection
|
| 44 |
+
from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, decollate_batch, list_data_collate
|
| 45 |
+
from monai.inferers import SlidingWindowInfererAdapt
|
| 46 |
+
from monai.losses import DeepSupervisionLoss
|
| 47 |
+
from monai.metrics import CumulativeAverage, DiceHelper
|
| 48 |
+
from monai.networks.layers.factories import split_args
|
| 49 |
+
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
|
| 50 |
+
from monai.transforms import (
|
| 51 |
+
AsDiscreted,
|
| 52 |
+
CastToTyped,
|
| 53 |
+
ClassesToIndicesd,
|
| 54 |
+
Compose,
|
| 55 |
+
ConcatItemsd,
|
| 56 |
+
CopyItemsd,
|
| 57 |
+
CropForegroundd,
|
| 58 |
+
DataStatsd,
|
| 59 |
+
DeleteItemsd,
|
| 60 |
+
EnsureTyped,
|
| 61 |
+
Identityd,
|
| 62 |
+
Invertd,
|
| 63 |
+
Lambdad,
|
| 64 |
+
LoadImaged,
|
| 65 |
+
NormalizeIntensityd,
|
| 66 |
+
Orientationd,
|
| 67 |
+
RandAdjustContrastd,
|
| 68 |
+
RandAffined,
|
| 69 |
+
RandCropByLabelClassesd,
|
| 70 |
+
RandFlipd,
|
| 71 |
+
RandGaussianNoised,
|
| 72 |
+
RandGaussianSmoothd,
|
| 73 |
+
RandHistogramShiftd,
|
| 74 |
+
RandIdentity,
|
| 75 |
+
RandRotate90d,
|
| 76 |
+
RandScaleIntensityd,
|
| 77 |
+
RandScaleIntensityFixedMeand,
|
| 78 |
+
RandShiftIntensityd,
|
| 79 |
+
RandSpatialCropd,
|
| 80 |
+
ResampleToMatchd,
|
| 81 |
+
SaveImaged,
|
| 82 |
+
ScaleIntensityRanged,
|
| 83 |
+
Spacingd,
|
| 84 |
+
SpatialPadd,
|
| 85 |
+
ToDeviced,
|
| 86 |
+
)
|
| 87 |
+
from monai.transforms.transform import MapTransform
|
| 88 |
+
from monai.utils import ImageMetaKey, convert_to_dst_type, optional_import, set_determinism
|
| 89 |
+
|
| 90 |
+
mlflow, mlflow_is_imported = optional_import("mlflow")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:2048"
|
| 94 |
+
print = logger.debug
|
| 95 |
+
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
|
| 96 |
+
|
| 97 |
+
if __package__ in (None, ""):
|
| 98 |
+
from utils import auto_adjust_network_settings, logger_configure
|
| 99 |
+
else:
|
| 100 |
+
from .utils import auto_adjust_network_settings, logger_configure
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class LabelEmbedClassIndex(MapTransform):
|
| 104 |
+
"""
|
| 105 |
+
Label embedding according to class_index
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self, keys: KeysCollection = "label", allow_missing_keys: bool = False, class_index: Optional[List] = None
|
| 110 |
+
) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Args:
|
| 113 |
+
keys: keys of the corresponding items to be compared to the source_key item shape.
|
| 114 |
+
allow_missing_keys: do not raise exception if key is missing.
|
| 115 |
+
class_index: a list of class indices
|
| 116 |
+
"""
|
| 117 |
+
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
|
| 118 |
+
self.class_index = class_index
|
| 119 |
+
|
| 120 |
+
def label_mapping(self, x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
dtype = x.dtype
|
| 122 |
+
return torch.cat([sum([x == i for i in c]) for c in self.class_index], dim=0).to(dtype=dtype)
|
| 123 |
+
|
| 124 |
+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
|
| 125 |
+
d = dict(data)
|
| 126 |
+
if self.class_index is not None:
|
| 127 |
+
for key in self.key_iterator(d):
|
| 128 |
+
d[key] = self.label_mapping(d[key])
|
| 129 |
+
return d
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def schedule_validation_epochs(num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list:
|
| 133 |
+
"""
|
| 134 |
+
Schedule of epochs to validate (progressively more frequently)
|
| 135 |
+
num_epochs - total number of epochs
|
| 136 |
+
num_epochs_per_validation - if provided use a linear schedule with this step
|
| 137 |
+
init_step
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
if num_epochs_per_validation is None:
|
| 141 |
+
x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int)
|
| 142 |
+
x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1])
|
| 143 |
+
x[-1] = num_epochs
|
| 144 |
+
x = x.tolist()
|
| 145 |
+
else:
|
| 146 |
+
if num_epochs_per_validation >= num_epochs:
|
| 147 |
+
x = [num_epochs_per_validation]
|
| 148 |
+
else:
|
| 149 |
+
x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation))
|
| 150 |
+
|
| 151 |
+
if len(x) == 0:
|
| 152 |
+
x = [0]
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class DataTransformBuilder:
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
roi_size: list,
|
| 161 |
+
image_key: str = "image",
|
| 162 |
+
label_key: str = "label",
|
| 163 |
+
resample: bool = False,
|
| 164 |
+
resample_resolution: Optional[list] = None,
|
| 165 |
+
normalize_mode: str = "meanstd",
|
| 166 |
+
normalize_params: Optional[dict] = None,
|
| 167 |
+
crop_mode: str = "ratio",
|
| 168 |
+
crop_params: Optional[dict] = None,
|
| 169 |
+
extra_modalities: Optional[dict] = None,
|
| 170 |
+
custom_transforms=None,
|
| 171 |
+
augment_params: Optional[dict] = None,
|
| 172 |
+
debug: bool = False,
|
| 173 |
+
rank: int = 0,
|
| 174 |
+
class_index=None,
|
| 175 |
+
**kwargs,
|
| 176 |
+
) -> None:
|
| 177 |
+
self.roi_size, self.image_key, self.label_key = roi_size, image_key, label_key
|
| 178 |
+
|
| 179 |
+
self.resample, self.resample_resolution = resample, resample_resolution
|
| 180 |
+
self.normalize_mode = normalize_mode
|
| 181 |
+
self.normalize_params = normalize_params if normalize_params is not None else {}
|
| 182 |
+
self.crop_mode = crop_mode
|
| 183 |
+
self.crop_params = crop_params if crop_params is not None else {}
|
| 184 |
+
self.augment_params = augment_params if augment_params is not None else {}
|
| 185 |
+
|
| 186 |
+
self.extra_modalities = extra_modalities if extra_modalities is not None else {}
|
| 187 |
+
self.custom_transforms = custom_transforms if custom_transforms is not None else {}
|
| 188 |
+
|
| 189 |
+
self.extra_options = kwargs
|
| 190 |
+
self.debug = debug
|
| 191 |
+
self.rank = rank
|
| 192 |
+
self.class_index = class_index
|
| 193 |
+
|
| 194 |
+
def get_custom(self, name, **kwargs):
|
| 195 |
+
tr = []
|
| 196 |
+
for t in self.custom_transforms.get(name, []):
|
| 197 |
+
if isinstance(t, dict):
|
| 198 |
+
t.update(kwargs)
|
| 199 |
+
t = ConfigParser(t).get_parsed_content(instantiate=True)
|
| 200 |
+
tr.append(t)
|
| 201 |
+
|
| 202 |
+
return tr
|
| 203 |
+
|
| 204 |
+
def get_load_transforms(self):
|
| 205 |
+
ts = self.get_custom("load_transforms")
|
| 206 |
+
if len(ts) > 0:
|
| 207 |
+
return ts
|
| 208 |
+
|
| 209 |
+
keys = [self.image_key, self.label_key] + list(self.extra_modalities)
|
| 210 |
+
ts.append(
|
| 211 |
+
LoadImaged(keys=keys, ensure_channel_first=True, dtype=None, allow_missing_keys=True, image_only=True)
|
| 212 |
+
)
|
| 213 |
+
ts.append(EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float, allow_missing_keys=True))
|
| 214 |
+
ts.append(
|
| 215 |
+
EnsureSameShaped(keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
ts.extend(self.get_custom("after_load_transforms"))
|
| 219 |
+
|
| 220 |
+
return ts
|
| 221 |
+
|
| 222 |
+
def get_resample_transforms(self, resample_label=True):
|
| 223 |
+
ts = self.get_custom("resample_transforms", resample_label=resample_label)
|
| 224 |
+
if len(ts) > 0:
|
| 225 |
+
return ts
|
| 226 |
+
|
| 227 |
+
keys = [self.image_key, self.label_key] if resample_label else [self.image_key]
|
| 228 |
+
mode = ["bilinear", "nearest"] if resample_label else ["bilinear"]
|
| 229 |
+
extra_keys = list(self.extra_modalities)
|
| 230 |
+
|
| 231 |
+
if self.extra_options.get("orientation_ras", False):
|
| 232 |
+
ts.append(Orientationd(keys=keys, axcodes="RAS", labels=(("L", "R"), ("P", "A"), ("I", "S"))))
|
| 233 |
+
|
| 234 |
+
if self.extra_options.get("crop_foreground", False) and len(extra_keys) == 0:
|
| 235 |
+
ts.append(
|
| 236 |
+
CropForegroundd(
|
| 237 |
+
keys=keys, source_key=self.image_key, allow_missing_keys=True, margin=10, allow_smaller=True
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
if self.resample:
|
| 241 |
+
if self.resample_resolution is None:
|
| 242 |
+
raise ValueError("resample_resolution is not provided")
|
| 243 |
+
|
| 244 |
+
pixdim = self.resample_resolution
|
| 245 |
+
ts.append(
|
| 246 |
+
Spacingd(
|
| 247 |
+
keys=keys,
|
| 248 |
+
pixdim=pixdim,
|
| 249 |
+
mode=mode,
|
| 250 |
+
dtype=torch.float,
|
| 251 |
+
min_pixdim=np.array(pixdim) * 0.75,
|
| 252 |
+
max_pixdim=np.array(pixdim) * 1.25,
|
| 253 |
+
allow_missing_keys=True,
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if resample_label:
|
| 258 |
+
ts.append(
|
| 259 |
+
EnsureSameShaped(
|
| 260 |
+
keys=self.label_key, source_key=self.image_key, allow_missing_keys=True, warn=self.debug
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
for extra_key in extra_keys:
|
| 265 |
+
ts.append(ResampleToMatchd(keys=extra_key, key_dst=self.image_key, dtype=np.float32))
|
| 266 |
+
|
| 267 |
+
ts.extend(self.get_custom("after_resample_transforms", resample_label=resample_label))
|
| 268 |
+
|
| 269 |
+
return ts
|
| 270 |
+
|
| 271 |
+
def get_normalize_transforms(self):
|
| 272 |
+
|
| 273 |
+
ts = self.get_custom("normalize_transforms")
|
| 274 |
+
if len(ts) > 0:
|
| 275 |
+
return ts
|
| 276 |
+
|
| 277 |
+
label_dtype = self.normalize_params.get("label_dtype", None)
|
| 278 |
+
if label_dtype is not None:
|
| 279 |
+
ts.append(CastToTyped(keys=self.label_key, dtype=label_dtype, allow_missing_keys=True))
|
| 280 |
+
image_dtype = self.normalize_params.get("image_dtype", None)
|
| 281 |
+
if image_dtype is not None:
|
| 282 |
+
ts.append(CastToTyped(keys=self.image_key, dtype=image_dtype, allow_missing_keys=True)) # for caching
|
| 283 |
+
ts.append(RandIdentity()) # indicate to stop caching after this point
|
| 284 |
+
ts.append(CastToTyped(keys=self.image_key, dtype=torch.float, allow_missing_keys=True))
|
| 285 |
+
|
| 286 |
+
modalities = {self.image_key: self.normalize_mode}
|
| 287 |
+
modalities.update(self.extra_modalities)
|
| 288 |
+
|
| 289 |
+
for key, normalize_mode in modalities.items():
|
| 290 |
+
if normalize_mode == "none":
|
| 291 |
+
pass
|
| 292 |
+
elif normalize_mode in ["range", "ct"]:
|
| 293 |
+
intensity_bounds = self.normalize_params.get("intensity_bounds", None)
|
| 294 |
+
if intensity_bounds is None:
|
| 295 |
+
intensity_bounds = [-250, 250]
|
| 296 |
+
warnings.warn(f"intensity_bounds is not specified, assuming {intensity_bounds}")
|
| 297 |
+
|
| 298 |
+
ts.append(
|
| 299 |
+
ScaleIntensityRanged(
|
| 300 |
+
keys=key, a_min=intensity_bounds[0], a_max=intensity_bounds[1], b_min=-1, b_max=1, clip=False
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid(x)))
|
| 304 |
+
elif normalize_mode in ["meanstd", "mri"]:
|
| 305 |
+
ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
|
| 306 |
+
elif normalize_mode in ["meanstdtanh"]:
|
| 307 |
+
ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
|
| 308 |
+
ts.append(Lambdad(keys=key, func=lambda x: 3 * torch.tanh(x / 3)))
|
| 309 |
+
elif normalize_mode in ["pet"]:
|
| 310 |
+
ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid((x - x.min()) / x.std())))
|
| 311 |
+
else:
|
| 312 |
+
raise ValueError("Unsupported normalize_mode" + str(normalize_mode))
|
| 313 |
+
|
| 314 |
+
if len(self.extra_modalities) > 0:
|
| 315 |
+
ts.append(ConcatItemsd(keys=list(modalities), name=self.image_key)) # concat
|
| 316 |
+
ts.append(DeleteItemsd(keys=list(self.extra_modalities))) # release memory
|
| 317 |
+
|
| 318 |
+
ts.extend(self.get_custom("after_normalize_transforms"))
|
| 319 |
+
return ts
|
| 320 |
+
|
| 321 |
+
def get_crop_transforms(self):
|
| 322 |
+
ts = self.get_custom("crop_transforms")
|
| 323 |
+
if len(ts) > 0:
|
| 324 |
+
return ts
|
| 325 |
+
|
| 326 |
+
if self.roi_size is None:
|
| 327 |
+
raise ValueError("roi_size is not specified")
|
| 328 |
+
|
| 329 |
+
keys = [self.image_key, self.label_key]
|
| 330 |
+
ts = []
|
| 331 |
+
ts.append(SpatialPadd(keys=keys, spatial_size=self.roi_size))
|
| 332 |
+
|
| 333 |
+
if self.crop_mode == "ratio":
|
| 334 |
+
output_classes = self.crop_params.get("output_classes", None)
|
| 335 |
+
if output_classes is None:
|
| 336 |
+
raise ValueError("crop_params option output_classes must be specified")
|
| 337 |
+
|
| 338 |
+
crop_ratios = self.crop_params.get("crop_ratios", None)
|
| 339 |
+
cache_class_indices = self.crop_params.get("cache_class_indices", False)
|
| 340 |
+
max_samples_per_class = self.crop_params.get("max_samples_per_class", None)
|
| 341 |
+
if max_samples_per_class <= 0:
|
| 342 |
+
max_samples_per_class = None
|
| 343 |
+
indices_key = None
|
| 344 |
+
|
| 345 |
+
sigmoid = self.extra_options.get("sigmoid", False)
|
| 346 |
+
crop_add_background = self.crop_params.get("crop_add_background", False)
|
| 347 |
+
|
| 348 |
+
if crop_ratios is None:
|
| 349 |
+
crop_classes = output_classes
|
| 350 |
+
if sigmoid and crop_add_background and self.class_index is not None and len(self.class_index) > 1:
|
| 351 |
+
crop_classes = crop_classes + 1
|
| 352 |
+
else:
|
| 353 |
+
crop_classes = len(crop_ratios)
|
| 354 |
+
|
| 355 |
+
if self.debug:
|
| 356 |
+
print(
|
| 357 |
+
f"Cropping with classes {crop_classes} and crop_add_background {crop_add_background} ratios {crop_ratios}"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if cache_class_indices:
|
| 361 |
+
ts.append(
|
| 362 |
+
ClassesToIndicesd(
|
| 363 |
+
keys=self.label_key,
|
| 364 |
+
num_classes=crop_classes,
|
| 365 |
+
indices_postfix="_cls_indices",
|
| 366 |
+
max_samples_per_class=max_samples_per_class,
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
indices_key = self.label_key + "_cls_indices"
|
| 370 |
+
|
| 371 |
+
num_crops_per_image = self.crop_params.get("num_crops_per_image", 1)
|
| 372 |
+
# if num_crops_per_image > 1:
|
| 373 |
+
# print(f"Cropping with num_crops_per_image {num_crops_per_image}")
|
| 374 |
+
|
| 375 |
+
ts.append(
|
| 376 |
+
RandCropByLabelClassesd(
|
| 377 |
+
keys=keys,
|
| 378 |
+
label_key=self.label_key,
|
| 379 |
+
num_classes=crop_classes,
|
| 380 |
+
spatial_size=self.roi_size,
|
| 381 |
+
num_samples=num_crops_per_image,
|
| 382 |
+
ratios=crop_ratios,
|
| 383 |
+
indices_key=indices_key,
|
| 384 |
+
warn=False,
|
| 385 |
+
)
|
| 386 |
+
)
|
| 387 |
+
elif self.crop_mode == "rand":
|
| 388 |
+
ts.append(RandSpatialCropd(keys=keys, roi_size=self.roi_size, random_size=False))
|
| 389 |
+
else:
|
| 390 |
+
raise ValueError("Unsupported crop mode" + str(self.crop_mode))
|
| 391 |
+
|
| 392 |
+
ts.extend(self.get_custom("after_crop_transforms"))
|
| 393 |
+
|
| 394 |
+
return ts
|
| 395 |
+
|
| 396 |
+
def get_augment_transforms(self):
|
| 397 |
+
ts = self.get_custom("augment_transforms")
|
| 398 |
+
if len(ts) > 0:
|
| 399 |
+
return ts
|
| 400 |
+
|
| 401 |
+
if self.roi_size is None:
|
| 402 |
+
raise ValueError("roi_size is not specified")
|
| 403 |
+
|
| 404 |
+
augment_mode = self.augment_params.get("augment_mode", None)
|
| 405 |
+
augment_flips = self.augment_params.get("augment_flips", None)
|
| 406 |
+
augment_rots = self.augment_params.get("augment_rots", None)
|
| 407 |
+
|
| 408 |
+
if self.debug:
|
| 409 |
+
print(f"Using augment_mode {augment_mode}, augment_flips {augment_flips} augment_rots {augment_rots}")
|
| 410 |
+
|
| 411 |
+
ts = []
|
| 412 |
+
|
| 413 |
+
if augment_mode is None or augment_mode == "default":
|
| 414 |
+
|
| 415 |
+
ts.append(
|
| 416 |
+
RandAffined(
|
| 417 |
+
keys=[self.image_key, self.label_key],
|
| 418 |
+
prob=0.2,
|
| 419 |
+
rotate_range=[0.26, 0.26, 0.26],
|
| 420 |
+
scale_range=[0.2, 0.2, 0.2],
|
| 421 |
+
mode=["bilinear", "nearest"],
|
| 422 |
+
spatial_size=self.roi_size,
|
| 423 |
+
cache_grid=True,
|
| 424 |
+
padding_mode="border",
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
ts.append(
|
| 428 |
+
RandGaussianSmoothd(
|
| 429 |
+
keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0]
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
ts.append(RandScaleIntensityd(keys=self.image_key, prob=0.5, factors=0.3))
|
| 433 |
+
ts.append(RandShiftIntensityd(keys=self.image_key, prob=0.5, offsets=0.1))
|
| 434 |
+
ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1))
|
| 435 |
+
|
| 436 |
+
elif augment_mode == "none":
|
| 437 |
+
|
| 438 |
+
augment_flips = []
|
| 439 |
+
augment_rots = []
|
| 440 |
+
|
| 441 |
+
elif augment_mode == "ct_ax_1":
|
| 442 |
+
|
| 443 |
+
ts.append(RandHistogramShiftd(keys="image", prob=0.5, num_control_points=16))
|
| 444 |
+
ts.append(RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.5, 3.0]))
|
| 445 |
+
|
| 446 |
+
ts.append(
|
| 447 |
+
RandAffined(
|
| 448 |
+
keys=[self.image_key, self.label_key],
|
| 449 |
+
prob=0.5,
|
| 450 |
+
rotate_range=[0, 0, 0.26],
|
| 451 |
+
scale_range=[0.35, 0.35, 0],
|
| 452 |
+
mode=["bilinear", "nearest"],
|
| 453 |
+
spatial_size=self.roi_size,
|
| 454 |
+
cache_grid=True,
|
| 455 |
+
padding_mode="border",
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
elif augment_mode == "mri_1":
|
| 460 |
+
|
| 461 |
+
ts.append(
|
| 462 |
+
RandAffined(
|
| 463 |
+
keys=[self.image_key, self.label_key],
|
| 464 |
+
prob=0.2,
|
| 465 |
+
rotate_range=[0.26, 0.26, 0.26],
|
| 466 |
+
scale_range=[0.2, 0.2, 0.2],
|
| 467 |
+
mode=["bilinear", "nearest"],
|
| 468 |
+
spatial_size=self.roi_size,
|
| 469 |
+
cache_grid=True,
|
| 470 |
+
padding_mode="border",
|
| 471 |
+
)
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
ts.append(RandGaussianNoised(keys=self.image_key, prob=0.2, mean=0.0, std=0.1))
|
| 475 |
+
|
| 476 |
+
ts.append(
|
| 477 |
+
RandGaussianSmoothd(
|
| 478 |
+
keys=self.image_key, prob=0.2, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], sigma_z=[0.5, 1.0]
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
ts.append(RandScaleIntensityFixedMeand(keys="image", prob=0.2, fixed_mean=True, factors=0.3))
|
| 483 |
+
ts.append(
|
| 484 |
+
RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=False)
|
| 485 |
+
)
|
| 486 |
+
ts.append(
|
| 487 |
+
RandAdjustContrastd(keys="image", prob=0.2, gamma=[0.7, 1.5], retain_stats=True, invert_image=True)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
else:
|
| 491 |
+
raise ValueError("Unsupported augment_mode: " + str(augment_mode))
|
| 492 |
+
|
| 493 |
+
# default to all flips
|
| 494 |
+
if augment_flips is None:
|
| 495 |
+
augment_flips = [0, 1, 2]
|
| 496 |
+
for sa in augment_flips:
|
| 497 |
+
ts.append(RandFlipd(keys=[self.image_key, self.label_key], prob=0.5, spatial_axis=sa))
|
| 498 |
+
|
| 499 |
+
# default to no rots
|
| 500 |
+
if augment_rots is not None:
|
| 501 |
+
for sa in augment_rots:
|
| 502 |
+
ts.append(RandRotate90d(keys=[self.image_key, self.label_key], prob=0.5, spatial_axes=sa))
|
| 503 |
+
|
| 504 |
+
ts.extend(self.get_custom("after_augment_transforms"))
|
| 505 |
+
|
| 506 |
+
return ts
|
| 507 |
+
|
| 508 |
+
def get_final_transforms(self):
|
| 509 |
+
return self.get_custom("final_transforms")
|
| 510 |
+
|
| 511 |
+
@classmethod
|
| 512 |
+
def get_postprocess_transform(
|
| 513 |
+
cls,
|
| 514 |
+
save_mask=False,
|
| 515 |
+
invert=False,
|
| 516 |
+
transform=None,
|
| 517 |
+
sigmoid=False,
|
| 518 |
+
output_path=None,
|
| 519 |
+
resample=False,
|
| 520 |
+
data_root_dir="",
|
| 521 |
+
output_dtype=np.uint8,
|
| 522 |
+
save_mask_mode=None,
|
| 523 |
+
) -> Compose:
|
| 524 |
+
ts = []
|
| 525 |
+
if invert and transform is not None:
|
| 526 |
+
# if resample:
|
| 527 |
+
# ts.append(ToDeviced(keys="pred", device=torch.device("cpu")))
|
| 528 |
+
ts.append(Invertd(keys="pred", orig_keys="image", transform=transform, nearest_interp=False))
|
| 529 |
+
|
| 530 |
+
if save_mask and output_path is not None:
|
| 531 |
+
ts.append(CopyItemsd(keys="pred", times=1, names="seg"))
|
| 532 |
+
if save_mask_mode == "prob":
|
| 533 |
+
output_dtype = np.float32
|
| 534 |
+
else:
|
| 535 |
+
ts.append(
|
| 536 |
+
AsDiscreted(keys="seg", argmax=True) if not sigmoid else AsDiscreted(keys="seg", threshold=0.5)
|
| 537 |
+
)
|
| 538 |
+
ts.append(
|
| 539 |
+
SaveImaged(
|
| 540 |
+
keys=["seg"],
|
| 541 |
+
output_dir=output_path,
|
| 542 |
+
output_postfix="",
|
| 543 |
+
data_root_dir=data_root_dir,
|
| 544 |
+
output_dtype=output_dtype,
|
| 545 |
+
separate_folder=False,
|
| 546 |
+
squeeze_end_dims=True,
|
| 547 |
+
resample=False,
|
| 548 |
+
print_log=False,
|
| 549 |
+
)
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
return Compose(ts)
|
| 553 |
+
|
| 554 |
+
def __call__(self, augment=False, resample_label=False) -> Compose:
|
| 555 |
+
ts = []
|
| 556 |
+
ts.extend(self.get_load_transforms())
|
| 557 |
+
ts.extend(self.get_resample_transforms(resample_label=resample_label))
|
| 558 |
+
ts.extend(self.get_normalize_transforms())
|
| 559 |
+
|
| 560 |
+
if augment:
|
| 561 |
+
ts.extend(self.get_crop_transforms())
|
| 562 |
+
ts.extend(self.get_augment_transforms())
|
| 563 |
+
|
| 564 |
+
ts.extend(self.get_final_transforms())
|
| 565 |
+
|
| 566 |
+
compose_ts = Compose(ts)
|
| 567 |
+
|
| 568 |
+
return compose_ts
|
| 569 |
+
|
| 570 |
+
def __repr__(self) -> str:
|
| 571 |
+
out: str = f"DataTransformBuilder: with image_key: {self.image_key}, label_key: {self.label_key} \n"
|
| 572 |
+
out += f"roi_size {self.roi_size} resample {self.resample} resample_resolution {self.resample_resolution} \n"
|
| 573 |
+
out += f"normalize_mode {self.normalize_mode} normalize_params {self.normalize_params} \n"
|
| 574 |
+
out += f"crop_mode {self.crop_mode} crop_params {self.crop_params} \n"
|
| 575 |
+
out += f"extra_modalities {self.extra_modalities} \n"
|
| 576 |
+
for k, trs in self.custom_transforms.items():
|
| 577 |
+
out += f"Custom {k} : {str(trs)} \n"
|
| 578 |
+
return out
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class Segmenter:
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
config_file: Optional[Union[str, Sequence[str]]] = None,
|
| 585 |
+
config_dict: Dict = {},
|
| 586 |
+
rank: int = 0,
|
| 587 |
+
global_rank: int = 0,
|
| 588 |
+
) -> None:
|
| 589 |
+
self.rank = rank
|
| 590 |
+
self.global_rank = global_rank
|
| 591 |
+
self.distributed = dist.is_initialized()
|
| 592 |
+
|
| 593 |
+
if self.global_rank == 0:
|
| 594 |
+
print(f"Segmenter started config_file: {config_file}, config_dict: {config_dict}")
|
| 595 |
+
|
| 596 |
+
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
|
| 597 |
+
logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.WARNING)
|
| 598 |
+
|
| 599 |
+
config = self.parse_input_config(config_file=config_file, override=config_dict)
|
| 600 |
+
self.config = config
|
| 601 |
+
self.config_file = config_file if not isinstance(config_file, (list, tuple)) else config_file[0]
|
| 602 |
+
self.override = config_dict
|
| 603 |
+
|
| 604 |
+
if config["ckpt_path"] is not None and not os.path.exists(config["ckpt_path"]):
|
| 605 |
+
os.makedirs(config["ckpt_path"], exist_ok=True)
|
| 606 |
+
|
| 607 |
+
if config["log_output_file"] is None:
|
| 608 |
+
config["log_output_file"] = os.path.join(self.config["ckpt_path"], "training.log")
|
| 609 |
+
logger_configure(log_output_file=config["log_output_file"], debug=config["debug"], global_rank=self.global_rank)
|
| 610 |
+
|
| 611 |
+
if config["fork"] and "fork" in mp.get_all_start_methods():
|
| 612 |
+
mp.set_start_method("fork", force=True) # lambda functions fail to pickle without it
|
| 613 |
+
else:
|
| 614 |
+
warnings.warn(
|
| 615 |
+
"Multiprocessing method fork is not available, some non-picklable objects (e.g. lambda ) may fail"
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if config["cuda"] and torch.cuda.is_available():
|
| 619 |
+
self.device = torch.device(self.rank)
|
| 620 |
+
if self.distributed and dist.get_backend() == dist.Backend.NCCL:
|
| 621 |
+
torch.cuda.set_device(rank)
|
| 622 |
+
else:
|
| 623 |
+
self.device = torch.device("cpu")
|
| 624 |
+
|
| 625 |
+
if self.global_rank == 0:
|
| 626 |
+
print(yaml.dump(config))
|
| 627 |
+
|
| 628 |
+
if config["determ"]:
|
| 629 |
+
set_determinism(seed=0)
|
| 630 |
+
elif torch.cuda.is_available():
|
| 631 |
+
torch.backends.cudnn.benchmark = True
|
| 632 |
+
|
| 633 |
+
if config["notf32"]:
|
| 634 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 635 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 636 |
+
print(f"!!!disabling tf32")
|
| 637 |
+
if config.get("float32_precision", None) is not None:
|
| 638 |
+
torch.set_float32_matmul_precision(config["float32_precision"])
|
| 639 |
+
print(f"!!!setting matmul precession {config['float32_precision']}")
|
| 640 |
+
|
| 641 |
+
# auto adjust network settings
|
| 642 |
+
if config["auto_scale_allowed"]:
|
| 643 |
+
if config["auto_scale_batch"] or config["auto_scale_roi"] or config["auto_scale_filters"]:
|
| 644 |
+
roi_size, _, init_filters, batch_size = auto_adjust_network_settings(
|
| 645 |
+
auto_scale_batch=config["auto_scale_batch"],
|
| 646 |
+
auto_scale_roi=config["auto_scale_roi"],
|
| 647 |
+
auto_scale_filters=config["auto_scale_filters"],
|
| 648 |
+
image_size_mm=config["image_size_mm_median"],
|
| 649 |
+
spacing=config["resample_resolution"],
|
| 650 |
+
anisotropic_scales=config["anisotropic_scales"],
|
| 651 |
+
levels=len(config["network"]["blocks_down"]),
|
| 652 |
+
output_classes=config["output_classes"],
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
config["roi_size"] = roi_size
|
| 656 |
+
if config["auto_scale_batch"]:
|
| 657 |
+
config["batch_size"] = batch_size
|
| 658 |
+
if config["auto_scale_filters"] and config["pretrained_ckpt_name"] is None:
|
| 659 |
+
config["network"]["init_filters"] = init_filters
|
| 660 |
+
|
| 661 |
+
self.model = self.setup_model(pretrained_ckpt_name=config["pretrained_ckpt_name"])
|
| 662 |
+
|
| 663 |
+
loss_function = ConfigParser(config["loss"]).get_parsed_content(instantiate=True)
|
| 664 |
+
self.loss_function = DeepSupervisionLoss(loss_function)
|
| 665 |
+
|
| 666 |
+
dice_ignore_empty = config.get("dice_ignore_empty", True)
|
| 667 |
+
self.acc_function = DiceHelper(threshold=config["sigmoid"], ignore_empty=dice_ignore_empty)
|
| 668 |
+
self.amp_device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
| 669 |
+
self.grad_scaler = GradScaler(self.amp_device_type, enabled=config["amp"])
|
| 670 |
+
|
| 671 |
+
if config.get("sliding_inferrer") is not None:
|
| 672 |
+
self.sliding_inferrer = ConfigParser(config["sliding_inferrer"]).get_parsed_content()
|
| 673 |
+
else:
|
| 674 |
+
self.sliding_inferrer = SlidingWindowInfererAdapt(
|
| 675 |
+
roi_size=config["roi_size"],
|
| 676 |
+
sw_batch_size=1,
|
| 677 |
+
overlap=0.625,
|
| 678 |
+
mode="gaussian",
|
| 679 |
+
cache_roi_weight_map=True,
|
| 680 |
+
progress=False,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
self._data_transform_builder: DataTransformBuilder = None
|
| 684 |
+
self.lr_scheduler = None
|
| 685 |
+
self.optimizer = None
|
| 686 |
+
|
| 687 |
+
def get_custom_transforms(self):
|
| 688 |
+
config = self.config
|
| 689 |
+
|
| 690 |
+
# check for custom transforms
|
| 691 |
+
custom_transforms = {}
|
| 692 |
+
for tr in config.get("custom_data_transforms", []):
|
| 693 |
+
must_include_keys = ("key", "path", "transform")
|
| 694 |
+
if not all(k in tr for k in must_include_keys):
|
| 695 |
+
raise ValueError("custom transform must include " + str(must_include_keys))
|
| 696 |
+
|
| 697 |
+
if os.path.abspath(tr["path"]) not in sys.path:
|
| 698 |
+
sys.path.append(os.path.abspath(tr["path"]))
|
| 699 |
+
|
| 700 |
+
custom_transforms.setdefault(tr["key"], [])
|
| 701 |
+
custom_transforms[tr["key"]].append(tr["transform"])
|
| 702 |
+
|
| 703 |
+
if len(custom_transforms) > 0 and self.global_rank == 0:
|
| 704 |
+
print(f"Using custom transforms {custom_transforms}")
|
| 705 |
+
|
| 706 |
+
if isinstance(config["class_index"], list) and len(config["class_index"]) > 0:
|
| 707 |
+
# custom label embedding, if class_index provided
|
| 708 |
+
custom_transforms.setdefault("final_transforms", [])
|
| 709 |
+
custom_transforms["final_transforms"].append(
|
| 710 |
+
LabelEmbedClassIndex(keys="label", class_index=config["class_index"], allow_missing_keys=True)
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
return custom_transforms
|
| 714 |
+
|
| 715 |
+
def get_data_transform_builder(self):
|
| 716 |
+
if self._data_transform_builder is None:
|
| 717 |
+
config = self.config
|
| 718 |
+
custom_transforms = self.get_custom_transforms()
|
| 719 |
+
|
| 720 |
+
self._data_transform_builder = DataTransformBuilder(
|
| 721 |
+
roi_size=config["roi_size"],
|
| 722 |
+
resample=config["resample"],
|
| 723 |
+
resample_resolution=config["resample_resolution"],
|
| 724 |
+
normalize_mode=config["normalize_mode"],
|
| 725 |
+
normalize_params={
|
| 726 |
+
"intensity_bounds": config["intensity_bounds"],
|
| 727 |
+
"label_dtype": torch.uint8 if config["input_channels"] < 255 else torch.int16,
|
| 728 |
+
"image_dtype": torch.int16 if config.get("cache_image_int16", False) else None,
|
| 729 |
+
},
|
| 730 |
+
crop_mode=config["crop_mode"],
|
| 731 |
+
crop_params={
|
| 732 |
+
"output_classes": config["output_classes"],
|
| 733 |
+
"input_channels": config["input_channels"],
|
| 734 |
+
"crop_ratios": config["crop_ratios"],
|
| 735 |
+
"cache_class_indices": config["cache_class_indices"],
|
| 736 |
+
"num_crops_per_image": config["num_crops_per_image"],
|
| 737 |
+
"max_samples_per_class": config["max_samples_per_class"],
|
| 738 |
+
"crop_add_background": config["crop_add_background"],
|
| 739 |
+
},
|
| 740 |
+
augment_params={
|
| 741 |
+
"augment_mode": config.get("augment_mode", None),
|
| 742 |
+
"augment_flips": config.get("augment_flips", None),
|
| 743 |
+
"augment_rots": config.get("augment_rots", None),
|
| 744 |
+
},
|
| 745 |
+
extra_modalities=config["extra_modalities"],
|
| 746 |
+
custom_transforms=custom_transforms,
|
| 747 |
+
crop_foreground=config.get("crop_foreground", True),
|
| 748 |
+
sigmoid=config["sigmoid"],
|
| 749 |
+
orientation_ras=config.get("orientation_ras", False),
|
| 750 |
+
class_index=config["class_index"],
|
| 751 |
+
debug=config["debug"],
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
return self._data_transform_builder
|
| 755 |
+
|
| 756 |
+
def setup_model(self, pretrained_ckpt_name=None):
|
| 757 |
+
config = self.config
|
| 758 |
+
spatial_dims = config["network"].get("spatial_dims", 3)
|
| 759 |
+
norm_name, norm_args = split_args(config["network"].get("norm", ""))
|
| 760 |
+
norm_name = norm_name.upper()
|
| 761 |
+
|
| 762 |
+
if norm_name == "INSTANCE_NVFUSER":
|
| 763 |
+
_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
|
| 764 |
+
if has_nvfuser and spatial_dims == 3:
|
| 765 |
+
act = config["network"].get("act", "relu")
|
| 766 |
+
if isinstance(act, str):
|
| 767 |
+
config["network"]["act"] = [act, {"inplace": False}]
|
| 768 |
+
else:
|
| 769 |
+
norm_name = "INSTANCE"
|
| 770 |
+
|
| 771 |
+
if len(norm_name) > 0:
|
| 772 |
+
config["network"]["norm"] = norm_name if len(norm_args) == 0 else [norm_name, norm_args]
|
| 773 |
+
|
| 774 |
+
if spatial_dims == 3:
|
| 775 |
+
if config.get("anisotropic_scales", False) and "SegResNetDS" in config["network"]["_target_"]:
|
| 776 |
+
config["network"]["resolution"] = copy.deepcopy(config["resample_resolution"])
|
| 777 |
+
if self.global_rank == 0:
|
| 778 |
+
print(f"Using anisotropic scales {config['network']}")
|
| 779 |
+
|
| 780 |
+
model = ConfigParser(config["network"]).get_parsed_content()
|
| 781 |
+
|
| 782 |
+
if self.global_rank == 0:
|
| 783 |
+
print(str(model))
|
| 784 |
+
|
| 785 |
+
if pretrained_ckpt_name is not None:
|
| 786 |
+
self.checkpoint_load(ckpt=pretrained_ckpt_name, model=model)
|
| 787 |
+
|
| 788 |
+
model = model.to(self.device)
|
| 789 |
+
|
| 790 |
+
if spatial_dims == 3:
|
| 791 |
+
memory_format = torch.channels_last_3d if config["channels_last"] else torch.preserve_format
|
| 792 |
+
model = model.to(memory_format=memory_format)
|
| 793 |
+
|
| 794 |
+
if self.distributed and not config["infer"]["enabled"]:
|
| 795 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 796 |
+
model = DistributedDataParallel(
|
| 797 |
+
module=model, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=False
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
if self.global_rank == 0:
|
| 801 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 802 |
+
print(f"Total parameters count: {pytorch_total_params} distributed: {self.distributed}")
|
| 803 |
+
|
| 804 |
+
return model
|
| 805 |
+
|
| 806 |
+
def parse_input_config(
|
| 807 |
+
self, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {}
|
| 808 |
+
) -> Tuple[ConfigParser, Dict]:
|
| 809 |
+
config = {}
|
| 810 |
+
if config_file is None or override.get("use_ckpt_config", False):
|
| 811 |
+
# attempt to load config from model ckpt file
|
| 812 |
+
for ckpt_key in ["pretrained_ckpt_name", "validate#ckpt_name", "infer#ckpt_name", "finetune#ckpt_name"]:
|
| 813 |
+
ckpt = override.get(ckpt_key, None)
|
| 814 |
+
if ckpt and os.path.exists(ckpt):
|
| 815 |
+
checkpoint = torch.load(ckpt, map_location="cpu")
|
| 816 |
+
config = checkpoint.get("config", {})
|
| 817 |
+
if self.global_rank == 0:
|
| 818 |
+
print(f"Initializing config from the checkpoint {ckpt}: {yaml.dump(config)}")
|
| 819 |
+
|
| 820 |
+
if len(config) == 0 and config_file is None:
|
| 821 |
+
warnings.warn("No input config_file provided, and no valid checkpoints found")
|
| 822 |
+
|
| 823 |
+
if config_file is not None and len(config) == 0:
|
| 824 |
+
config = ConfigParser.load_config_files(config_file)
|
| 825 |
+
config.setdefault("finetune", {"enabled": False, "ckpt_name": None})
|
| 826 |
+
config.setdefault(
|
| 827 |
+
"validate", {"enabled": False, "ckpt_name": None, "save_mask": False, "output_path": None}
|
| 828 |
+
)
|
| 829 |
+
config.setdefault("infer", {"enabled": False, "ckpt_name": None})
|
| 830 |
+
|
| 831 |
+
parser = ConfigParser(config=config)
|
| 832 |
+
parser.update(pairs=override)
|
| 833 |
+
config = parser.config # just in case
|
| 834 |
+
|
| 835 |
+
if config.get("data_file_base_dir", None) is None or config.get("data_list_file_path", None) is None:
|
| 836 |
+
raise ValueError("CONFIG: data_file_base_dir and data_list_file_path must be provided")
|
| 837 |
+
|
| 838 |
+
if config.get("bundle_root", None) is None:
|
| 839 |
+
config["bundle_root"] = str(Path(__file__).parent.parent)
|
| 840 |
+
|
| 841 |
+
if "modality" not in config:
|
| 842 |
+
if self.global_rank == 0:
|
| 843 |
+
warnings.warn("CONFIG: modality is not provided, assuming MRI")
|
| 844 |
+
config["modality"] = "mri"
|
| 845 |
+
|
| 846 |
+
if "normalize_mode" not in config:
|
| 847 |
+
config["normalize_mode"] = "range" if config["modality"].lower() == "ct" else "meanstd"
|
| 848 |
+
if self.global_rank == 0:
|
| 849 |
+
print(f"CONFIG: normalize_mode is not provided, assuming: {config['normalize_mode']}")
|
| 850 |
+
|
| 851 |
+
# assign defaults
|
| 852 |
+
config.setdefault("debug", False)
|
| 853 |
+
|
| 854 |
+
config.setdefault("loss", None)
|
| 855 |
+
config.setdefault("acc", None)
|
| 856 |
+
config.setdefault("amp", True)
|
| 857 |
+
config.setdefault("cuda", True)
|
| 858 |
+
config.setdefault("fold", 0)
|
| 859 |
+
config.setdefault("batch_size", 1)
|
| 860 |
+
config.setdefault("determ", False)
|
| 861 |
+
config.setdefault("quick", False)
|
| 862 |
+
config.setdefault("sigmoid", False)
|
| 863 |
+
config.setdefault("cache_rate", None)
|
| 864 |
+
config.setdefault("cache_class_indices", None)
|
| 865 |
+
config.setdefault("crop_add_background", True)
|
| 866 |
+
config.setdefault("orientation_ras", False)
|
| 867 |
+
|
| 868 |
+
config.setdefault("channels_last", True)
|
| 869 |
+
config.setdefault("fork", True)
|
| 870 |
+
|
| 871 |
+
config.setdefault("num_epochs", 300)
|
| 872 |
+
config.setdefault("num_warmup_epochs", 3)
|
| 873 |
+
config.setdefault("num_epochs_per_validation", None)
|
| 874 |
+
config.setdefault("num_epochs_per_saving", 10)
|
| 875 |
+
config.setdefault("num_steps_per_image", None)
|
| 876 |
+
config.setdefault("num_crops_per_image", 1)
|
| 877 |
+
config.setdefault("max_samples_per_class", None)
|
| 878 |
+
|
| 879 |
+
config.setdefault("calc_val_loss", False)
|
| 880 |
+
config.setdefault("validate_final_original_res", False)
|
| 881 |
+
config.setdefault("early_stopping_fraction", 0)
|
| 882 |
+
config.setdefault("start_epoch", 0)
|
| 883 |
+
|
| 884 |
+
config.setdefault("ckpt_path", None)
|
| 885 |
+
config.setdefault("ckpt_save", True)
|
| 886 |
+
config.setdefault("log_output_file", None)
|
| 887 |
+
|
| 888 |
+
config.setdefault("crop_mode", "ratio")
|
| 889 |
+
config.setdefault("crop_ratios", None)
|
| 890 |
+
config.setdefault("resample_resolution", [1.0, 1.0, 1.0])
|
| 891 |
+
config.setdefault("resample", False)
|
| 892 |
+
config.setdefault("roi_size", [128, 128, 128])
|
| 893 |
+
config.setdefault("num_workers", 4)
|
| 894 |
+
config.setdefault("extra_modalities", {})
|
| 895 |
+
config.setdefault("intensity_bounds", [-250, 250])
|
| 896 |
+
config.setdefault("stop_on_lowacc", True)
|
| 897 |
+
|
| 898 |
+
config.setdefault("float32_precision", None)
|
| 899 |
+
config.setdefault("notf32", False)
|
| 900 |
+
|
| 901 |
+
config.setdefault("class_index", None)
|
| 902 |
+
config.setdefault("class_names", [])
|
| 903 |
+
if not isinstance(config["class_names"], (list, tuple)):
|
| 904 |
+
config["class_names"] = []
|
| 905 |
+
|
| 906 |
+
if len(config["class_names"]) == 0:
|
| 907 |
+
n_foreground_classes = int(config["output_classes"])
|
| 908 |
+
if not config["sigmoid"]:
|
| 909 |
+
n_foreground_classes -= 1
|
| 910 |
+
config["class_names"] = ["acc_" + str(i) for i in range(n_foreground_classes)]
|
| 911 |
+
|
| 912 |
+
pretrained_ckpt_name = config.get("pretrained_ckpt_name", None)
|
| 913 |
+
if pretrained_ckpt_name is None:
|
| 914 |
+
if config["validate"]["enabled"]:
|
| 915 |
+
pretrained_ckpt_name = config["validate"]["ckpt_name"]
|
| 916 |
+
elif config["infer"]["enabled"]:
|
| 917 |
+
pretrained_ckpt_name = config["infer"]["ckpt_name"]
|
| 918 |
+
elif config["finetune"]["enabled"]:
|
| 919 |
+
pretrained_ckpt_name = config["finetune"]["ckpt_name"]
|
| 920 |
+
config["pretrained_ckpt_name"] = pretrained_ckpt_name
|
| 921 |
+
|
| 922 |
+
config.setdefault("auto_scale_allowed", False)
|
| 923 |
+
config.setdefault("auto_scale_batch", False)
|
| 924 |
+
config.setdefault("auto_scale_roi", False)
|
| 925 |
+
config.setdefault("auto_scale_filters", False)
|
| 926 |
+
|
| 927 |
+
if pretrained_ckpt_name is not None:
|
| 928 |
+
config["auto_scale_roi"] = False
|
| 929 |
+
config["auto_scale_filters"] = False
|
| 930 |
+
|
| 931 |
+
if config["max_samples_per_class"] is None:
|
| 932 |
+
config["max_samples_per_class"] = 10 * config["num_epochs"]
|
| 933 |
+
|
| 934 |
+
if not torch.cuda.is_available() and config["cuda"]:
|
| 935 |
+
print("No cuda is available.! Running on CPU!!!")
|
| 936 |
+
config["cuda"] = False
|
| 937 |
+
|
| 938 |
+
config["amp"] = config["amp"] and config["cuda"]
|
| 939 |
+
config["rank"] = self.rank
|
| 940 |
+
config["global_rank"] = self.global_rank
|
| 941 |
+
|
| 942 |
+
# resolve content
|
| 943 |
+
for k, v in config.items():
|
| 944 |
+
if isinstance(v, dict) and "_target_" in v:
|
| 945 |
+
config[k] = parser.get_parsed_content(k, instantiate=False).config
|
| 946 |
+
elif "_target_" in str(v):
|
| 947 |
+
config[k] = copy.deepcopy(v)
|
| 948 |
+
else:
|
| 949 |
+
config[k] = parser.get_parsed_content(k)
|
| 950 |
+
|
| 951 |
+
return config
|
| 952 |
+
|
| 953 |
+
def config_save_updated(self, save_path=None):
|
| 954 |
+
if self.global_rank == 0 and self.config["auto_scale_allowed"]:
|
| 955 |
+
# reload input config
|
| 956 |
+
config = ConfigParser.load_config_files(self.config_file)
|
| 957 |
+
parser = ConfigParser(config=config)
|
| 958 |
+
parser.update(pairs=self.override)
|
| 959 |
+
config = parser.config
|
| 960 |
+
|
| 961 |
+
config["batch_size"] = self.config["batch_size"]
|
| 962 |
+
config["roi_size"] = self.config["roi_size"]
|
| 963 |
+
config["num_crops_per_image"] = self.config["num_crops_per_image"]
|
| 964 |
+
|
| 965 |
+
if "init_filters" in self.config["network"]:
|
| 966 |
+
config["network"]["init_filters"] = self.config["network"]["init_filters"]
|
| 967 |
+
|
| 968 |
+
if save_path is None:
|
| 969 |
+
save_path = self.config_file
|
| 970 |
+
|
| 971 |
+
print(f"Re-saving main config to {save_path}.")
|
| 972 |
+
ConfigParser.export_config_file(config, save_path, fmt="yaml", default_flow_style=None, sort_keys=False)
|
| 973 |
+
|
| 974 |
+
def config_with_relpath(self, config=None):
|
| 975 |
+
if config is None:
|
| 976 |
+
config = self.config
|
| 977 |
+
config = copy.deepcopy(config)
|
| 978 |
+
bundle_root = config["bundle_root"]
|
| 979 |
+
|
| 980 |
+
def convert_rel_path(conf):
|
| 981 |
+
for k, v in conf.items():
|
| 982 |
+
if isinstance(v, str) and v.startswith(bundle_root):
|
| 983 |
+
conf[k] = f"$@bundle_root + '/{os.path.relpath(v, bundle_root)}'"
|
| 984 |
+
|
| 985 |
+
convert_rel_path(config)
|
| 986 |
+
convert_rel_path(config["finetune"])
|
| 987 |
+
convert_rel_path(config["validate"])
|
| 988 |
+
convert_rel_path(config["infer"])
|
| 989 |
+
config["bundle_root"] = bundle_root
|
| 990 |
+
|
| 991 |
+
return config
|
| 992 |
+
|
| 993 |
+
def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs):
|
| 994 |
+
save_time = time.time()
|
| 995 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
| 996 |
+
state_dict = model.module.state_dict()
|
| 997 |
+
else:
|
| 998 |
+
state_dict = model.state_dict()
|
| 999 |
+
|
| 1000 |
+
config = self.config_with_relpath()
|
| 1001 |
+
|
| 1002 |
+
torch.save({"state_dict": state_dict, "config": config, **kwargs}, ckpt)
|
| 1003 |
+
|
| 1004 |
+
save_time = time.time() - save_time
|
| 1005 |
+
print(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s")
|
| 1006 |
+
|
| 1007 |
+
return save_time
|
| 1008 |
+
|
| 1009 |
+
def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs):
|
| 1010 |
+
if not os.path.isfile(ckpt):
|
| 1011 |
+
if self.global_rank == 0:
|
| 1012 |
+
warnings.warn("Invalid checkpoint file: " + str(ckpt))
|
| 1013 |
+
else:
|
| 1014 |
+
checkpoint = torch.load(ckpt, map_location="cpu")
|
| 1015 |
+
model.load_state_dict(checkpoint["state_dict"], strict=True)
|
| 1016 |
+
epoch = checkpoint.get("epoch", 0)
|
| 1017 |
+
best_metric = checkpoint.get("best_metric", 0)
|
| 1018 |
+
|
| 1019 |
+
if self.config.pop("continue", False):
|
| 1020 |
+
if "epoch" in checkpoint:
|
| 1021 |
+
self.config["start_epoch"] = checkpoint["epoch"]
|
| 1022 |
+
if "best_metric" in checkpoint:
|
| 1023 |
+
self.config["best_metric"] = checkpoint["best_metric"]
|
| 1024 |
+
|
| 1025 |
+
print(
|
| 1026 |
+
f"=> loaded checkpoint {ckpt} (epoch {epoch}) (best_metric {best_metric}) setting start_epoch {self.config['start_epoch']}"
|
| 1027 |
+
)
|
| 1028 |
+
self.config["start_epoch"] = int(self.config["start_epoch"]) + 1
|
| 1029 |
+
|
| 1030 |
+
def get_shared_memory_list(self, length=0):
|
| 1031 |
+
mp.current_process().authkey = np.arange(32, dtype=np.uint8).tobytes()
|
| 1032 |
+
shl0 = mp.Manager().list([None] * length)
|
| 1033 |
+
|
| 1034 |
+
if self.distributed:
|
| 1035 |
+
# to support multi-node training, we need check for a local process group
|
| 1036 |
+
is_multinode = False
|
| 1037 |
+
|
| 1038 |
+
if dist_launched():
|
| 1039 |
+
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE"))
|
| 1040 |
+
world_size = int(os.getenv("WORLD_SIZE"))
|
| 1041 |
+
group_rank = int(os.getenv("GROUP_RANK"))
|
| 1042 |
+
if world_size > local_world_size:
|
| 1043 |
+
is_multinode = True
|
| 1044 |
+
# we're in multi-node, get local world sizes
|
| 1045 |
+
lw = torch.tensor(local_world_size, dtype=torch.int, device=self.device)
|
| 1046 |
+
lw_sizes = [torch.zeros_like(lw) for _ in range(world_size)]
|
| 1047 |
+
dist.all_gather(tensor_list=lw_sizes, tensor=lw)
|
| 1048 |
+
|
| 1049 |
+
src = g_rank = 0
|
| 1050 |
+
while src < world_size:
|
| 1051 |
+
# create sub-groups local to a node, to share memory only within a node
|
| 1052 |
+
# and broadcast shared list within a node
|
| 1053 |
+
group = dist.new_group(ranks=list(range(src, src + local_world_size)))
|
| 1054 |
+
if group_rank == g_rank:
|
| 1055 |
+
shl_list = [shl0]
|
| 1056 |
+
dist.broadcast_object_list(shl_list, src=src, group=group, device=self.device)
|
| 1057 |
+
shl = shl_list[0]
|
| 1058 |
+
dist.destroy_process_group(group)
|
| 1059 |
+
src = src + lw_sizes[src].item() # rank of first process in the next node
|
| 1060 |
+
g_rank += 1
|
| 1061 |
+
|
| 1062 |
+
if not is_multinode:
|
| 1063 |
+
shl_list = [shl0]
|
| 1064 |
+
dist.broadcast_object_list(shl_list, src=0, device=self.device)
|
| 1065 |
+
shl = shl_list[0]
|
| 1066 |
+
|
| 1067 |
+
else:
|
| 1068 |
+
shl = shl0
|
| 1069 |
+
|
| 1070 |
+
return shl
|
| 1071 |
+
|
| 1072 |
+
def get_train_loader(self, data, cache_rate=0, persistent_workers=False):
|
| 1073 |
+
distributed = self.distributed
|
| 1074 |
+
num_workers = self.config["num_workers"]
|
| 1075 |
+
batch_size = self.config["batch_size"]
|
| 1076 |
+
|
| 1077 |
+
train_transform = self.get_data_transform_builder()(augment=True, resample_label=True)
|
| 1078 |
+
|
| 1079 |
+
if cache_rate > 0:
|
| 1080 |
+
runtime_cache = self.get_shared_memory_list(length=len(data))
|
| 1081 |
+
train_ds = CacheDataset(
|
| 1082 |
+
data=data,
|
| 1083 |
+
transform=train_transform,
|
| 1084 |
+
copy_cache=False,
|
| 1085 |
+
cache_rate=cache_rate,
|
| 1086 |
+
runtime_cache=runtime_cache,
|
| 1087 |
+
)
|
| 1088 |
+
else:
|
| 1089 |
+
train_ds = Dataset(data=data, transform=train_transform)
|
| 1090 |
+
|
| 1091 |
+
train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
|
| 1092 |
+
train_loader = DataLoader(
|
| 1093 |
+
train_ds,
|
| 1094 |
+
batch_size=batch_size,
|
| 1095 |
+
shuffle=(train_sampler is None),
|
| 1096 |
+
num_workers=num_workers,
|
| 1097 |
+
sampler=train_sampler,
|
| 1098 |
+
persistent_workers=persistent_workers and num_workers > 0,
|
| 1099 |
+
pin_memory=True,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
return train_loader
|
| 1103 |
+
|
| 1104 |
+
def get_val_loader(self, data, cache_rate=0, resample_label=False, persistent_workers=False):
|
| 1105 |
+
distributed = self.distributed
|
| 1106 |
+
num_workers = self.config["num_workers"]
|
| 1107 |
+
|
| 1108 |
+
val_transform = self.get_data_transform_builder()(augment=False, resample_label=resample_label)
|
| 1109 |
+
|
| 1110 |
+
if cache_rate > 0:
|
| 1111 |
+
runtime_cache = self.get_shared_memory_list(length=len(data))
|
| 1112 |
+
val_ds = CacheDataset(
|
| 1113 |
+
data=data, transform=val_transform, copy_cache=False, cache_rate=cache_rate, runtime_cache=runtime_cache
|
| 1114 |
+
)
|
| 1115 |
+
else:
|
| 1116 |
+
val_ds = Dataset(data=data, transform=val_transform)
|
| 1117 |
+
|
| 1118 |
+
val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None
|
| 1119 |
+
val_loader = DataLoader(
|
| 1120 |
+
val_ds,
|
| 1121 |
+
batch_size=1,
|
| 1122 |
+
shuffle=False,
|
| 1123 |
+
num_workers=num_workers,
|
| 1124 |
+
sampler=val_sampler,
|
| 1125 |
+
persistent_workers=persistent_workers and num_workers > 0,
|
| 1126 |
+
pin_memory=True,
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
return val_loader
|
| 1130 |
+
|
| 1131 |
+
def train(self):
|
| 1132 |
+
if self.global_rank == 0:
|
| 1133 |
+
print("Segmenter train called")
|
| 1134 |
+
|
| 1135 |
+
if self.loss_function is None:
|
| 1136 |
+
raise ValueError("CONFIG loss function is not provided")
|
| 1137 |
+
if self.acc_function is None:
|
| 1138 |
+
raise ValueError("CONFIG accuracy function is not provided")
|
| 1139 |
+
|
| 1140 |
+
config = self.config
|
| 1141 |
+
distributed = self.distributed
|
| 1142 |
+
sliding_inferrer = self.sliding_inferrer
|
| 1143 |
+
|
| 1144 |
+
loss_function = self.loss_function
|
| 1145 |
+
acc_function = self.acc_function
|
| 1146 |
+
grad_scaler = self.grad_scaler
|
| 1147 |
+
|
| 1148 |
+
use_amp = config["amp"]
|
| 1149 |
+
use_cuda = config["cuda"]
|
| 1150 |
+
ckpt_path = config["ckpt_path"]
|
| 1151 |
+
sigmoid = config["sigmoid"]
|
| 1152 |
+
channels_last = config["channels_last"]
|
| 1153 |
+
calc_val_loss = config["calc_val_loss"]
|
| 1154 |
+
|
| 1155 |
+
data_list_file_path = config["data_list_file_path"]
|
| 1156 |
+
if not os.path.isabs(data_list_file_path):
|
| 1157 |
+
data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path))
|
| 1158 |
+
|
| 1159 |
+
if config.get("validation_key", None) is not None:
|
| 1160 |
+
train_files, _ = datafold_read(datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=-1)
|
| 1161 |
+
validation_files, _ = datafold_read(
|
| 1162 |
+
datalist=data_list_file_path,
|
| 1163 |
+
basedir=config["data_file_base_dir"],
|
| 1164 |
+
fold=-1,
|
| 1165 |
+
key=config["validation_key"],
|
| 1166 |
+
)
|
| 1167 |
+
else:
|
| 1168 |
+
train_files, validation_files = datafold_read(
|
| 1169 |
+
datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"]
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
if config["quick"]: # quick run on a smaller subset of files
|
| 1173 |
+
train_files, validation_files = train_files[:8], validation_files[:8]
|
| 1174 |
+
if self.global_rank == 0:
|
| 1175 |
+
print(f"train_files files {len(train_files)}, validation files {len(validation_files)}")
|
| 1176 |
+
|
| 1177 |
+
if len(validation_files) == 0:
|
| 1178 |
+
warnings.warn("No validation files found!")
|
| 1179 |
+
|
| 1180 |
+
cache_rate_train, cache_rate_val = self.get_cache_rate(
|
| 1181 |
+
train_cases=len(train_files), validation_cases=len(validation_files)
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
if config["cache_class_indices"] is None:
|
| 1185 |
+
config["cache_class_indices"] = cache_rate_train > 0
|
| 1186 |
+
|
| 1187 |
+
if self.global_rank == 0:
|
| 1188 |
+
print(
|
| 1189 |
+
f"Auto setting max_samples_per_class: {config['max_samples_per_class']} cache_class_indices: {config['cache_class_indices']}"
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
num_steps_per_image = config["num_steps_per_image"]
|
| 1193 |
+
if config["auto_scale_allowed"] and num_steps_per_image is None:
|
| 1194 |
+
be = config["batch_size"]
|
| 1195 |
+
|
| 1196 |
+
if config["crop_mode"] == "ratio":
|
| 1197 |
+
config["num_crops_per_image"] = config["batch_size"]
|
| 1198 |
+
config["batch_size"] = 1
|
| 1199 |
+
else:
|
| 1200 |
+
config["num_crops_per_image"] = 1
|
| 1201 |
+
|
| 1202 |
+
if cache_rate_train < 0.75:
|
| 1203 |
+
num_steps_per_image = max(1, 4 // be)
|
| 1204 |
+
else:
|
| 1205 |
+
num_steps_per_image = 1
|
| 1206 |
+
|
| 1207 |
+
elif num_steps_per_image is None:
|
| 1208 |
+
num_steps_per_image = 1
|
| 1209 |
+
|
| 1210 |
+
num_crops_per_image = int(config["num_crops_per_image"])
|
| 1211 |
+
num_epochs_per_saving = max(1, config["num_epochs_per_saving"] // num_crops_per_image)
|
| 1212 |
+
num_warmup_epochs = max(3, config["num_warmup_epochs"] // num_crops_per_image)
|
| 1213 |
+
num_epochs_per_validation = config["num_epochs_per_validation"]
|
| 1214 |
+
num_epochs = max(1, config["num_epochs"] // min(3, num_crops_per_image))
|
| 1215 |
+
if self.global_rank == 0:
|
| 1216 |
+
print(
|
| 1217 |
+
f"Given num_crops_per_image {num_crops_per_image}, num_epochs was adjusted {config['num_epochs']} => {num_epochs}"
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
if num_epochs_per_validation is not None:
|
| 1221 |
+
num_epochs_per_validation = max(1, num_epochs_per_validation // num_crops_per_image)
|
| 1222 |
+
|
| 1223 |
+
val_schedule_list = schedule_validation_epochs(
|
| 1224 |
+
num_epochs=num_epochs,
|
| 1225 |
+
num_epochs_per_validation=num_epochs_per_validation,
|
| 1226 |
+
fraction=min(0.3, 0.16 * num_crops_per_image),
|
| 1227 |
+
)
|
| 1228 |
+
if self.global_rank == 0:
|
| 1229 |
+
print(f"Scheduling validation loops at epochs: {val_schedule_list}")
|
| 1230 |
+
|
| 1231 |
+
train_loader = self.get_train_loader(data=train_files, cache_rate=cache_rate_train, persistent_workers=True)
|
| 1232 |
+
|
| 1233 |
+
val_loader = self.get_val_loader(
|
| 1234 |
+
data=validation_files, cache_rate=cache_rate_val, resample_label=True, persistent_workers=True
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
optim_name = config.get("optim_name", None) # experimental
|
| 1238 |
+
if optim_name is not None:
|
| 1239 |
+
if self.global_rank == 0:
|
| 1240 |
+
print(f"Using optimizer: {optim_name}")
|
| 1241 |
+
if optim_name == "fusednovograd":
|
| 1242 |
+
import apex
|
| 1243 |
+
|
| 1244 |
+
optimizer = apex.optimizers.FusedNovoGrad(
|
| 1245 |
+
params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5
|
| 1246 |
+
)
|
| 1247 |
+
elif optim_name == "sgd":
|
| 1248 |
+
momentum = config.get("sgd_momentum", 0.9)
|
| 1249 |
+
optimizer = torch.optim.SGD(
|
| 1250 |
+
params=self.model.parameters(), lr=config["learning_rate"], weight_decay=1.0e-5, momentum=momentum
|
| 1251 |
+
)
|
| 1252 |
+
if self.global_rank == 0:
|
| 1253 |
+
print(f"Using momentum: {momentum}")
|
| 1254 |
+
else:
|
| 1255 |
+
raise ValueError("Unsupported optim_name" + str(optim_name))
|
| 1256 |
+
|
| 1257 |
+
elif self.optimizer is None:
|
| 1258 |
+
optimizer_part = ConfigParser(config["optimizer"]).get_parsed_content(instantiate=False)
|
| 1259 |
+
optimizer = optimizer_part.instantiate(params=self.model.parameters())
|
| 1260 |
+
else:
|
| 1261 |
+
optimizer = self.optimizer
|
| 1262 |
+
|
| 1263 |
+
tb_writer = None
|
| 1264 |
+
csv_path = progress_path = None
|
| 1265 |
+
|
| 1266 |
+
if self.global_rank == 0 and ckpt_path is not None:
|
| 1267 |
+
# rank 0 is responsible for heavy lifting of logging/saving
|
| 1268 |
+
progress_path = os.path.join(ckpt_path, "progress.yaml")
|
| 1269 |
+
|
| 1270 |
+
tb_writer = SummaryWriter(log_dir=ckpt_path)
|
| 1271 |
+
print(f"Writing Tensorboard logs to {tb_writer.log_dir}")
|
| 1272 |
+
|
| 1273 |
+
if mlflow_is_imported:
|
| 1274 |
+
mlflow.set_tracking_uri(config["mlflow_tracking_uri"])
|
| 1275 |
+
mlflow.set_experiment(config["mlflow_experiment_name"])
|
| 1276 |
+
mlflow.start_run(run_name=f'segresnet - fold{config["fold"]} - train')
|
| 1277 |
+
|
| 1278 |
+
csv_path = os.path.join(ckpt_path, "accuracy_history.csv")
|
| 1279 |
+
self.save_history_csv(
|
| 1280 |
+
csv_path=csv_path,
|
| 1281 |
+
header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"],
|
| 1282 |
+
)
|
| 1283 |
+
|
| 1284 |
+
do_torch_save = (self.global_rank == 0) and ckpt_path is not None and config["ckpt_save"]
|
| 1285 |
+
best_ckpt_path = os.path.join(ckpt_path, "model.pt")
|
| 1286 |
+
intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt")
|
| 1287 |
+
|
| 1288 |
+
best_metric = -1
|
| 1289 |
+
best_metric_epoch = -1
|
| 1290 |
+
pre_loop_time = time.time()
|
| 1291 |
+
report_num_epochs = num_epochs * num_crops_per_image
|
| 1292 |
+
train_time = validation_time = 0
|
| 1293 |
+
val_acc_history = []
|
| 1294 |
+
|
| 1295 |
+
start_epoch = config["start_epoch"]
|
| 1296 |
+
if "best_metric" in config:
|
| 1297 |
+
best_metric = float(config["best_metric"])
|
| 1298 |
+
|
| 1299 |
+
start_epoch = start_epoch // num_crops_per_image
|
| 1300 |
+
if start_epoch > 0:
|
| 1301 |
+
val_schedule_list = [v for v in val_schedule_list if v >= start_epoch]
|
| 1302 |
+
if len(val_schedule_list) == 0:
|
| 1303 |
+
val_schedule_list = [start_epoch]
|
| 1304 |
+
print(f"adjusted schedule_list {val_schedule_list}")
|
| 1305 |
+
|
| 1306 |
+
if self.global_rank == 0:
|
| 1307 |
+
print(
|
| 1308 |
+
f"Using num_epochs => {num_epochs}\n "
|
| 1309 |
+
f"Using start_epoch => {start_epoch}\n "
|
| 1310 |
+
f"batch_size => {config['batch_size']} \n "
|
| 1311 |
+
f"num_crops_per_image => {config['num_crops_per_image']} \n "
|
| 1312 |
+
f"num_steps_per_image => {num_steps_per_image} \n "
|
| 1313 |
+
f"num_warmup_epochs => {num_warmup_epochs} \n "
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
if self.lr_scheduler is None:
|
| 1317 |
+
lr_scheduler = WarmupCosineSchedule(
|
| 1318 |
+
optimizer=optimizer, warmup_steps=num_warmup_epochs, warmup_multiplier=0.1, t_total=num_epochs
|
| 1319 |
+
)
|
| 1320 |
+
else:
|
| 1321 |
+
lr_scheduler = self.lr_scheduler
|
| 1322 |
+
if lr_scheduler is not None and start_epoch > 0:
|
| 1323 |
+
lr_scheduler.last_epoch = start_epoch
|
| 1324 |
+
|
| 1325 |
+
range_num_epochs = range(start_epoch, num_epochs)
|
| 1326 |
+
if self.global_rank == 0 and has_tqdm and not config["debug"]:
|
| 1327 |
+
range_num_epochs = tqdm(
|
| 1328 |
+
range(start_epoch, num_epochs),
|
| 1329 |
+
desc=str(os.path.basename(config["bundle_root"])) + " - training",
|
| 1330 |
+
unit="epoch",
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
if distributed:
|
| 1334 |
+
dist.barrier()
|
| 1335 |
+
self.config_save_updated(save_path=self.config_file) # overwriting main input config
|
| 1336 |
+
|
| 1337 |
+
for epoch in range_num_epochs:
|
| 1338 |
+
report_epoch = epoch * num_crops_per_image
|
| 1339 |
+
|
| 1340 |
+
if distributed:
|
| 1341 |
+
if isinstance(train_loader.sampler, DistributedSampler):
|
| 1342 |
+
train_loader.sampler.set_epoch(epoch)
|
| 1343 |
+
dist.barrier()
|
| 1344 |
+
|
| 1345 |
+
epoch_time = start_time = time.time()
|
| 1346 |
+
|
| 1347 |
+
train_loss, train_acc = 0, 0
|
| 1348 |
+
if not config.get("skip_train", False):
|
| 1349 |
+
train_loss, train_acc = self.train_epoch(
|
| 1350 |
+
model=self.model,
|
| 1351 |
+
train_loader=train_loader,
|
| 1352 |
+
optimizer=optimizer,
|
| 1353 |
+
loss_function=loss_function,
|
| 1354 |
+
acc_function=acc_function,
|
| 1355 |
+
grad_scaler=grad_scaler,
|
| 1356 |
+
epoch=report_epoch,
|
| 1357 |
+
rank=self.rank,
|
| 1358 |
+
global_rank=self.global_rank,
|
| 1359 |
+
num_epochs=report_num_epochs,
|
| 1360 |
+
sigmoid=sigmoid,
|
| 1361 |
+
use_amp=use_amp,
|
| 1362 |
+
use_cuda=use_cuda,
|
| 1363 |
+
channels_last=channels_last,
|
| 1364 |
+
num_steps_per_image=num_steps_per_image,
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
train_time = time.time() - start_time
|
| 1368 |
+
|
| 1369 |
+
if self.global_rank == 0:
|
| 1370 |
+
print(
|
| 1371 |
+
f"Final training {report_epoch}/{report_num_epochs - 1} "
|
| 1372 |
+
f"loss: {train_loss:.4f} acc_avg: {np.mean(train_acc):.4f} "
|
| 1373 |
+
f"acc {train_acc} time {train_time:.2f}s "
|
| 1374 |
+
f"lr: {optimizer.param_groups[0]['lr']:.4e}"
|
| 1375 |
+
)
|
| 1376 |
+
|
| 1377 |
+
if tb_writer is not None:
|
| 1378 |
+
tb_writer.add_scalar("train/loss", train_loss, report_epoch)
|
| 1379 |
+
tb_writer.add_scalar("train/acc", np.mean(train_acc), report_epoch)
|
| 1380 |
+
if mlflow_is_imported:
|
| 1381 |
+
mlflow.log_metric("train/loss", train_loss, step=report_epoch)
|
| 1382 |
+
|
| 1383 |
+
# validate every num_epochs_per_validation epochs (defaults to 1, every epoch)
|
| 1384 |
+
val_acc_mean = -1
|
| 1385 |
+
if (
|
| 1386 |
+
len(val_schedule_list) > 0
|
| 1387 |
+
and epoch + 1 >= val_schedule_list[0]
|
| 1388 |
+
and val_loader is not None
|
| 1389 |
+
and len(val_loader) > 0
|
| 1390 |
+
):
|
| 1391 |
+
val_schedule_list.pop(0)
|
| 1392 |
+
|
| 1393 |
+
start_time = time.time()
|
| 1394 |
+
torch.cuda.empty_cache()
|
| 1395 |
+
|
| 1396 |
+
val_loss, val_acc = self.val_epoch(
|
| 1397 |
+
model=self.model,
|
| 1398 |
+
val_loader=val_loader,
|
| 1399 |
+
sliding_inferrer=sliding_inferrer,
|
| 1400 |
+
loss_function=loss_function,
|
| 1401 |
+
acc_function=acc_function,
|
| 1402 |
+
epoch=report_epoch,
|
| 1403 |
+
rank=self.rank,
|
| 1404 |
+
global_rank=self.global_rank,
|
| 1405 |
+
num_epochs=report_num_epochs,
|
| 1406 |
+
sigmoid=sigmoid,
|
| 1407 |
+
use_amp=use_amp,
|
| 1408 |
+
use_cuda=use_cuda,
|
| 1409 |
+
channels_last=channels_last,
|
| 1410 |
+
calc_val_loss=calc_val_loss,
|
| 1411 |
+
)
|
| 1412 |
+
|
| 1413 |
+
torch.cuda.empty_cache()
|
| 1414 |
+
validation_time = time.time() - start_time
|
| 1415 |
+
|
| 1416 |
+
val_acc_mean = float(np.mean(val_acc))
|
| 1417 |
+
val_acc_history.append((report_epoch, val_acc_mean))
|
| 1418 |
+
|
| 1419 |
+
if self.global_rank == 0:
|
| 1420 |
+
print(
|
| 1421 |
+
f"Final validation {report_epoch}/{report_num_epochs - 1} "
|
| 1422 |
+
f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s"
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
if tb_writer is not None:
|
| 1426 |
+
tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch)
|
| 1427 |
+
if mlflow_is_imported:
|
| 1428 |
+
mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch)
|
| 1429 |
+
|
| 1430 |
+
for i in range(min(len(config["class_names"]), len(val_acc))): # accuracy per class
|
| 1431 |
+
tb_writer.add_scalar("val_class/" + config["class_names"][i], val_acc[i], report_epoch)
|
| 1432 |
+
if mlflow_is_imported:
|
| 1433 |
+
mlflow.log_metric(
|
| 1434 |
+
"val_class/" + config["class_names"][i], val_acc[i], step=report_epoch
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
if calc_val_loss:
|
| 1438 |
+
tb_writer.add_scalar("val/loss", val_loss, report_epoch)
|
| 1439 |
+
|
| 1440 |
+
timing_dict = dict(
|
| 1441 |
+
time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600),
|
| 1442 |
+
train_time="{:.2f}s".format(train_time),
|
| 1443 |
+
validation_time="{:.2f}s".format(validation_time),
|
| 1444 |
+
epoch_time="{:.2f}s".format(time.time() - epoch_time),
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
if val_acc_mean > best_metric:
|
| 1448 |
+
print(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ")
|
| 1449 |
+
best_metric, best_metric_epoch = val_acc_mean, report_epoch
|
| 1450 |
+
save_time = 0
|
| 1451 |
+
if do_torch_save:
|
| 1452 |
+
save_time = self.checkpoint_save(
|
| 1453 |
+
ckpt=best_ckpt_path, model=self.model, epoch=best_metric_epoch, best_metric=best_metric
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
if progress_path is not None:
|
| 1457 |
+
self.save_progress_yaml(
|
| 1458 |
+
progress_path=progress_path,
|
| 1459 |
+
ckpt=best_ckpt_path if do_torch_save else None,
|
| 1460 |
+
best_avg_dice_score_epoch=best_metric_epoch,
|
| 1461 |
+
best_avg_dice_score=best_metric,
|
| 1462 |
+
save_time=save_time,
|
| 1463 |
+
**timing_dict,
|
| 1464 |
+
)
|
| 1465 |
+
if csv_path is not None:
|
| 1466 |
+
self.save_history_csv(
|
| 1467 |
+
csv_path=csv_path,
|
| 1468 |
+
epoch=report_epoch,
|
| 1469 |
+
metric="{:.4f}".format(val_acc_mean),
|
| 1470 |
+
loss="{:.4f}".format(train_loss),
|
| 1471 |
+
iter=report_epoch * len(train_loader.dataset),
|
| 1472 |
+
**timing_dict,
|
| 1473 |
+
)
|
| 1474 |
+
|
| 1475 |
+
# sanity check
|
| 1476 |
+
if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config["stop_on_lowacc"]:
|
| 1477 |
+
raise ValueError(
|
| 1478 |
+
f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. "
|
| 1479 |
+
f"Most likely optimization diverged, try setting a smaller learning_rate than {config['learning_rate']}"
|
| 1480 |
+
)
|
| 1481 |
+
|
| 1482 |
+
# early stopping
|
| 1483 |
+
if config["early_stopping_fraction"] > 0 and epoch > num_epochs / 2 and len(val_acc_history) > 10:
|
| 1484 |
+
check_interval = int(0.1 * num_epochs * num_crops_per_image)
|
| 1485 |
+
check_stats = [
|
| 1486 |
+
va[1] for va in val_acc_history if report_epoch - va[0] < check_interval
|
| 1487 |
+
] # at least 10% epochs
|
| 1488 |
+
if len(check_stats) < 10:
|
| 1489 |
+
check_stats = [va[1] for va in val_acc_history[-10:]] # at least 10 sample points
|
| 1490 |
+
mac, mic = max(check_stats), min(check_stats)
|
| 1491 |
+
|
| 1492 |
+
early_stopping_fraction = (mac - mic) / (abs(mac) + 1e-8)
|
| 1493 |
+
if mac > 0 and mic > 0 and early_stopping_fraction < config["early_stopping_fraction"]:
|
| 1494 |
+
if self.global_rank == 0:
|
| 1495 |
+
print(
|
| 1496 |
+
f"Early stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}"
|
| 1497 |
+
)
|
| 1498 |
+
break
|
| 1499 |
+
else:
|
| 1500 |
+
if self.global_rank == 0:
|
| 1501 |
+
print(
|
| 1502 |
+
f"No stopping at epoch {report_epoch} fraction {early_stopping_fraction} !!! max {mac} min {mic} samples count {len(check_stats)} {check_stats[-50:]}"
|
| 1503 |
+
)
|
| 1504 |
+
|
| 1505 |
+
# save intermediate checkpoint every num_epochs_per_saving epochs
|
| 1506 |
+
if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs):
|
| 1507 |
+
if report_epoch != best_metric_epoch:
|
| 1508 |
+
self.checkpoint_save(
|
| 1509 |
+
ckpt=intermediate_ckpt_path, model=self.model, epoch=report_epoch, best_metric=val_acc_mean
|
| 1510 |
+
)
|
| 1511 |
+
else:
|
| 1512 |
+
shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) # if already saved once
|
| 1513 |
+
|
| 1514 |
+
if lr_scheduler is not None:
|
| 1515 |
+
lr_scheduler.step()
|
| 1516 |
+
|
| 1517 |
+
if self.global_rank == 0:
|
| 1518 |
+
# report time estimate
|
| 1519 |
+
time_remaining_estimate = train_time * (num_epochs - epoch)
|
| 1520 |
+
if val_loader is not None and len(val_loader) > 0:
|
| 1521 |
+
if validation_time == 0:
|
| 1522 |
+
validation_time = train_time
|
| 1523 |
+
time_remaining_estimate += validation_time * len(val_schedule_list)
|
| 1524 |
+
|
| 1525 |
+
print(
|
| 1526 |
+
f"Estimated remaining training time for the current model fold {config['fold']} is "
|
| 1527 |
+
f"{time_remaining_estimate/3600:.2f} hr, "
|
| 1528 |
+
f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, "
|
| 1529 |
+
f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n"
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
# end of main epoch loop
|
| 1533 |
+
|
| 1534 |
+
train_loader = val_loader = optimizer = None
|
| 1535 |
+
|
| 1536 |
+
# optionally validate best checkpoint at the original image resolution
|
| 1537 |
+
orig_res = config["resample"] == False
|
| 1538 |
+
if config["validate_final_original_res"] and config["resample"]:
|
| 1539 |
+
pretrained_ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path
|
| 1540 |
+
if os.path.exists(pretrained_ckpt_name):
|
| 1541 |
+
self.model = None
|
| 1542 |
+
gc.collect()
|
| 1543 |
+
torch.cuda.empty_cache()
|
| 1544 |
+
|
| 1545 |
+
best_metric = self.original_resolution_validate(
|
| 1546 |
+
pretrained_ckpt_name=pretrained_ckpt_name,
|
| 1547 |
+
progress_path=progress_path,
|
| 1548 |
+
best_metric_epoch=best_metric_epoch,
|
| 1549 |
+
pre_loop_time=pre_loop_time,
|
| 1550 |
+
)
|
| 1551 |
+
orig_res = True
|
| 1552 |
+
else:
|
| 1553 |
+
if self.global_rank == 0:
|
| 1554 |
+
print(
|
| 1555 |
+
f"Unable to validate at the original res since no model checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}"
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
if tb_writer is not None:
|
| 1559 |
+
tb_writer.flush()
|
| 1560 |
+
tb_writer.close()
|
| 1561 |
+
|
| 1562 |
+
if mlflow_is_imported:
|
| 1563 |
+
mlflow.end_run()
|
| 1564 |
+
|
| 1565 |
+
if self.global_rank == 0:
|
| 1566 |
+
print(
|
| 1567 |
+
f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs} orig_res {orig_res}. Training time {(time.time() - pre_loop_time)/3600:.2f} hr."
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
return best_metric
|
| 1571 |
+
|
| 1572 |
+
def original_resolution_validate(self, pretrained_ckpt_name, progress_path, best_metric_epoch, pre_loop_time):
|
| 1573 |
+
if self.global_rank == 0:
|
| 1574 |
+
print("Running final best model validation on the original image resolution!")
|
| 1575 |
+
|
| 1576 |
+
self.model = self.setup_model(pretrained_ckpt_name=pretrained_ckpt_name)
|
| 1577 |
+
|
| 1578 |
+
# validate
|
| 1579 |
+
start_time = time.time()
|
| 1580 |
+
val_acc_mean, val_loss, val_acc = self.validate()
|
| 1581 |
+
validation_time = "{:.2f}s".format(time.time() - start_time)
|
| 1582 |
+
val_acc_mean = float(np.mean(val_acc))
|
| 1583 |
+
if self.global_rank == 0:
|
| 1584 |
+
print(
|
| 1585 |
+
f"Original resolution validation: "
|
| 1586 |
+
f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} "
|
| 1587 |
+
f"acc {val_acc} time {validation_time}"
|
| 1588 |
+
)
|
| 1589 |
+
|
| 1590 |
+
if progress_path is not None:
|
| 1591 |
+
self.save_progress_yaml(
|
| 1592 |
+
progress_path=progress_path,
|
| 1593 |
+
ckpt=pretrained_ckpt_name,
|
| 1594 |
+
best_avg_dice_score_epoch=best_metric_epoch,
|
| 1595 |
+
best_avg_dice_score=val_acc_mean,
|
| 1596 |
+
validation_time=validation_time,
|
| 1597 |
+
inverted_best_validation=True,
|
| 1598 |
+
time="{:.2f} hr".format((time.time() - pre_loop_time) / 3600),
|
| 1599 |
+
)
|
| 1600 |
+
|
| 1601 |
+
return val_acc_mean
|
| 1602 |
+
|
| 1603 |
+
def validate(self, validation_files=None):
|
| 1604 |
+
config = self.config
|
| 1605 |
+
resample = config["resample"]
|
| 1606 |
+
|
| 1607 |
+
val_config = self.config["validate"]
|
| 1608 |
+
output_path = val_config.get("output_path", None)
|
| 1609 |
+
save_mask = val_config.get("save_mask", False) and output_path is not None
|
| 1610 |
+
invert = val_config.get("invert", True)
|
| 1611 |
+
|
| 1612 |
+
data_list_file_path = config["data_list_file_path"]
|
| 1613 |
+
if not os.path.isabs(data_list_file_path):
|
| 1614 |
+
data_list_file_path = os.path.abspath(os.path.join(config["bundle_root"], data_list_file_path))
|
| 1615 |
+
|
| 1616 |
+
if validation_files is None:
|
| 1617 |
+
if config.get("validation_key", None) is not None:
|
| 1618 |
+
validation_files, _ = datafold_read(
|
| 1619 |
+
datalist=data_list_file_path,
|
| 1620 |
+
basedir=config["data_file_base_dir"],
|
| 1621 |
+
fold=-1,
|
| 1622 |
+
key=config["validation_key"],
|
| 1623 |
+
)
|
| 1624 |
+
else:
|
| 1625 |
+
_, validation_files = datafold_read(
|
| 1626 |
+
datalist=data_list_file_path, basedir=config["data_file_base_dir"], fold=config["fold"]
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
+
if self.global_rank == 0:
|
| 1630 |
+
print(f"validation files {len(validation_files)}")
|
| 1631 |
+
|
| 1632 |
+
if len(validation_files) == 0:
|
| 1633 |
+
warnings.warn("No validation files found!")
|
| 1634 |
+
return
|
| 1635 |
+
|
| 1636 |
+
val_loader = self.get_val_loader(data=validation_files, resample_label=not invert)
|
| 1637 |
+
val_transform = val_loader.dataset.transform
|
| 1638 |
+
|
| 1639 |
+
post_transforms = None
|
| 1640 |
+
if save_mask or invert:
|
| 1641 |
+
post_transforms = DataTransformBuilder.get_postprocess_transform(
|
| 1642 |
+
save_mask=save_mask,
|
| 1643 |
+
invert=invert,
|
| 1644 |
+
transform=val_transform,
|
| 1645 |
+
sigmoid=self.config["sigmoid"],
|
| 1646 |
+
output_path=output_path,
|
| 1647 |
+
resample=resample,
|
| 1648 |
+
data_root_dir=self.config["data_file_base_dir"],
|
| 1649 |
+
output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
|
| 1650 |
+
save_mask_mode=self.config.get("save_mask_mode", None),
|
| 1651 |
+
)
|
| 1652 |
+
|
| 1653 |
+
start_time = time.time()
|
| 1654 |
+
val_loss, val_acc = self.val_epoch(
|
| 1655 |
+
model=self.model,
|
| 1656 |
+
val_loader=val_loader,
|
| 1657 |
+
sliding_inferrer=self.sliding_inferrer,
|
| 1658 |
+
loss_function=self.loss_function,
|
| 1659 |
+
acc_function=self.acc_function,
|
| 1660 |
+
rank=self.rank,
|
| 1661 |
+
global_rank=self.global_rank,
|
| 1662 |
+
sigmoid=self.config["sigmoid"],
|
| 1663 |
+
use_amp=self.config["amp"],
|
| 1664 |
+
use_cuda=self.config["cuda"],
|
| 1665 |
+
post_transforms=post_transforms,
|
| 1666 |
+
channels_last=self.config["channels_last"],
|
| 1667 |
+
calc_val_loss=self.config["calc_val_loss"],
|
| 1668 |
+
)
|
| 1669 |
+
val_acc_mean = float(np.mean(val_acc))
|
| 1670 |
+
|
| 1671 |
+
if self.global_rank == 0:
|
| 1672 |
+
print(
|
| 1673 |
+
f"Validation complete, loss_avg: {val_loss:.4f} "
|
| 1674 |
+
f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s"
|
| 1675 |
+
)
|
| 1676 |
+
|
| 1677 |
+
return val_acc_mean, val_loss, val_acc
|
| 1678 |
+
|
| 1679 |
+
def infer(self, testing_files=None):
|
| 1680 |
+
output_path = self.config["infer"].get("output_path", None)
|
| 1681 |
+
testing_key = self.config["infer"].get("data_list_key", "testing")
|
| 1682 |
+
|
| 1683 |
+
if output_path is None:
|
| 1684 |
+
if self.global_rank == 0:
|
| 1685 |
+
print("Inference output_path is not specified")
|
| 1686 |
+
return
|
| 1687 |
+
|
| 1688 |
+
if testing_files is None:
|
| 1689 |
+
data_list_file_path = self.config["data_list_file_path"]
|
| 1690 |
+
if not os.path.isabs(data_list_file_path):
|
| 1691 |
+
data_list_file_path = os.path.abspath(os.path.join(self.config["bundle_root"], data_list_file_path))
|
| 1692 |
+
|
| 1693 |
+
testing_files, _ = datafold_read(
|
| 1694 |
+
datalist=data_list_file_path, basedir=self.config["data_file_base_dir"], fold=-1, key=testing_key
|
| 1695 |
+
)
|
| 1696 |
+
|
| 1697 |
+
if self.global_rank == 0:
|
| 1698 |
+
print(f"testing_files files {len(testing_files)}")
|
| 1699 |
+
|
| 1700 |
+
if len(testing_files) == 0:
|
| 1701 |
+
warnings.warn("No testing_files files found!")
|
| 1702 |
+
return
|
| 1703 |
+
|
| 1704 |
+
inf_loader = self.get_val_loader(data=testing_files, resample_label=False)
|
| 1705 |
+
inf_transform = inf_loader.dataset.transform
|
| 1706 |
+
|
| 1707 |
+
post_transforms = DataTransformBuilder.get_postprocess_transform(
|
| 1708 |
+
save_mask=True,
|
| 1709 |
+
invert=True,
|
| 1710 |
+
transform=inf_transform,
|
| 1711 |
+
sigmoid=self.config["sigmoid"],
|
| 1712 |
+
output_path=output_path,
|
| 1713 |
+
resample=self.config["resample"],
|
| 1714 |
+
data_root_dir=self.config["data_file_base_dir"],
|
| 1715 |
+
output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
|
| 1716 |
+
save_mask_mode=self.config.get("save_mask_mode", None),
|
| 1717 |
+
)
|
| 1718 |
+
|
| 1719 |
+
start_time = time.time()
|
| 1720 |
+
self.val_epoch(
|
| 1721 |
+
model=self.model,
|
| 1722 |
+
val_loader=inf_loader,
|
| 1723 |
+
sliding_inferrer=self.sliding_inferrer,
|
| 1724 |
+
rank=self.rank,
|
| 1725 |
+
global_rank=self.global_rank,
|
| 1726 |
+
sigmoid=self.config["sigmoid"],
|
| 1727 |
+
use_amp=self.config["amp"],
|
| 1728 |
+
use_cuda=self.config["cuda"],
|
| 1729 |
+
post_transforms=post_transforms,
|
| 1730 |
+
channels_last=self.config["channels_last"],
|
| 1731 |
+
calc_val_loss=self.config["calc_val_loss"],
|
| 1732 |
+
)
|
| 1733 |
+
|
| 1734 |
+
if self.global_rank == 0:
|
| 1735 |
+
print(f"Inference complete, time {time.time() - start_time:.2f}s")
|
| 1736 |
+
|
| 1737 |
+
@torch.no_grad()
|
| 1738 |
+
def infer_image(self, image_file):
|
| 1739 |
+
self.model.eval()
|
| 1740 |
+
|
| 1741 |
+
infer_config = self.config["infer"]
|
| 1742 |
+
output_path = infer_config.get("output_path", None)
|
| 1743 |
+
save_mask = infer_config.get("save_mask", False) and output_path is not None
|
| 1744 |
+
invert_on_gpu = infer_config.get("invert_on_gpu", False)
|
| 1745 |
+
|
| 1746 |
+
start_time = time.time()
|
| 1747 |
+
sigmoid = self.config["sigmoid"]
|
| 1748 |
+
resample = self.config["resample"]
|
| 1749 |
+
channels_last = self.config["channels_last"]
|
| 1750 |
+
|
| 1751 |
+
inf_transform = self.get_data_transform_builder()(augment=False, resample_label=False)
|
| 1752 |
+
|
| 1753 |
+
batch_data = inf_transform([image_file])
|
| 1754 |
+
batch_data = list_data_collate([batch_data])
|
| 1755 |
+
|
| 1756 |
+
memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
|
| 1757 |
+
data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=self.device)
|
| 1758 |
+
|
| 1759 |
+
with autocast(self.amp_device_type, enabled=self.config["amp"]):
|
| 1760 |
+
logits = self.sliding_inferrer(inputs=data, network=self.model)
|
| 1761 |
+
|
| 1762 |
+
data = None
|
| 1763 |
+
|
| 1764 |
+
try:
|
| 1765 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid)
|
| 1766 |
+
except RuntimeError as e:
|
| 1767 |
+
if not logits.is_cuda:
|
| 1768 |
+
raise e
|
| 1769 |
+
print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape}")
|
| 1770 |
+
logits = logits.cpu()
|
| 1771 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid)
|
| 1772 |
+
|
| 1773 |
+
logits = None
|
| 1774 |
+
|
| 1775 |
+
if not invert_on_gpu:
|
| 1776 |
+
pred = pred.cpu() # invert on cpu (default)
|
| 1777 |
+
|
| 1778 |
+
post_transforms = DataTransformBuilder.get_postprocess_transform(
|
| 1779 |
+
save_mask=save_mask,
|
| 1780 |
+
invert=True,
|
| 1781 |
+
transform=inf_transform,
|
| 1782 |
+
sigmoid=sigmoid,
|
| 1783 |
+
output_path=output_path,
|
| 1784 |
+
resample=resample,
|
| 1785 |
+
data_root_dir=self.config["data_file_base_dir"],
|
| 1786 |
+
output_dtype=np.uint8 if self.config["output_classes"] < 255 else np.uint16,
|
| 1787 |
+
save_mask_mode=self.config.get("save_mask_mode", None),
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
batch_data["pred"] = convert_to_dst_type(pred, batch_data["image"], dtype=pred.dtype, device=pred.device)[
|
| 1791 |
+
0
|
| 1792 |
+
] # make Meta tensor
|
| 1793 |
+
pred = [post_transforms(x)["pred"] for x in decollate_batch(batch_data)]
|
| 1794 |
+
|
| 1795 |
+
pred = pred[0]
|
| 1796 |
+
|
| 1797 |
+
print(f"Inference complete, time {time.time() - start_time:.2f}s shape {pred.shape} {image_file}")
|
| 1798 |
+
|
| 1799 |
+
return pred
|
| 1800 |
+
|
| 1801 |
+
def train_epoch(
|
| 1802 |
+
self,
|
| 1803 |
+
model,
|
| 1804 |
+
train_loader,
|
| 1805 |
+
optimizer,
|
| 1806 |
+
loss_function,
|
| 1807 |
+
acc_function,
|
| 1808 |
+
grad_scaler,
|
| 1809 |
+
epoch,
|
| 1810 |
+
rank,
|
| 1811 |
+
global_rank=0,
|
| 1812 |
+
num_epochs=0,
|
| 1813 |
+
sigmoid=False,
|
| 1814 |
+
use_amp=True,
|
| 1815 |
+
use_cuda=True,
|
| 1816 |
+
channels_last=False,
|
| 1817 |
+
num_steps_per_image=1,
|
| 1818 |
+
):
|
| 1819 |
+
model.train()
|
| 1820 |
+
device = torch.device(rank) if use_cuda else torch.device("cpu")
|
| 1821 |
+
memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
|
| 1822 |
+
|
| 1823 |
+
run_loss = CumulativeAverage()
|
| 1824 |
+
run_acc = CumulativeAverage()
|
| 1825 |
+
|
| 1826 |
+
start_time = time.time()
|
| 1827 |
+
avg_loss = avg_acc = 0
|
| 1828 |
+
for idx, batch_data in enumerate(train_loader):
|
| 1829 |
+
data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
|
| 1830 |
+
target = batch_data["label"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
|
| 1831 |
+
|
| 1832 |
+
data_list = data.chunk(num_steps_per_image) if num_steps_per_image > 1 else [data]
|
| 1833 |
+
target_list = target.chunk(num_steps_per_image) if num_steps_per_image > 1 else [target]
|
| 1834 |
+
|
| 1835 |
+
for ich in range(min(num_steps_per_image, len(data_list))):
|
| 1836 |
+
data = data_list[ich]
|
| 1837 |
+
target = target_list[ich]
|
| 1838 |
+
|
| 1839 |
+
# optimizer.zero_grad(set_to_none=True)
|
| 1840 |
+
for param in model.parameters():
|
| 1841 |
+
param.grad = None
|
| 1842 |
+
|
| 1843 |
+
with autocast(self.amp_device_type, enabled=use_amp):
|
| 1844 |
+
logits = model(data)
|
| 1845 |
+
|
| 1846 |
+
loss = loss_function(logits, target)
|
| 1847 |
+
grad_scaler.scale(loss).backward()
|
| 1848 |
+
grad_scaler.step(optimizer)
|
| 1849 |
+
grad_scaler.update()
|
| 1850 |
+
|
| 1851 |
+
with torch.no_grad():
|
| 1852 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True)
|
| 1853 |
+
acc = acc_function(pred, target)
|
| 1854 |
+
|
| 1855 |
+
batch_size_adjusted = batch_size = data.shape[0]
|
| 1856 |
+
if isinstance(acc, (list, tuple)):
|
| 1857 |
+
acc, batch_size_adjusted = acc
|
| 1858 |
+
|
| 1859 |
+
run_loss.append(loss, count=batch_size)
|
| 1860 |
+
run_acc.append(acc, count=batch_size_adjusted)
|
| 1861 |
+
|
| 1862 |
+
avg_loss = run_loss.aggregate()
|
| 1863 |
+
avg_acc = run_acc.aggregate()
|
| 1864 |
+
|
| 1865 |
+
if global_rank == 0:
|
| 1866 |
+
print(
|
| 1867 |
+
f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} "
|
| 1868 |
+
f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s "
|
| 1869 |
+
)
|
| 1870 |
+
start_time = time.time()
|
| 1871 |
+
|
| 1872 |
+
# optimizer.zero_grad(set_to_none=True)
|
| 1873 |
+
for param in model.parameters():
|
| 1874 |
+
param.grad = None
|
| 1875 |
+
|
| 1876 |
+
data = None
|
| 1877 |
+
target = None
|
| 1878 |
+
data_list = None
|
| 1879 |
+
target_list = None
|
| 1880 |
+
batch_data = None
|
| 1881 |
+
|
| 1882 |
+
return avg_loss, avg_acc
|
| 1883 |
+
|
| 1884 |
+
@torch.no_grad()
|
| 1885 |
+
def val_epoch(
|
| 1886 |
+
self,
|
| 1887 |
+
model,
|
| 1888 |
+
val_loader,
|
| 1889 |
+
sliding_inferrer,
|
| 1890 |
+
loss_function=None,
|
| 1891 |
+
acc_function=None,
|
| 1892 |
+
epoch=0,
|
| 1893 |
+
rank=0,
|
| 1894 |
+
global_rank=0,
|
| 1895 |
+
num_epochs=0,
|
| 1896 |
+
sigmoid=False,
|
| 1897 |
+
use_amp=True,
|
| 1898 |
+
use_cuda=True,
|
| 1899 |
+
post_transforms=None,
|
| 1900 |
+
channels_last=False,
|
| 1901 |
+
calc_val_loss=False,
|
| 1902 |
+
):
|
| 1903 |
+
model.eval()
|
| 1904 |
+
device = torch.device(rank) if use_cuda else torch.device("cpu")
|
| 1905 |
+
memory_format = torch.channels_last_3d if channels_last else torch.preserve_format
|
| 1906 |
+
distributed = dist.is_initialized()
|
| 1907 |
+
|
| 1908 |
+
run_loss = CumulativeAverage()
|
| 1909 |
+
run_acc = CumulativeAverage()
|
| 1910 |
+
run_loss.append(torch.tensor(0, device=device), count=0)
|
| 1911 |
+
|
| 1912 |
+
avg_loss = avg_acc = 0
|
| 1913 |
+
start_time = time.time()
|
| 1914 |
+
|
| 1915 |
+
# In DDP, each replica has a subset of data, but if total data length is not evenly divisible by num_replicas, then some replicas has 1 extra repeated item.
|
| 1916 |
+
# For proper validation with batch of 1, we only want to collect metrics for non-repeated items, hence let's compute a proper subset length
|
| 1917 |
+
nonrepeated_data_length = len(val_loader.dataset)
|
| 1918 |
+
sampler = val_loader.sampler
|
| 1919 |
+
if dist.is_initialized and isinstance(sampler, DistributedSampler) and not sampler.drop_last:
|
| 1920 |
+
nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas))
|
| 1921 |
+
|
| 1922 |
+
for idx, batch_data in enumerate(val_loader):
|
| 1923 |
+
data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
|
| 1924 |
+
filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ]
|
| 1925 |
+
batch_size = data.shape[0]
|
| 1926 |
+
|
| 1927 |
+
with autocast(self.amp_device_type, enabled=use_amp):
|
| 1928 |
+
logits = sliding_inferrer(inputs=data, network=model)
|
| 1929 |
+
|
| 1930 |
+
data = None
|
| 1931 |
+
|
| 1932 |
+
if post_transforms:
|
| 1933 |
+
|
| 1934 |
+
try:
|
| 1935 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid)
|
| 1936 |
+
except RuntimeError as e:
|
| 1937 |
+
if not logits.is_cuda:
|
| 1938 |
+
raise e
|
| 1939 |
+
print(f"logits2pred failed on GPU pred retrying on CPU {logits.shape} {filename}")
|
| 1940 |
+
logits = logits.cpu()
|
| 1941 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid)
|
| 1942 |
+
|
| 1943 |
+
if not calc_val_loss:
|
| 1944 |
+
logits = None
|
| 1945 |
+
|
| 1946 |
+
batch_data["pred"] = convert_to_dst_type(
|
| 1947 |
+
pred, batch_data["image"], dtype=pred.dtype, device=pred.device
|
| 1948 |
+
)[0]
|
| 1949 |
+
pred = None
|
| 1950 |
+
|
| 1951 |
+
try:
|
| 1952 |
+
# inverting on gpu can OOM due inverse resampling or un-cropping
|
| 1953 |
+
pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)])
|
| 1954 |
+
except RuntimeError as e:
|
| 1955 |
+
if not batch_data["pred"].is_cuda:
|
| 1956 |
+
raise e
|
| 1957 |
+
print(f"post_transforms failed on GPU pred retrying on CPU {batch_data['pred'].shape}")
|
| 1958 |
+
batch_data["pred"] = batch_data["pred"].cpu()
|
| 1959 |
+
pred = torch.stack([post_transforms(x)["pred"] for x in decollate_batch(batch_data)])
|
| 1960 |
+
|
| 1961 |
+
batch_data["pred"] = None
|
| 1962 |
+
if logits is not None and pred.shape != logits.shape:
|
| 1963 |
+
logits = None # if shape has changed due to inverse resampling or un-cropping
|
| 1964 |
+
else:
|
| 1965 |
+
pred = self.logits2pred(logits, sigmoid=sigmoid, skip_softmax=True)
|
| 1966 |
+
|
| 1967 |
+
if "label" in batch_data and loss_function is not None and acc_function is not None:
|
| 1968 |
+
loss = acc = None
|
| 1969 |
+
target = batch_data["label"].as_subclass(torch.Tensor)
|
| 1970 |
+
|
| 1971 |
+
if calc_val_loss:
|
| 1972 |
+
if logits is not None:
|
| 1973 |
+
loss = loss_function(logits, target.to(device=logits.device))
|
| 1974 |
+
run_loss.append(loss.to(device=device), count=batch_size)
|
| 1975 |
+
logits = None
|
| 1976 |
+
|
| 1977 |
+
with torch.no_grad():
|
| 1978 |
+
try:
|
| 1979 |
+
acc = acc_function(pred.to(device=device), target.to(device=device)) # try GPU
|
| 1980 |
+
except RuntimeError as e:
|
| 1981 |
+
if "OutOfMemoryError" not in str(type(e).__name__):
|
| 1982 |
+
raise e
|
| 1983 |
+
print(
|
| 1984 |
+
f"acc_function val failed on GPU pred: {pred.shape} on {pred.device}, target: {target.shape} on {target.device}. retrying on CPU"
|
| 1985 |
+
)
|
| 1986 |
+
acc = acc_function(pred.cpu(), target.cpu())
|
| 1987 |
+
|
| 1988 |
+
batch_size_adjusted = batch_size
|
| 1989 |
+
if isinstance(acc, (list, tuple)):
|
| 1990 |
+
acc, batch_size_adjusted = acc
|
| 1991 |
+
acc = acc.detach().clone()
|
| 1992 |
+
|
| 1993 |
+
if idx < nonrepeated_data_length:
|
| 1994 |
+
run_acc.append(acc.to(device=device), count=batch_size_adjusted)
|
| 1995 |
+
else:
|
| 1996 |
+
run_acc.append(torch.zeros_like(acc, device=device), count=torch.zeros_like(batch_size_adjusted))
|
| 1997 |
+
|
| 1998 |
+
avg_loss = loss.cpu() if loss is not None else 0
|
| 1999 |
+
avg_acc = acc.cpu().numpy() if acc is not None else 0
|
| 2000 |
+
pred, target = None, None
|
| 2001 |
+
|
| 2002 |
+
if global_rank == 0:
|
| 2003 |
+
print(
|
| 2004 |
+
f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} loss: {avg_loss:.4f} "
|
| 2005 |
+
f"acc {avg_acc} time {time.time() - start_time:.2f}s {filename}"
|
| 2006 |
+
)
|
| 2007 |
+
|
| 2008 |
+
else:
|
| 2009 |
+
if global_rank == 0:
|
| 2010 |
+
print(f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} time {time.time() - start_time:.2f}s")
|
| 2011 |
+
|
| 2012 |
+
start_time = time.time()
|
| 2013 |
+
|
| 2014 |
+
pred = target = data = batch_data = None
|
| 2015 |
+
|
| 2016 |
+
if distributed:
|
| 2017 |
+
dist.barrier()
|
| 2018 |
+
|
| 2019 |
+
avg_loss = run_loss.aggregate()
|
| 2020 |
+
avg_acc = run_acc.aggregate()
|
| 2021 |
+
|
| 2022 |
+
if np.any(avg_acc < 0):
|
| 2023 |
+
dist.barrier()
|
| 2024 |
+
warnings.warn("Avg dice accuracy is negative, something went wrong!!!!!")
|
| 2025 |
+
|
| 2026 |
+
return avg_loss, avg_acc
|
| 2027 |
+
|
| 2028 |
+
def logits2pred(self, logits, sigmoid=False, dim=1, skip_softmax=False):
|
| 2029 |
+
if isinstance(logits, (list, tuple)):
|
| 2030 |
+
logits = logits[0]
|
| 2031 |
+
|
| 2032 |
+
if sigmoid:
|
| 2033 |
+
pred = torch.sigmoid(logits)
|
| 2034 |
+
else:
|
| 2035 |
+
pred = logits if skip_softmax else torch.softmax(logits, dim=dim, dtype=torch.double).float()
|
| 2036 |
+
|
| 2037 |
+
return pred
|
| 2038 |
+
|
| 2039 |
+
def get_avail_cpu_memory(self):
|
| 2040 |
+
avail_memory = psutil.virtual_memory().available
|
| 2041 |
+
|
| 2042 |
+
# check if in docker
|
| 2043 |
+
memory_limit_filename = "/sys/fs/cgroup/memory/memory.limit_in_bytes"
|
| 2044 |
+
if os.path.exists(memory_limit_filename):
|
| 2045 |
+
with open(memory_limit_filename, "r") as f:
|
| 2046 |
+
docker_limit = int(f.read())
|
| 2047 |
+
avail_memory = min(docker_limit, avail_memory) # could be lower limit in docker
|
| 2048 |
+
|
| 2049 |
+
return avail_memory
|
| 2050 |
+
|
| 2051 |
+
def get_cache_rate(self, train_cases=0, validation_cases=0, prioritise_train=True):
|
| 2052 |
+
config = self.config
|
| 2053 |
+
cache_rate = config["cache_rate"]
|
| 2054 |
+
avail_memory = None
|
| 2055 |
+
|
| 2056 |
+
total_cases = train_cases + validation_cases
|
| 2057 |
+
|
| 2058 |
+
image_size_mm_90 = config.get("image_size_mm_90", None)
|
| 2059 |
+
if config["resample"] and image_size_mm_90 is not None:
|
| 2060 |
+
image_size = (
|
| 2061 |
+
(np.array(image_size_mm_90) / np.array(config["resample_resolution"])).astype(np.int32).tolist()
|
| 2062 |
+
)
|
| 2063 |
+
else:
|
| 2064 |
+
image_size = config["image_size"]
|
| 2065 |
+
|
| 2066 |
+
approx_data_cache_required = (4 * config["input_channels"] + 1) * np.prod(image_size) * total_cases
|
| 2067 |
+
approx_os_cache_required = 50 * 1024**3 # reserve 50gb
|
| 2068 |
+
|
| 2069 |
+
if cache_rate is None:
|
| 2070 |
+
cache_rate = 0
|
| 2071 |
+
|
| 2072 |
+
if image_size is not None:
|
| 2073 |
+
avail_memory = self.get_avail_cpu_memory()
|
| 2074 |
+
cache_rate = min(avail_memory / float(approx_data_cache_required + approx_os_cache_required), 1.0)
|
| 2075 |
+
if cache_rate < 0.1:
|
| 2076 |
+
cache_rate = 0.0 # don't cache small
|
| 2077 |
+
|
| 2078 |
+
if self.global_rank == 0:
|
| 2079 |
+
print(
|
| 2080 |
+
f"Calculating cache required {approx_data_cache_required >> 30}GB, available RAM {avail_memory >> 30}GB given avg image size {image_size}."
|
| 2081 |
+
)
|
| 2082 |
+
if cache_rate < 1:
|
| 2083 |
+
print(
|
| 2084 |
+
f"Available RAM is not enought to cache full dataset, caching a fraction {cache_rate:.2f}"
|
| 2085 |
+
)
|
| 2086 |
+
else:
|
| 2087 |
+
print("Caching full dataset in RAM")
|
| 2088 |
+
else:
|
| 2089 |
+
print("Cant calculate cache_rate since image_size is not provided!!!!")
|
| 2090 |
+
|
| 2091 |
+
else:
|
| 2092 |
+
if self.global_rank == 0:
|
| 2093 |
+
print(f"Using user specified cache_rate={cache_rate} to cache data in RAM")
|
| 2094 |
+
|
| 2095 |
+
# allocate cache_rate to training files first
|
| 2096 |
+
cache_rate_train = cache_rate_val = cache_rate
|
| 2097 |
+
|
| 2098 |
+
if prioritise_train:
|
| 2099 |
+
if cache_rate > 0 and cache_rate < 1:
|
| 2100 |
+
cache_num = cache_rate * total_cases
|
| 2101 |
+
cache_rate_train = min(1.0, cache_num / train_cases) if train_cases > 0 else 0
|
| 2102 |
+
if (cache_rate_train < 1 and train_cases > 0) or validation_cases == 0:
|
| 2103 |
+
cache_rate_val = 0
|
| 2104 |
+
else:
|
| 2105 |
+
cache_rate_val = (cache_num - cache_rate_train * train_cases) / validation_cases
|
| 2106 |
+
|
| 2107 |
+
if self.global_rank == 0:
|
| 2108 |
+
print(f"Prioritizing cache_rate training {cache_rate_train} validation {cache_rate_val}")
|
| 2109 |
+
|
| 2110 |
+
return cache_rate_train, cache_rate_val
|
| 2111 |
+
|
| 2112 |
+
def save_history_csv(self, csv_path=None, header=None, **kwargs):
|
| 2113 |
+
if csv_path is not None:
|
| 2114 |
+
if header is not None:
|
| 2115 |
+
with open(csv_path, "a") as myfile:
|
| 2116 |
+
wrtr = csv.writer(myfile, delimiter="\t")
|
| 2117 |
+
wrtr.writerow(header)
|
| 2118 |
+
if len(kwargs):
|
| 2119 |
+
with open(csv_path, "a") as myfile:
|
| 2120 |
+
wrtr = csv.writer(myfile, delimiter="\t")
|
| 2121 |
+
wrtr.writerow(list(kwargs.values()))
|
| 2122 |
+
|
| 2123 |
+
def save_progress_yaml(self, progress_path=None, ckpt=None, **report):
|
| 2124 |
+
if ckpt is not None:
|
| 2125 |
+
report["model"] = ckpt
|
| 2126 |
+
|
| 2127 |
+
report["date"] = str(datetime.now())[:19]
|
| 2128 |
+
|
| 2129 |
+
if progress_path is not None:
|
| 2130 |
+
yaml.add_representer(
|
| 2131 |
+
float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", "{0:.4f}".format(value))
|
| 2132 |
+
)
|
| 2133 |
+
with open(progress_path, "a") as progress_file:
|
| 2134 |
+
yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False)
|
| 2135 |
+
|
| 2136 |
+
print("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items()))
|
| 2137 |
+
|
| 2138 |
+
def run(self):
|
| 2139 |
+
if self.config["validate"]["enabled"]:
|
| 2140 |
+
self.validate()
|
| 2141 |
+
elif self.config["infer"]["enabled"]:
|
| 2142 |
+
self.infer()
|
| 2143 |
+
else:
|
| 2144 |
+
self.train()
|
| 2145 |
+
|
| 2146 |
+
|
| 2147 |
+
def run_segmenter_worker(rank=0, config_file: Optional[Union[str, Sequence[str]]] = None, override: Dict = {}):
|
| 2148 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
| 2149 |
+
dist_available = dist.is_available()
|
| 2150 |
+
global_rank = rank
|
| 2151 |
+
|
| 2152 |
+
if type(config_file) == str and "," in config_file:
|
| 2153 |
+
config_file = config_file.split(",")
|
| 2154 |
+
|
| 2155 |
+
if dist_available:
|
| 2156 |
+
mgpu = override.get("mgpu", None)
|
| 2157 |
+
if mgpu is not None:
|
| 2158 |
+
logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING)
|
| 2159 |
+
dist.init_process_group(backend="nccl", rank=rank, timeout=timedelta(seconds=5400), **mgpu)
|
| 2160 |
+
mgpu.update({"rank": rank, "global_rank": rank})
|
| 2161 |
+
if rank == 0:
|
| 2162 |
+
print(f"Distributed: initializing multi-gpu tcp:// process group {mgpu}")
|
| 2163 |
+
|
| 2164 |
+
elif dist_launched() and torch.cuda.device_count() > 1:
|
| 2165 |
+
rank = int(os.getenv("LOCAL_RANK"))
|
| 2166 |
+
global_rank = int(os.getenv("RANK"))
|
| 2167 |
+
world_size = int(os.getenv("LOCAL_WORLD_SIZE"))
|
| 2168 |
+
logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING)
|
| 2169 |
+
dist.init_process_group(backend="nccl", init_method="env://") # torchrun spawned it
|
| 2170 |
+
override["mgpu"] = {"world_size": world_size, "rank": rank, "global_rank": global_rank}
|
| 2171 |
+
|
| 2172 |
+
print(f"Distributed launched: initializing multi-gpu env:// process group {override['mgpu']}")
|
| 2173 |
+
|
| 2174 |
+
segmenter = Segmenter(config_file=config_file, config_dict=override, rank=rank, global_rank=global_rank)
|
| 2175 |
+
best_metric = segmenter.run()
|
| 2176 |
+
segmenter = None
|
| 2177 |
+
|
| 2178 |
+
if dist_available and dist.is_initialized():
|
| 2179 |
+
dist.destroy_process_group()
|
| 2180 |
+
|
| 2181 |
+
return best_metric
|
| 2182 |
+
|
| 2183 |
+
|
| 2184 |
+
def dist_launched() -> bool:
|
| 2185 |
+
return dist.is_torchelastic_launched() or (
|
| 2186 |
+
os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1
|
| 2187 |
+
)
|
| 2188 |
+
|
| 2189 |
+
|
| 2190 |
+
def run_segmenter(config_file: Optional[Union[str, Sequence[str]]] = None, **kwargs):
|
| 2191 |
+
"""
|
| 2192 |
+
if multiple gpu available, start multiprocessing for all gpus
|
| 2193 |
+
"""
|
| 2194 |
+
|
| 2195 |
+
nprocs = torch.cuda.device_count()
|
| 2196 |
+
|
| 2197 |
+
if nprocs > 1 and not dist_launched():
|
| 2198 |
+
print("Manually spawning processes {nprocs}")
|
| 2199 |
+
kwargs["mgpu"] = {"world_size": nprocs, "init_method": kwargs.get("init_method", "tcp://127.0.0.1:23456")}
|
| 2200 |
+
torch.multiprocessing.spawn(run_segmenter_worker, nprocs=nprocs, args=(config_file, kwargs))
|
| 2201 |
+
else:
|
| 2202 |
+
print("Not spawning processes, dist is already launched {nprocs}")
|
| 2203 |
+
run_segmenter_worker(0, config_file, kwargs)
|
| 2204 |
+
|
| 2205 |
+
|
| 2206 |
+
if __name__ == "__main__":
|
| 2207 |
+
fire, fire_is_imported = optional_import("fire")
|
| 2208 |
+
if fire_is_imported:
|
| 2209 |
+
fire.Fire(run_segmenter)
|
| 2210 |
+
else:
|
| 2211 |
+
warnings.warn("Fire commandline parser cannot be imported, using options from config/hyper_parameters.yaml")
|
| 2212 |
+
run_segmenter(config_file="config/hyper_parameters.yaml")
|
scripts/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
from monai.apps.auto3dseg.auto_runner import logger
|
| 9 |
+
|
| 10 |
+
print = logger.debug
|
| 11 |
+
roi_size_default = [224, 224, 144]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def logger_configure(log_output_file: str = None, debug=False, global_rank=0) -> None:
|
| 15 |
+
log_config = {
|
| 16 |
+
"version": 1,
|
| 17 |
+
"disable_existing_loggers": False,
|
| 18 |
+
"formatters": {"monai_default": {"format": "%(message)s"}},
|
| 19 |
+
"loggers": {
|
| 20 |
+
"monai.apps.auto3dseg.auto_runner": {"handlers": ["console", "file"], "level": "DEBUG", "propagate": False}
|
| 21 |
+
},
|
| 22 |
+
# "filters": {"rank_filter": {"()": RankFilter}},
|
| 23 |
+
"handlers": {
|
| 24 |
+
"console": {
|
| 25 |
+
"class": "logging.StreamHandler",
|
| 26 |
+
"level": "INFO",
|
| 27 |
+
"formatter": "monai_default",
|
| 28 |
+
# "filters": ["rank_filter"],
|
| 29 |
+
},
|
| 30 |
+
"file": {
|
| 31 |
+
"class": "logging.FileHandler",
|
| 32 |
+
"filename": "runner.log",
|
| 33 |
+
"mode": "a",
|
| 34 |
+
"level": "DEBUG",
|
| 35 |
+
"formatter": "monai_default",
|
| 36 |
+
# "filters": ["rank_filter"],
|
| 37 |
+
},
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
if log_output_file is not None:
|
| 42 |
+
log_config["handlers"]["file"]["filename"] = log_output_file
|
| 43 |
+
log_config["handlers"]["file"]["level"] = "DEBUG"
|
| 44 |
+
else:
|
| 45 |
+
log_config["handlers"]["file"]["level"] = "CRITICAL"
|
| 46 |
+
|
| 47 |
+
if debug or bool(os.environ.get("SEGRESNET_DEBUG", False)):
|
| 48 |
+
log_config["handlers"]["console"]["level"] = "DEBUG"
|
| 49 |
+
|
| 50 |
+
logging.config.dictConfig(log_config)
|
| 51 |
+
# if global_rank!=0:
|
| 52 |
+
# logger.addFilter(lambda x: False)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_gpu_mem_size():
|
| 56 |
+
gpu_mem = 0
|
| 57 |
+
n_gpus = torch.cuda.device_count()
|
| 58 |
+
if n_gpus > 0:
|
| 59 |
+
gpu_mem = min([torch.cuda.get_device_properties(i).total_memory for i in range(n_gpus)])
|
| 60 |
+
gpu_mem = gpu_mem / 1024**3
|
| 61 |
+
else:
|
| 62 |
+
gpu_mem = 16
|
| 63 |
+
|
| 64 |
+
return gpu_mem
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def auto_adjust_network_settings(
|
| 68 |
+
auto_scale_roi=False,
|
| 69 |
+
auto_scale_batch=False,
|
| 70 |
+
auto_scale_filters=False,
|
| 71 |
+
image_size_mm=None,
|
| 72 |
+
spacing=None,
|
| 73 |
+
output_classes=None,
|
| 74 |
+
levels=None,
|
| 75 |
+
anisotropic_scales=False,
|
| 76 |
+
levels_limit=5,
|
| 77 |
+
gpu_mem=None,
|
| 78 |
+
):
|
| 79 |
+
global_rank = 0
|
| 80 |
+
if dist.is_available() and dist.is_initialized():
|
| 81 |
+
global_rank = dist.get_rank()
|
| 82 |
+
print(f"auto_adjust_network_settings dist global_rank {global_rank}")
|
| 83 |
+
else:
|
| 84 |
+
print(f"auto_adjust_network_settings no distributed global_rank {global_rank}")
|
| 85 |
+
|
| 86 |
+
batch_size_default = 1
|
| 87 |
+
init_filters_default = 32
|
| 88 |
+
|
| 89 |
+
roi_size = np.array(roi_size_default)
|
| 90 |
+
base_numel = roi_size.prod()
|
| 91 |
+
gpu_factor = 1
|
| 92 |
+
|
| 93 |
+
if gpu_mem is None:
|
| 94 |
+
gpu_mem = get_gpu_mem_size()
|
| 95 |
+
if global_rank == 0:
|
| 96 |
+
print(f"GPU device memory min: {gpu_mem}")
|
| 97 |
+
|
| 98 |
+
# adapting
|
| 99 |
+
if auto_scale_batch or auto_scale_roi or auto_scale_filters:
|
| 100 |
+
gpu_factor_init = gpu_factor = max(1, gpu_mem / 16)
|
| 101 |
+
if anisotropic_scales:
|
| 102 |
+
gpu_factor = max(1, 0.8 * gpu_factor)
|
| 103 |
+
if global_rank == 0:
|
| 104 |
+
print(f"base_numel {base_numel} gpu_factor {gpu_factor} gpu_factor_init {gpu_factor_init}")
|
| 105 |
+
else:
|
| 106 |
+
gpu_mem = 16
|
| 107 |
+
gpu_factor = gpu_factor_init = 1
|
| 108 |
+
|
| 109 |
+
# account for output_classes
|
| 110 |
+
output_classes_thresh = 20
|
| 111 |
+
if output_classes is not None and output_classes > output_classes_thresh:
|
| 112 |
+
base_adjust = gpu_mem / (output_classes * 0.2 + 11.5)
|
| 113 |
+
if gpu_mem < 17:
|
| 114 |
+
base_adjust /= 2
|
| 115 |
+
|
| 116 |
+
if global_rank == 0:
|
| 117 |
+
print(f"base_adjust {base_adjust} since output_classes {output_classes} > {output_classes_thresh}")
|
| 118 |
+
if base_adjust < 0.95: # reduce roi
|
| 119 |
+
base_numel *= base_adjust
|
| 120 |
+
r = int(base_numel ** (1 / 3) / 2**4)
|
| 121 |
+
if r == 0 and global_rank == 0:
|
| 122 |
+
print(f"Warning: given output_classes {output_classes}, unable to fit any ROI on the gpu {gpu_mem} Gb!")
|
| 123 |
+
roi_size = np.array([max(1, r) * 2**4] * 3)
|
| 124 |
+
gpu_factor = gpu_factor_init = 1
|
| 125 |
+
auto_scale_roi = False
|
| 126 |
+
else:
|
| 127 |
+
gpu_factor_init = gpu_factor = base_adjust
|
| 128 |
+
|
| 129 |
+
if global_rank == 0:
|
| 130 |
+
print(f"base_numel {base_numel} roi_size {roi_size} gpu_factor {gpu_factor}")
|
| 131 |
+
|
| 132 |
+
if image_size_mm is not None and spacing is not None:
|
| 133 |
+
image_size = np.floor(np.array(image_size_mm) / np.array(spacing))
|
| 134 |
+
if global_rank == 0:
|
| 135 |
+
print(f"input roi {roi_size} image_size {image_size} numel {roi_size.prod()}")
|
| 136 |
+
roi_size = np.minimum(roi_size, image_size)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError("image_size_mm or spacing is not provided, network params may be inaccuracy")
|
| 139 |
+
|
| 140 |
+
# adjust ROI
|
| 141 |
+
max_numel = base_numel * gpu_factor if auto_scale_roi else base_numel
|
| 142 |
+
while roi_size.prod() < max_numel:
|
| 143 |
+
old_numel = roi_size.prod()
|
| 144 |
+
roi_size = np.minimum(roi_size * 1.15, image_size)
|
| 145 |
+
if global_rank == 0:
|
| 146 |
+
print(f"increasing roi step {roi_size}")
|
| 147 |
+
if roi_size.prod() == old_numel:
|
| 148 |
+
break
|
| 149 |
+
if global_rank == 0:
|
| 150 |
+
print(f"increasing roi result 1 {roi_size}")
|
| 151 |
+
|
| 152 |
+
# adjust number of network downsize levels
|
| 153 |
+
if not anisotropic_scales:
|
| 154 |
+
if levels is None:
|
| 155 |
+
levels = np.floor(np.log2(roi_size))
|
| 156 |
+
if global_rank == 0:
|
| 157 |
+
print(f"levels 1 {levels}")
|
| 158 |
+
levels = min(min(levels), levels_limit) # limit to 5
|
| 159 |
+
if global_rank == 0:
|
| 160 |
+
print(f"levels 2' {levels}")
|
| 161 |
+
|
| 162 |
+
factor = 2 ** (levels - 1)
|
| 163 |
+
roi_size = factor * np.maximum(2, np.floor(roi_size / factor))
|
| 164 |
+
if global_rank == 0:
|
| 165 |
+
print(f"roi_size factored {roi_size}")
|
| 166 |
+
|
| 167 |
+
else:
|
| 168 |
+
extra_levels = np.floor(np.log2(np.max(spacing) / spacing))
|
| 169 |
+
extra_levels = max(extra_levels) - extra_levels
|
| 170 |
+
|
| 171 |
+
if levels is None:
|
| 172 |
+
# calc levels
|
| 173 |
+
levels = np.floor(np.log2(roi_size))
|
| 174 |
+
if global_rank == 0:
|
| 175 |
+
print(f"levels 1 aniso {levels} extra_levels {extra_levels}")
|
| 176 |
+
levels = min(min(levels + extra_levels), levels_limit) # limit to 5
|
| 177 |
+
if global_rank == 0:
|
| 178 |
+
print(f"levels 2 {levels}")
|
| 179 |
+
|
| 180 |
+
factor = 2 ** (np.maximum(1, levels - extra_levels) - 1)
|
| 181 |
+
roi_size = factor * np.maximum(2, np.floor(roi_size / factor))
|
| 182 |
+
if global_rank == 0:
|
| 183 |
+
print(f"roi_size factored {roi_size} factor {factor} extra_levels {extra_levels}")
|
| 184 |
+
|
| 185 |
+
# optionally adjust initial filters (above 32)
|
| 186 |
+
if auto_scale_filters and roi_size.prod() < base_numel * gpu_factor:
|
| 187 |
+
init_filters = int(max(32, np.floor(4 * (base_numel / roi_size.prod())) * 8))
|
| 188 |
+
if global_rank == 0:
|
| 189 |
+
print(f"checking to increase init_filters {init_filters}")
|
| 190 |
+
gpu_factor_init *= init_filters / 32
|
| 191 |
+
gpu_factor *= init_filters / 32
|
| 192 |
+
else:
|
| 193 |
+
if global_rank == 0:
|
| 194 |
+
print(f"kept filters the same base_numel {base_numel}, gpu_factor {gpu_factor}")
|
| 195 |
+
|
| 196 |
+
init_filters = init_filters_default
|
| 197 |
+
|
| 198 |
+
# finally scale batch
|
| 199 |
+
if auto_scale_batch and roi_size.prod() < base_numel * gpu_factor_init:
|
| 200 |
+
batch_size = int(1.1 * gpu_factor_init)
|
| 201 |
+
if global_rank == 0:
|
| 202 |
+
print(
|
| 203 |
+
f"increased batch_size {batch_size} base_numel {base_numel}, gpu_factor {gpu_factor}, gpu_factor_init {gpu_factor_init}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
batch_size = batch_size_default
|
| 208 |
+
if global_rank == 0:
|
| 209 |
+
print(
|
| 210 |
+
f"kept batch the same base_numel {base_numel}, gpu_factor {gpu_factor}, gpu_factor_init {gpu_factor_init}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
levels = int(levels)
|
| 214 |
+
roi_size = roi_size.astype(int).tolist()
|
| 215 |
+
|
| 216 |
+
if global_rank == 0:
|
| 217 |
+
print(
|
| 218 |
+
f"Suggested network parameters: \n"
|
| 219 |
+
f"Batch size {batch_size_default} => {batch_size} \n"
|
| 220 |
+
f"ROI size {roi_size_default} => {roi_size} \n"
|
| 221 |
+
f"init_filters {init_filters_default} => {init_filters} \n"
|
| 222 |
+
f"aniso: {anisotropic_scales} image_size_mm: {image_size_mm} spacing: {spacing} levels: {levels} \n"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return roi_size, levels, init_filters, batch_size
|