Xsmos commited on
Commit
e8b4c7f
·
verified ·
1 Parent(s): 6e6d516
Files changed (1) hide show
  1. diffusion.py +3 -2
diffusion.py CHANGED
@@ -556,7 +556,7 @@ class DDPM21CM:
556
  # print("device =", config.device)
557
 
558
  # %%
559
- def single_main(rank, world_size):
560
  config = TrainConfig()
561
  ddp_setup(rank, world_size)
562
 
@@ -569,6 +569,7 @@ def single_main(rank, world_size):
569
  print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
570
  print(f"run_name = {ddpm21cm.config.run_name}")
571
  ddpm21cm.train()
 
572
 
573
 
574
  if __name__ == "__main__":
@@ -576,7 +577,7 @@ if __name__ == "__main__":
576
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
577
  world_size = 2#torch.cuda.device_count()
578
 
579
- mp.spawn(single_main, args=(world_size,), nprocs=world_size)
580
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
581
 
582
  # %%
 
556
  # print("device =", config.device)
557
 
558
  # %%
559
+ def main(rank, world_size):
560
  config = TrainConfig()
561
  ddp_setup(rank, world_size)
562
 
 
569
  print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
570
  print(f"run_name = {ddpm21cm.config.run_name}")
571
  ddpm21cm.train()
572
+ destroy_process_group()
573
 
574
 
575
  if __name__ == "__main__":
 
577
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
578
  world_size = 2#torch.cuda.device_count()
579
 
580
+ mp.spawn(main, args=(world_size,), nprocs=world_size)
581
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
582
 
583
  # %%