project-monai's picture
Upload pediatric_abdominal_ct_segmentation version 0.4.5
a0ae4d2 verified
---
imports:
- "$import glob"
- "$import json"
- "$import os"
- "$import ignite"
- "$from scipy import ndimage"
- "$import scripts.monai_utils"
- "$import scripts.lr_scheduler"
- "$import monai.apps.deepedit.transforms"
workflow_type: train
input_channels: 1
output_channels: 4
output_classes: 4
#arch_ckpt_path: "$@bundle_root + '/models/dynunet_FT.pt'"
#arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
bundle_root: "."
ckpt_dir: "$@bundle_root + '/models'"
output_dir: "$@bundle_root + '/eval'"
dataset_dir: "/processed/Public/CT_TotalSegmentator/TS_split/test/" #"/workspace/data"
data_list_file_path: "$@bundle_root + '/configs/TS_test.json'"
train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training')"
val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation')"
n_gpu: 0
device: "$torch.device('cuda:' + str(@n_gpu) if torch.cuda.is_available() else 'cpu')"
spatial_size:
- 96
- 96
- 96
spatial_dims: "$len(@spatial_size)"
labels:
background: 0
liver: 1
spleen: 2
pancreas: 3
network_def:
_target_: monai.networks.nets.DynUNet
spatial_dims: "@spatial_dims"
in_channels: "@input_channels"
out_channels: "@output_channels"
kernel_size:
- 3
- 3
- 3
- 3
- 3
- 3
strides:
- 1
- 2
- 2
- 2
- 2
-
- 2
- 2
- 1
upsample_kernel_size:
- 2
- 2
- 2
- 2
-
- 2
- 2
- 1
norm_name: "instance"
deep_supervision: false
res_block: true
network: "$@network_def.to(@device)"
loss:
_target_: DiceCELoss
include_background: false
to_onehot_y: true
softmax: true
squared_pred: true
batch: true
smooth_nr: 1.0e-06
smooth_dr: 1.0e-06
optimizer:
_target_: torch.optim.AdamW
params: "$@network.parameters()"
weight_decay: 1.0e-05
lr: 0.00005
max_epochs: 15
lr_scheduler:
_target_: scripts.lr_scheduler.LinearWarmupCosineAnnealingLR
optimizer: "@optimizer"
warmup_epochs: 10
warmup_start_lr: 0.0000005
eta_min: 1.0e-08
max_epochs: "@max_epochs"
image_key: image
label_key: label
val_interval: 2
train:
deterministic_transforms:
- _target_: LoadImaged
keys:
- "@image_key"
- "@label_key"
reader: ITKReader
- _target_: EnsureChannelFirstd
keys:
- "@image_key"
- "@label_key"
- _target_: Orientationd
keys:
- "@image_key"
- "@label_key"
axcodes: RAS
- _target_: Spacingd
keys:
- "@image_key"
- "@label_key"
pixdim:
- 1.5
- 1.5
- 3.0
mode:
- bilinear
- nearest
- _target_: scripts.monai_utils.AddLabelNamesd
keys: "@label_key"
label_names: "@labels"
- _target_: ScaleIntensityRanged
keys: "@image_key"
a_min: -250
a_max: 400
b_min: 0
b_max: 1
clip: true
- _target_: CropForegroundd
keys:
- "@image_key"
- "@label_key"
source_key: "@image_key"
mode:
- "minimum"
- "minimum"
- _target_: EnsureTyped
keys:
- "@image_key"
- "@label_key"
- _target_: CastToTyped
keys: "@image_key"
dtype: "$torch.float32"
random_transforms:
- _target_: RandCropByLabelClassesd
keys:
- "@image_key"
- "@label_key"
label_key: "@label_key" # label4crop
spatial_size: "@spatial_size"
num_classes: 4
ratios: null
allow_smaller: true
num_samples: 8
# - _target_: RandSpatialCropSamplesd
# keys:
# - "@image_key"
# - "@label_key"
# roi_size: "$[int(x * 0.75) for x in @spatial_size]"
# num_samples: 1
# max_roi_size: "@spatial_size"
# random_center: true
# random_size: true
# allow_missing_keys: false
- _target_: SpatialPadd
keys:
- "@image_key"
- "@label_key"
spatial_size: "@spatial_size"
method: "symmetric"
mode:
- "minimum"
- "minimum"
allow_missing_keys: false
- _target_: RandRotate90d
keys:
- "@image_key"
- "@label_key"
prob: 0.5
max_k: 3
allow_missing_keys: false
# - _target_: SelectItemsd
# keys:
# - "@image_key"
# - "@label_key"
# - "label_names"
- _target_: CastToTyped
keys:
- "@image_key"
- "@label_key"
dtype:
- "$torch.float32"
- "$torch.uint8"
- _target_: ToTensord
keys:
- "@image_key"
- "@label_key"
preprocessing:
_target_: Compose
transforms: "$@train#deterministic_transforms + @train#random_transforms"
dataset:
_target_: PersistentDataset
data: "@train_datalist"
transform: "@train#preprocessing"
cache_dir: "$@bundle_root + '/cache'"
dataloader:
_target_: DataLoader
dataset: "@train#dataset"
batch_size: 1
shuffle: true
num_workers: 4
inferer:
_target_: SimpleInferer
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: pred
softmax: true
- _target_: AsDiscreted
keys:
- pred
- label
argmax:
- true
- false
to_onehot:
- "@output_classes"
- "@output_classes"
- _target_: scripts.monai_utils.SplitPredsLabeld
keys: pred
# dice_function:
# _target_: "$engine.state.metrics['train_dice']"
handlers:
- _target_: LrScheduleHandler
lr_scheduler: "@lr_scheduler"
print_lr: true
# step_transform: "@dice_function"
- _target_: ValidationHandler
validator: "@validate#evaluator"
epoch_level: true
interval: "@val_interval"
- _target_: StatsHandler
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
- _target_: TensorBoardStatsHandler
log_dir: "@output_dir"
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
key_metric:
train_dice:
_target_: MeanDice
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
include_background: false
additional_metrics:
liver_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])"
include_background: false
spleen_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])"
include_background: false
pancreas_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])"
include_background: false
trainer:
_target_: SupervisedTrainer
device: "@device"
max_epochs: "@max_epochs"
train_data_loader: "@train#dataloader"
network: "@network"
loss_function: "@loss"
# train_interaction: null
optimizer: "@optimizer"
inferer: "@train#inferer"
postprocessing: "@train#postprocessing"
key_train_metric: "@train#key_metric"
additional_metrics: "@train#additional_metrics"
train_handlers: "@train#handlers"
amp: true
validate:
preprocessing:
_target_: Compose
transforms: "%train#deterministic_transforms"
dataset:
# _target_: CacheDataset
# data: "@val_datalist"
# transform: "@validate#preprocessing"
# cache_rate: 0.025
_target_: PersistentDataset
data: "@val_datalist"
transform: "@validate#preprocessing"
cache_dir: "$@bundle_root + '/cache'"
dataloader:
_target_: DataLoader
dataset: "@validate#dataset"
batch_size: 1
shuffle: false
num_workers: 4
inferer:
_target_: SlidingWindowInferer
roi_size: "@spatial_size"
sw_batch_size: 4
mode: "constant"
overlap: 0.5
postprocessing: "%train#postprocessing"
handlers:
- _target_: StatsHandler
iteration_log: false
- _target_: TensorBoardStatsHandler
log_dir: "@output_dir"
iteration_log: false
- _target_: CheckpointSaver
save_dir: "@ckpt_dir"
save_dict:
model: "@network"
save_key_metric: true
key_metric_filename: model_latest.pt
key_metric:
val_dice:
_target_: MeanDice
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
include_background: false
additional_metrics:
val_liver_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])"
include_background: false
val_spleen_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])"
include_background: false
val_pancreas_dice:
_target_: monai.handlers.MeanDice
output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])"
include_background: false
evaluator:
_target_: SupervisedEvaluator
device: "@device"
val_data_loader: "@validate#dataloader"
network: "@network"
inferer: "@validate#inferer"
postprocessing: "@validate#postprocessing"
key_val_metric: "@validate#key_metric"
#additional_metrics: "@validate#additional_metrics"
val_handlers: "@validate#handlers"
amp: true
initialize:
- "$monai.utils.set_determinism(seed=123)"
run:
- "$print('Training started... ')"
- "$print('output_channels: ', @output_channels )"
- "$print('spatial_dims: ', @spatial_dims)"
- "$print('Labels dict: ', @labels)"
- "$@train#trainer.run()"