mednist_ddpm / configs /train.yaml
project-monai's picture
Upload mednist_ddpm version 1.0.1
57decc6 verified
# This defines the training script for the network
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator
# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'
network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128
network: $@network_def.to(@device)
bundle_root: .
ckpt_path: $@bundle_root + '/models/model.pt'
use_amp: true
image_dim: 64
image_size: [1, '@image_dim', '@image_dim']
num_train_timesteps: 1000
base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true
scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'
inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'
# Training-specific
# choose a new directory for every run
output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')
dataset_dir: ./data
train_data:
_target_ : MedNISTDataset
root_dir: '@dataset_dir'
section: training
download: true
progress: false
seed: 0
val_data:
_target_ : MedNISTDataset
root_dir: '@dataset_dir'
section: validation
download: true
progress: false
seed: 0
train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]'
val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]'
batch_size: 8
num_substeps: 1
num_workers: 4
use_thread_workers: false
lr: 0.000025
rand_prob: 0.5
num_epochs: 75
val_interval: 5
save_interval: 5
train_transforms:
- _target_: RandAffined
keys: '@image'
rotate_range:
- ['$-np.pi / 36', '$np.pi / 36']
- ['$-np.pi / 36', '$np.pi / 36']
translate_range:
- [-1, 1]
- [-1, 1]
scale_range:
- [-0.05, 0.05]
- [-0.05, 0.05]
spatial_size: [64, 64]
padding_mode: "zeros"
prob: '@rand_prob'
train_ds:
_target_: Dataset
data: $@train_datalist
transform:
_target_: Compose
transforms: '$@base_transforms + @train_transforms'
train_loader:
_target_: ThreadDataLoader
dataset: '@train_ds'
batch_size: '@batch_size'
repeats: '@num_substeps'
num_workers: '@num_workers'
use_thread_workers: '@use_thread_workers'
persistent_workers: '$@num_workers > 0'
shuffle: true
val_ds:
_target_: Dataset
data: $@val_datalist
transform:
_target_: Compose
transforms: '@base_transforms'
val_loader:
_target_: DataLoader
dataset: '@val_ds'
batch_size: '@batch_size'
num_workers: '@num_workers'
persistent_workers: '$@num_workers > 0'
shuffle: false
lossfn:
_target_: torch.nn.MSELoss
optimizer:
_target_: torch.optim.Adam
params: $@network.parameters()
lr: '@lr'
prepare_batch:
_target_: monai.engines.DiffusionPrepareBatch
num_train_timesteps: '@num_train_timesteps'
val_handlers:
- _target_: StatsHandler
name: train_log
output_transform: '$lambda x: None'
_disabled_: '@is_not_rank0'
evaluator:
_target_: SupervisedEvaluator
device: '@device'
val_data_loader: '@val_loader'
network: '@network'
amp: '@use_amp'
inferer: '@inferer'
prepare_batch: '@prepare_batch'
key_val_metric:
val_mean_abs_error:
_target_: MeanAbsoluteError
output_transform: $monai.handlers.from_engine([@pred, @label])
metric_cmp_fn: '$operator.lt'
val_handlers: '$list(filter(bool, @val_handlers))'
handlers:
- _target_: CheckpointLoader
_disabled_: $not os.path.exists(@ckpt_path)
load_path: '@ckpt_path'
load_dict:
model: '@network'
- _target_: ValidationHandler
validator: '@evaluator'
epoch_level: true
interval: '@val_interval'
- _target_: CheckpointSaver
save_dir: '@output_dir'
save_dict:
model: '@network'
save_interval: '@save_interval'
save_final: true
epoch_level: true
_disabled_: '@is_not_rank0'
trainer:
_target_: SupervisedTrainer
max_epochs: '@num_epochs'
device: '@device'
train_data_loader: '@train_loader'
network: '@network'
loss_function: '@lossfn'
optimizer: '@optimizer'
inferer: '@inferer'
prepare_batch: '@prepare_batch'
key_train_metric:
train_acc:
_target_: MeanSquaredError
output_transform: $monai.handlers.from_engine([@pred, @label])
metric_cmp_fn: '$operator.lt'
train_handlers: '$list(filter(bool, @handlers))'
amp: '@use_amp'
training:
- '$monai.utils.set_determinism(0)'
- '$@trainer.run()'