File size: 1,921 Bytes
57decc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# This can be mixed in with the training script to enable multi-GPU training
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator

# Common elements to all training files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'

network_def:
  _target_: monai.networks.nets.DiffusionModelUNet
  spatial_dims: 2
  in_channels: 1
  out_channels: 1
  channels: [64, 128, 128]
  attention_levels: [false, true, true]
  num_res_blocks: 1
  num_head_channels: 128

base_transforms:
- _target_: LoadImaged
  keys: '@image'
  image_only: true
- _target_: EnsureChannelFirstd
  keys: '@image'
- _target_: ScaleIntensityRanged
  keys: '@image'
  a_min: 0.0
  a_max: 255.0
  b_min: 0.0
  b_max: 1.0
  clip: true

scheduler:
  _target_: monai.networks.schedulers.DDPMScheduler
  num_train_timesteps: '@num_train_timesteps'

inferer:
  _target_: monai.inferers.DiffusionInferer
  scheduler: '@scheduler'

# Training specific

network:
  _target_: torch.nn.parallel.DistributedDataParallel
  module: $@network_def.to(@device)
  device_ids: ['@device']
  find_unused_parameters: true

tsampler:
  _target_: DistributedSampler
  dataset: '@train_ds'
  even_divisible: true
  shuffle: true
train_loader#sampler: '@tsampler'
train_loader#shuffle: false

vsampler:
  _target_: DistributedSampler
  dataset: '@val_ds'
  even_divisible: false
  shuffle: false
val_loader#sampler: '@vsampler'

training:
- $import torch.distributed as dist
- $dist.init_process_group(backend='nccl')
- $torch.cuda.set_device(@device)
- $monai.utils.set_determinism(seed=123),
- $@trainer.run()
- $dist.destroy_process_group()