|
|
--- |
|
|
imports: |
|
|
- "$import glob" |
|
|
- "$import json" |
|
|
- "$import os" |
|
|
- "$import ignite" |
|
|
- "$from scipy import ndimage" |
|
|
- "$import scripts.monai_utils" |
|
|
- "$import scripts.lr_scheduler" |
|
|
- "$import monai.apps.deepedit.transforms" |
|
|
workflow_type: train |
|
|
input_channels: 1 |
|
|
output_channels: 4 |
|
|
output_classes: 4 |
|
|
|
|
|
|
|
|
bundle_root: "." |
|
|
ckpt_dir: "$@bundle_root + '/models'" |
|
|
output_dir: "$@bundle_root + '/eval'" |
|
|
dataset_dir: "/processed/Public/CT_TotalSegmentator/TS_split/test/" |
|
|
data_list_file_path: "$@bundle_root + '/configs/TS_test.json'" |
|
|
train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training')" |
|
|
val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation')" |
|
|
n_gpu: 0 |
|
|
device: "$torch.device('cuda:' + str(@n_gpu) if torch.cuda.is_available() else 'cpu')" |
|
|
spatial_size: |
|
|
- 96 |
|
|
- 96 |
|
|
- 96 |
|
|
spatial_dims: "$len(@spatial_size)" |
|
|
labels: |
|
|
background: 0 |
|
|
liver: 1 |
|
|
spleen: 2 |
|
|
pancreas: 3 |
|
|
network_def: |
|
|
_target_: monai.networks.nets.DynUNet |
|
|
spatial_dims: "@spatial_dims" |
|
|
in_channels: "@input_channels" |
|
|
out_channels: "@output_channels" |
|
|
kernel_size: |
|
|
- 3 |
|
|
- 3 |
|
|
- 3 |
|
|
- 3 |
|
|
- 3 |
|
|
- 3 |
|
|
strides: |
|
|
- 1 |
|
|
- 2 |
|
|
- 2 |
|
|
- 2 |
|
|
- 2 |
|
|
- |
|
|
- 2 |
|
|
- 2 |
|
|
- 1 |
|
|
upsample_kernel_size: |
|
|
- 2 |
|
|
- 2 |
|
|
- 2 |
|
|
- 2 |
|
|
- |
|
|
- 2 |
|
|
- 2 |
|
|
- 1 |
|
|
norm_name: "instance" |
|
|
deep_supervision: false |
|
|
res_block: true |
|
|
network: "$@network_def.to(@device)" |
|
|
loss: |
|
|
_target_: DiceCELoss |
|
|
include_background: false |
|
|
to_onehot_y: true |
|
|
softmax: true |
|
|
squared_pred: true |
|
|
batch: true |
|
|
smooth_nr: 1.0e-06 |
|
|
smooth_dr: 1.0e-06 |
|
|
optimizer: |
|
|
_target_: torch.optim.AdamW |
|
|
params: "$@network.parameters()" |
|
|
weight_decay: 1.0e-05 |
|
|
lr: 0.00005 |
|
|
max_epochs: 15 |
|
|
lr_scheduler: |
|
|
_target_: scripts.lr_scheduler.LinearWarmupCosineAnnealingLR |
|
|
optimizer: "@optimizer" |
|
|
warmup_epochs: 10 |
|
|
warmup_start_lr: 0.0000005 |
|
|
eta_min: 1.0e-08 |
|
|
max_epochs: "@max_epochs" |
|
|
image_key: image |
|
|
label_key: label |
|
|
val_interval: 2 |
|
|
train: |
|
|
deterministic_transforms: |
|
|
- _target_: LoadImaged |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
reader: ITKReader |
|
|
- _target_: EnsureChannelFirstd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
- _target_: Orientationd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
axcodes: RAS |
|
|
- _target_: Spacingd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
pixdim: |
|
|
- 1.5 |
|
|
- 1.5 |
|
|
- 3.0 |
|
|
mode: |
|
|
- bilinear |
|
|
- nearest |
|
|
- _target_: scripts.monai_utils.AddLabelNamesd |
|
|
keys: "@label_key" |
|
|
label_names: "@labels" |
|
|
- _target_: ScaleIntensityRanged |
|
|
keys: "@image_key" |
|
|
a_min: -250 |
|
|
a_max: 400 |
|
|
b_min: 0 |
|
|
b_max: 1 |
|
|
clip: true |
|
|
- _target_: CropForegroundd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
source_key: "@image_key" |
|
|
mode: |
|
|
- "minimum" |
|
|
- "minimum" |
|
|
- _target_: EnsureTyped |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
- _target_: CastToTyped |
|
|
keys: "@image_key" |
|
|
dtype: "$torch.float32" |
|
|
random_transforms: |
|
|
- _target_: RandCropByLabelClassesd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
label_key: "@label_key" |
|
|
spatial_size: "@spatial_size" |
|
|
num_classes: 4 |
|
|
ratios: null |
|
|
allow_smaller: true |
|
|
num_samples: 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- _target_: SpatialPadd |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
spatial_size: "@spatial_size" |
|
|
method: "symmetric" |
|
|
mode: |
|
|
- "minimum" |
|
|
- "minimum" |
|
|
allow_missing_keys: false |
|
|
- _target_: RandRotate90d |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
prob: 0.5 |
|
|
max_k: 3 |
|
|
allow_missing_keys: false |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- _target_: CastToTyped |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
dtype: |
|
|
- "$torch.float32" |
|
|
- "$torch.uint8" |
|
|
- _target_: ToTensord |
|
|
keys: |
|
|
- "@image_key" |
|
|
- "@label_key" |
|
|
preprocessing: |
|
|
_target_: Compose |
|
|
transforms: "$@train#deterministic_transforms + @train#random_transforms" |
|
|
dataset: |
|
|
_target_: PersistentDataset |
|
|
data: "@train_datalist" |
|
|
transform: "@train#preprocessing" |
|
|
cache_dir: "$@bundle_root + '/cache'" |
|
|
dataloader: |
|
|
_target_: DataLoader |
|
|
dataset: "@train#dataset" |
|
|
batch_size: 1 |
|
|
shuffle: true |
|
|
num_workers: 4 |
|
|
inferer: |
|
|
_target_: SimpleInferer |
|
|
postprocessing: |
|
|
_target_: Compose |
|
|
transforms: |
|
|
- _target_: Activationsd |
|
|
keys: pred |
|
|
softmax: true |
|
|
- _target_: AsDiscreted |
|
|
keys: |
|
|
- pred |
|
|
- label |
|
|
argmax: |
|
|
- true |
|
|
- false |
|
|
to_onehot: |
|
|
- "@output_classes" |
|
|
- "@output_classes" |
|
|
- _target_: scripts.monai_utils.SplitPredsLabeld |
|
|
keys: pred |
|
|
|
|
|
|
|
|
handlers: |
|
|
- _target_: LrScheduleHandler |
|
|
lr_scheduler: "@lr_scheduler" |
|
|
print_lr: true |
|
|
|
|
|
- _target_: ValidationHandler |
|
|
validator: "@validate#evaluator" |
|
|
epoch_level: true |
|
|
interval: "@val_interval" |
|
|
- _target_: StatsHandler |
|
|
tag_name: train_loss |
|
|
output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
|
|
- _target_: TensorBoardStatsHandler |
|
|
log_dir: "@output_dir" |
|
|
tag_name: train_loss |
|
|
output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
|
|
key_metric: |
|
|
train_dice: |
|
|
_target_: MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred', 'label'])" |
|
|
include_background: false |
|
|
additional_metrics: |
|
|
liver_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])" |
|
|
include_background: false |
|
|
spleen_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])" |
|
|
include_background: false |
|
|
pancreas_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])" |
|
|
include_background: false |
|
|
trainer: |
|
|
_target_: SupervisedTrainer |
|
|
device: "@device" |
|
|
max_epochs: "@max_epochs" |
|
|
train_data_loader: "@train#dataloader" |
|
|
network: "@network" |
|
|
loss_function: "@loss" |
|
|
|
|
|
optimizer: "@optimizer" |
|
|
inferer: "@train#inferer" |
|
|
postprocessing: "@train#postprocessing" |
|
|
key_train_metric: "@train#key_metric" |
|
|
additional_metrics: "@train#additional_metrics" |
|
|
train_handlers: "@train#handlers" |
|
|
amp: true |
|
|
validate: |
|
|
preprocessing: |
|
|
_target_: Compose |
|
|
transforms: "%train#deterministic_transforms" |
|
|
dataset: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_target_: PersistentDataset |
|
|
data: "@val_datalist" |
|
|
transform: "@validate#preprocessing" |
|
|
cache_dir: "$@bundle_root + '/cache'" |
|
|
dataloader: |
|
|
_target_: DataLoader |
|
|
dataset: "@validate#dataset" |
|
|
batch_size: 1 |
|
|
shuffle: false |
|
|
num_workers: 4 |
|
|
inferer: |
|
|
_target_: SlidingWindowInferer |
|
|
roi_size: "@spatial_size" |
|
|
sw_batch_size: 4 |
|
|
mode: "constant" |
|
|
overlap: 0.5 |
|
|
postprocessing: "%train#postprocessing" |
|
|
handlers: |
|
|
- _target_: StatsHandler |
|
|
iteration_log: false |
|
|
- _target_: TensorBoardStatsHandler |
|
|
log_dir: "@output_dir" |
|
|
iteration_log: false |
|
|
- _target_: CheckpointSaver |
|
|
save_dir: "@ckpt_dir" |
|
|
save_dict: |
|
|
model: "@network" |
|
|
save_key_metric: true |
|
|
key_metric_filename: model_latest.pt |
|
|
key_metric: |
|
|
val_dice: |
|
|
_target_: MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred', 'label'])" |
|
|
include_background: false |
|
|
additional_metrics: |
|
|
val_liver_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])" |
|
|
include_background: false |
|
|
val_spleen_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])" |
|
|
include_background: false |
|
|
val_pancreas_dice: |
|
|
_target_: monai.handlers.MeanDice |
|
|
output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])" |
|
|
include_background: false |
|
|
evaluator: |
|
|
_target_: SupervisedEvaluator |
|
|
device: "@device" |
|
|
val_data_loader: "@validate#dataloader" |
|
|
network: "@network" |
|
|
inferer: "@validate#inferer" |
|
|
postprocessing: "@validate#postprocessing" |
|
|
key_val_metric: "@validate#key_metric" |
|
|
|
|
|
val_handlers: "@validate#handlers" |
|
|
amp: true |
|
|
initialize: |
|
|
- "$monai.utils.set_determinism(seed=123)" |
|
|
run: |
|
|
- "$print('Training started... ')" |
|
|
- "$print('output_channels: ', @output_channels )" |
|
|
- "$print('spatial_dims: ', @spatial_dims)" |
|
|
- "$print('Labels dict: ', @labels)" |
|
|
- "$@train#trainer.run()" |
|
|
|