mednist_ddpm / configs /train_multigpu.yaml
project-monai's picture
Upload mednist_ddpm version 1.0.1
57decc6 verified
# This can be mixed in with the training script to enable multi-GPU training
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator
# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'
network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128
base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true
scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'
inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'
# Training specific
network:
_target_: torch.nn.parallel.DistributedDataParallel
module: $@network_def.to(@device)
device_ids: ['@device']
find_unused_parameters: true
tsampler:
_target_: DistributedSampler
dataset: '@train_ds'
even_divisible: true
shuffle: true
train_loader#sampler: '@tsampler'
train_loader#shuffle: false
vsampler:
_target_: DistributedSampler
dataset: '@val_ds'
even_divisible: false
shuffle: false
val_loader#sampler: '@vsampler'
training:
- $import torch.distributed as dist
- $dist.init_process_group(backend='nccl')
- $torch.cuda.set_device(@device)
- $monai.utils.set_determinism(seed=123),
- $@trainer.run()
- $dist.destroy_process_group()