Xsmos commited on
Commit
ad2bb20
·
verified ·
1 Parent(s): f2109c7
Files changed (1) hide show
  1. diffusion.py +6 -8
diffusion.py CHANGED
@@ -619,7 +619,7 @@ def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_siz
619
  samples_list = [np.empty_like(samples) for _ in range(world_size)]
620
  dist.all_gather_object(samples_list, samples)
621
 
622
- if rank == 1:
623
  all_samples = np.concatenate(samples_list, axis=0)
624
  return all_samples
625
  else:
@@ -631,12 +631,10 @@ def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_di
631
 
632
  samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
633
 
634
- print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
635
-
636
- if rank == 1:
637
  return_dict['samples'] = samples
638
-
639
- print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
640
 
641
  dist.destroy_process_group()
642
 
@@ -664,8 +662,8 @@ if __name__ == "__main__":
664
 
665
  mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
666
 
667
- print("---"*30)
668
- print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
669
  if "samples" in return_dict:
670
  samples = return_dict["samples"]
671
  print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
 
619
  samples_list = [np.empty_like(samples) for _ in range(world_size)]
620
  dist.all_gather_object(samples_list, samples)
621
 
622
+ if rank == 0:
623
  all_samples = np.concatenate(samples_list, axis=0)
624
  return all_samples
625
  else:
 
631
 
632
  samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
633
 
634
+ # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
635
+ if rank == 0:
 
636
  return_dict['samples'] = samples
637
+ # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
 
638
 
639
  dist.destroy_process_group()
640
 
 
662
 
663
  mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
664
 
665
+ # print("---"*30)
666
+ # print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
667
  if "samples" in return_dict:
668
  samples = return_dict["samples"]
669
  print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")