0713-1718
Browse files- diffusion.py +6 -8
diffusion.py
CHANGED
|
@@ -619,7 +619,7 @@ def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_siz
|
|
| 619 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
| 620 |
dist.all_gather_object(samples_list, samples)
|
| 621 |
|
| 622 |
-
if rank ==
|
| 623 |
all_samples = np.concatenate(samples_list, axis=0)
|
| 624 |
return all_samples
|
| 625 |
else:
|
|
@@ -631,12 +631,10 @@ def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_di
|
|
| 631 |
|
| 632 |
samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 633 |
|
| 634 |
-
print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
|
| 635 |
-
|
| 636 |
-
if rank == 1:
|
| 637 |
return_dict['samples'] = samples
|
| 638 |
-
|
| 639 |
-
print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
|
| 640 |
|
| 641 |
dist.destroy_process_group()
|
| 642 |
|
|
@@ -664,8 +662,8 @@ if __name__ == "__main__":
|
|
| 664 |
|
| 665 |
mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
|
| 666 |
|
| 667 |
-
print("---"*30)
|
| 668 |
-
print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
|
| 669 |
if "samples" in return_dict:
|
| 670 |
samples = return_dict["samples"]
|
| 671 |
print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
|
|
|
|
| 619 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
| 620 |
dist.all_gather_object(samples_list, samples)
|
| 621 |
|
| 622 |
+
if rank == 0:
|
| 623 |
all_samples = np.concatenate(samples_list, axis=0)
|
| 624 |
return all_samples
|
| 625 |
else:
|
|
|
|
| 631 |
|
| 632 |
samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 633 |
|
| 634 |
+
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
|
| 635 |
+
if rank == 0:
|
|
|
|
| 636 |
return_dict['samples'] = samples
|
| 637 |
+
# print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
|
|
|
|
| 638 |
|
| 639 |
dist.destroy_process_group()
|
| 640 |
|
|
|
|
| 662 |
|
| 663 |
mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
|
| 664 |
|
| 665 |
+
# print("---"*30)
|
| 666 |
+
# print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
|
| 667 |
if "samples" in return_dict:
|
| 668 |
samples = return_dict["samples"]
|
| 669 |
print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
|