File size: 1,921 Bytes
57decc6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | # This can be mixed in with the training script to enable multi-GPU training
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator
# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'
network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128
base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true
scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'
inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'
# Training specific
network:
_target_: torch.nn.parallel.DistributedDataParallel
module: $@network_def.to(@device)
device_ids: ['@device']
find_unused_parameters: true
tsampler:
_target_: DistributedSampler
dataset: '@train_ds'
even_divisible: true
shuffle: true
train_loader#sampler: '@tsampler'
train_loader#shuffle: false
vsampler:
_target_: DistributedSampler
dataset: '@val_ds'
even_divisible: false
shuffle: false
val_loader#sampler: '@vsampler'
training:
- $import torch.distributed as dist
- $dist.init_process_group(backend='nccl')
- $torch.cuda.set_device(@device)
- $monai.utils.set_determinism(seed=123),
- $@trainer.run()
- $dist.destroy_process_group()
|