0715-1359
Browse files- diffusion.py +2 -2
diffusion.py
CHANGED
|
@@ -589,10 +589,10 @@ def train(rank, world_size):
|
|
| 589 |
|
| 590 |
|
| 591 |
if __name__ == "__main__":
|
| 592 |
-
|
|
|
|
| 593 |
# torch.multiprocessing.set_start_method("spawn")
|
| 594 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 595 |
-
world_size = torch.cuda.device_count()
|
| 596 |
|
| 597 |
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 598 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
|
|
|
| 589 |
|
| 590 |
|
| 591 |
if __name__ == "__main__":
|
| 592 |
+
world_size = torch.cuda.device_count()
|
| 593 |
+
print(f" training, world_size = {world_size} ".center(100,'-'))
|
| 594 |
# torch.multiprocessing.set_start_method("spawn")
|
| 595 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
|
|
|
| 596 |
|
| 597 |
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 598 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|