# This defines the training script for the network imports: - $import os - $import datetime - $import torch - $import scripts - $import monai - $import torch.distributed as dist - $import operator # Common elements to all training files - image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED is_dist: '$dist.is_initialized()' rank: '$dist.get_rank() if @is_dist else 0' is_not_rank0: '$@rank > 0' device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' network_def: _target_: monai.networks.nets.DiffusionModelUNet spatial_dims: 2 in_channels: 1 out_channels: 1 channels: [64, 128, 128] attention_levels: [false, true, true] num_res_blocks: 1 num_head_channels: 128 network: $@network_def.to(@device) bundle_root: . ckpt_path: $@bundle_root + '/models/model.pt' use_amp: true image_dim: 64 image_size: [1, '@image_dim', '@image_dim'] num_train_timesteps: 1000 base_transforms: - _target_: LoadImaged keys: '@image' image_only: true - _target_: EnsureChannelFirstd keys: '@image' - _target_: ScaleIntensityRanged keys: '@image' a_min: 0.0 a_max: 255.0 b_min: 0.0 b_max: 1.0 clip: true scheduler: _target_: monai.networks.schedulers.DDPMScheduler num_train_timesteps: '@num_train_timesteps' inferer: _target_: monai.inferers.DiffusionInferer scheduler: '@scheduler' # Training-specific # choose a new directory for every run output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S') dataset_dir: ./data train_data: _target_ : MedNISTDataset root_dir: '@dataset_dir' section: training download: true progress: false seed: 0 val_data: _target_ : MedNISTDataset root_dir: '@dataset_dir' section: validation download: true progress: false seed: 0 train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]' val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]' batch_size: 8 num_substeps: 1 num_workers: 4 use_thread_workers: false lr: 0.000025 rand_prob: 0.5 num_epochs: 75 val_interval: 5 save_interval: 5 train_transforms: - _target_: RandAffined keys: '@image' rotate_range: - ['$-np.pi / 36', '$np.pi / 36'] - ['$-np.pi / 36', '$np.pi / 36'] translate_range: - [-1, 1] - [-1, 1] scale_range: - [-0.05, 0.05] - [-0.05, 0.05] spatial_size: [64, 64] padding_mode: "zeros" prob: '@rand_prob' train_ds: _target_: Dataset data: $@train_datalist transform: _target_: Compose transforms: '$@base_transforms + @train_transforms' train_loader: _target_: ThreadDataLoader dataset: '@train_ds' batch_size: '@batch_size' repeats: '@num_substeps' num_workers: '@num_workers' use_thread_workers: '@use_thread_workers' persistent_workers: '$@num_workers > 0' shuffle: true val_ds: _target_: Dataset data: $@val_datalist transform: _target_: Compose transforms: '@base_transforms' val_loader: _target_: DataLoader dataset: '@val_ds' batch_size: '@batch_size' num_workers: '@num_workers' persistent_workers: '$@num_workers > 0' shuffle: false lossfn: _target_: torch.nn.MSELoss optimizer: _target_: torch.optim.Adam params: $@network.parameters() lr: '@lr' prepare_batch: _target_: monai.engines.DiffusionPrepareBatch num_train_timesteps: '@num_train_timesteps' val_handlers: - _target_: StatsHandler name: train_log output_transform: '$lambda x: None' _disabled_: '@is_not_rank0' evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@val_loader' network: '@network' amp: '@use_amp' inferer: '@inferer' prepare_batch: '@prepare_batch' key_val_metric: val_mean_abs_error: _target_: MeanAbsoluteError output_transform: $monai.handlers.from_engine([@pred, @label]) metric_cmp_fn: '$operator.lt' val_handlers: '$list(filter(bool, @val_handlers))' handlers: - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: ValidationHandler validator: '@evaluator' epoch_level: true interval: '@val_interval' - _target_: CheckpointSaver save_dir: '@output_dir' save_dict: model: '@network' save_interval: '@save_interval' save_final: true epoch_level: true _disabled_: '@is_not_rank0' trainer: _target_: SupervisedTrainer max_epochs: '@num_epochs' device: '@device' train_data_loader: '@train_loader' network: '@network' loss_function: '@lossfn' optimizer: '@optimizer' inferer: '@inferer' prepare_batch: '@prepare_batch' key_train_metric: train_acc: _target_: MeanSquaredError output_transform: $monai.handlers.from_engine([@pred, @label]) metric_cmp_fn: '$operator.lt' train_handlers: '$list(filter(bool, @handlers))' amp: '@use_amp' training: - '$monai.utils.set_determinism(0)' - '$@trainer.run()'