| { | |
| "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])", | |
| "gnetwork": { | |
| "_target_": "torch.nn.parallel.DistributedDataParallel", | |
| "module": "$@autoencoder_def.to(@device)", | |
| "device_ids": [ | |
| "@device" | |
| ] | |
| }, | |
| "dnetwork": { | |
| "_target_": "torch.nn.parallel.DistributedDataParallel", | |
| "module": "$@discriminator_def.to(@device)", | |
| "device_ids": [ | |
| "@device" | |
| ] | |
| }, | |
| "train#sampler": { | |
| "_target_": "DistributedSampler", | |
| "dataset": "@train#dataset", | |
| "even_divisible": true, | |
| "shuffle": true | |
| }, | |
| "train#dataloader#sampler": "@train#sampler", | |
| "train#dataloader#shuffle": false, | |
| "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]", | |
| "initialize": [ | |
| "$import torch.distributed as dist", | |
| "$import os", | |
| "$dist.is_initialized() or dist.init_process_group(backend='nccl')", | |
| "$torch.cuda.set_device(@device)", | |
| "$monai.utils.set_determinism(seed=123)", | |
| "$import logging", | |
| "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)" | |
| ], | |
| "run": [ | |
| "$@train#trainer.run()" | |
| ], | |
| "finalize": [ | |
| "$dist.is_initialized() and dist.destroy_process_group()" | |
| ] | |
| } | |