# This can be mixed in with the training script to enable multi-GPU training imports: - $import os - $import datetime - $import torch - $import scripts - $import monai - $import torch.distributed as dist - $import operator # Common elements to all training files - image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED is_dist: '$dist.is_initialized()' rank: '$dist.get_rank() if @is_dist else 0' is_not_rank0: '$@rank > 0' device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' network_def: _target_: monai.networks.nets.DiffusionModelUNet spatial_dims: 2 in_channels: 1 out_channels: 1 channels: [64, 128, 128] attention_levels: [false, true, true] num_res_blocks: 1 num_head_channels: 128 base_transforms: - _target_: LoadImaged keys: '@image' image_only: true - _target_: EnsureChannelFirstd keys: '@image' - _target_: ScaleIntensityRanged keys: '@image' a_min: 0.0 a_max: 255.0 b_min: 0.0 b_max: 1.0 clip: true scheduler: _target_: monai.networks.schedulers.DDPMScheduler num_train_timesteps: '@num_train_timesteps' inferer: _target_: monai.inferers.DiffusionInferer scheduler: '@scheduler' # Training specific network: _target_: torch.nn.parallel.DistributedDataParallel module: $@network_def.to(@device) device_ids: ['@device'] find_unused_parameters: true tsampler: _target_: DistributedSampler dataset: '@train_ds' even_divisible: true shuffle: true train_loader#sampler: '@tsampler' train_loader#shuffle: false vsampler: _target_: DistributedSampler dataset: '@val_ds' even_divisible: false shuffle: false val_loader#sampler: '@vsampler' training: - $import torch.distributed as dist - $dist.init_process_group(backend='nccl') - $torch.cuda.set_device(@device) - $monai.utils.set_determinism(seed=123), - $@trainer.run() - $dist.destroy_process_group()