|
|
|
|
|
imports: |
|
|
- $import os |
|
|
- $import datetime |
|
|
- $import torch |
|
|
- $import scripts |
|
|
- $import monai |
|
|
- $import torch.distributed as dist |
|
|
- $import operator |
|
|
|
|
|
|
|
|
- |
|
|
image: $monai.utils.CommonKeys.IMAGE |
|
|
label: $monai.utils.CommonKeys.LABEL |
|
|
pred: $monai.utils.CommonKeys.PRED |
|
|
|
|
|
is_dist: '$dist.is_initialized()' |
|
|
rank: '$dist.get_rank() if @is_dist else 0' |
|
|
is_not_rank0: '$@rank > 0' |
|
|
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' |
|
|
|
|
|
network_def: |
|
|
_target_: monai.networks.nets.DiffusionModelUNet |
|
|
spatial_dims: 2 |
|
|
in_channels: 1 |
|
|
out_channels: 1 |
|
|
channels: [64, 128, 128] |
|
|
attention_levels: [false, true, true] |
|
|
num_res_blocks: 1 |
|
|
num_head_channels: 128 |
|
|
|
|
|
base_transforms: |
|
|
- _target_: LoadImaged |
|
|
keys: '@image' |
|
|
image_only: true |
|
|
- _target_: EnsureChannelFirstd |
|
|
keys: '@image' |
|
|
- _target_: ScaleIntensityRanged |
|
|
keys: '@image' |
|
|
a_min: 0.0 |
|
|
a_max: 255.0 |
|
|
b_min: 0.0 |
|
|
b_max: 1.0 |
|
|
clip: true |
|
|
|
|
|
scheduler: |
|
|
_target_: monai.networks.schedulers.DDPMScheduler |
|
|
num_train_timesteps: '@num_train_timesteps' |
|
|
|
|
|
inferer: |
|
|
_target_: monai.inferers.DiffusionInferer |
|
|
scheduler: '@scheduler' |
|
|
|
|
|
|
|
|
|
|
|
network: |
|
|
_target_: torch.nn.parallel.DistributedDataParallel |
|
|
module: $@network_def.to(@device) |
|
|
device_ids: ['@device'] |
|
|
find_unused_parameters: true |
|
|
|
|
|
tsampler: |
|
|
_target_: DistributedSampler |
|
|
dataset: '@train_ds' |
|
|
even_divisible: true |
|
|
shuffle: true |
|
|
train_loader#sampler: '@tsampler' |
|
|
train_loader#shuffle: false |
|
|
|
|
|
vsampler: |
|
|
_target_: DistributedSampler |
|
|
dataset: '@val_ds' |
|
|
even_divisible: false |
|
|
shuffle: false |
|
|
val_loader#sampler: '@vsampler' |
|
|
|
|
|
training: |
|
|
- $import torch.distributed as dist |
|
|
- $dist.init_process_group(backend='nccl') |
|
|
- $torch.cuda.set_device(@device) |
|
|
- $monai.utils.set_determinism(seed=123), |
|
|
- $@trainer.run() |
|
|
- $dist.destroy_process_group() |
|
|
|