Xsmos commited on
Commit
7cdcb6e
·
verified ·
1 Parent(s): 38a0ad7
Files changed (1) hide show
  1. diffusion.py +6 -6
diffusion.py CHANGED
@@ -622,8 +622,11 @@ def generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size):
622
  return None
623
 
624
 
625
- def sample(rank, world_size, model, num_new_img, max_num_img_per_gpu, return_dict):
626
- samples = generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size)
 
 
 
627
 
628
  if rank == 0:
629
  return_dict['samples'] = samples
@@ -648,14 +651,11 @@ if __name__ == "__main__":
648
  filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
649
  config.num_image = num_image
650
 
651
- ddp_setup(rank, world_size)
652
- ddpm21cm = DDPM21CM(config)
653
  # print("ddpm21cm = DDPM21CM(config)")
654
-
655
  manager = np.Manager()
656
  return_dict = manager.dict()
657
 
658
- mp.spawn(sample, args=(world_size, ddpm21cm, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
659
 
660
  if "samples" in return_dict:
661
  samples = return_dict["samples"]
 
622
  return None
623
 
624
 
625
+ def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_dict):
626
+ ddp_setup(rank, world_size)
627
+ ddpm21cm = DDPM21CM(config)
628
+
629
+ samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
630
 
631
  if rank == 0:
632
  return_dict['samples'] = samples
 
651
  filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
652
  config.num_image = num_image
653
 
 
 
654
  # print("ddpm21cm = DDPM21CM(config)")
 
655
  manager = np.Manager()
656
  return_dict = manager.dict()
657
 
658
+ mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
659
 
660
  if "samples" in return_dict:
661
  samples = return_dict["samples"]