0710-1820
Browse files- diffusion.py +3 -2
diffusion.py
CHANGED
|
@@ -556,7 +556,7 @@ class DDPM21CM:
|
|
| 556 |
# print("device =", config.device)
|
| 557 |
|
| 558 |
# %%
|
| 559 |
-
def
|
| 560 |
config = TrainConfig()
|
| 561 |
ddp_setup(rank, world_size)
|
| 562 |
|
|
@@ -569,6 +569,7 @@ def single_main(rank, world_size):
|
|
| 569 |
print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
|
| 570 |
print(f"run_name = {ddpm21cm.config.run_name}")
|
| 571 |
ddpm21cm.train()
|
|
|
|
| 572 |
|
| 573 |
|
| 574 |
if __name__ == "__main__":
|
|
@@ -576,7 +577,7 @@ if __name__ == "__main__":
|
|
| 576 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 577 |
world_size = 2#torch.cuda.device_count()
|
| 578 |
|
| 579 |
-
mp.spawn(
|
| 580 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 581 |
|
| 582 |
# %%
|
|
|
|
| 556 |
# print("device =", config.device)
|
| 557 |
|
| 558 |
# %%
|
| 559 |
+
def main(rank, world_size):
|
| 560 |
config = TrainConfig()
|
| 561 |
ddp_setup(rank, world_size)
|
| 562 |
|
|
|
|
| 569 |
print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
|
| 570 |
print(f"run_name = {ddpm21cm.config.run_name}")
|
| 571 |
ddpm21cm.train()
|
| 572 |
+
destroy_process_group()
|
| 573 |
|
| 574 |
|
| 575 |
if __name__ == "__main__":
|
|
|
|
| 577 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 578 |
world_size = 2#torch.cuda.device_count()
|
| 579 |
|
| 580 |
+
mp.spawn(main, args=(world_size,), nprocs=world_size)
|
| 581 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 582 |
|
| 583 |
# %%
|