0712-1532
Browse files- diffusion.py +6 -6
diffusion.py
CHANGED
|
@@ -622,8 +622,11 @@ def generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size):
|
|
| 622 |
return None
|
| 623 |
|
| 624 |
|
| 625 |
-
def sample(rank, world_size,
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
if rank == 0:
|
| 629 |
return_dict['samples'] = samples
|
|
@@ -648,14 +651,11 @@ if __name__ == "__main__":
|
|
| 648 |
filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
|
| 649 |
config.num_image = num_image
|
| 650 |
|
| 651 |
-
ddp_setup(rank, world_size)
|
| 652 |
-
ddpm21cm = DDPM21CM(config)
|
| 653 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 654 |
-
|
| 655 |
manager = np.Manager()
|
| 656 |
return_dict = manager.dict()
|
| 657 |
|
| 658 |
-
mp.spawn(sample, args=(world_size,
|
| 659 |
|
| 660 |
if "samples" in return_dict:
|
| 661 |
samples = return_dict["samples"]
|
|
|
|
| 622 |
return None
|
| 623 |
|
| 624 |
|
| 625 |
+
def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_dict):
|
| 626 |
+
ddp_setup(rank, world_size)
|
| 627 |
+
ddpm21cm = DDPM21CM(config)
|
| 628 |
+
|
| 629 |
+
samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 630 |
|
| 631 |
if rank == 0:
|
| 632 |
return_dict['samples'] = samples
|
|
|
|
| 651 |
filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
|
| 652 |
config.num_image = num_image
|
| 653 |
|
|
|
|
|
|
|
| 654 |
# print("ddpm21cm = DDPM21CM(config)")
|
|
|
|
| 655 |
manager = np.Manager()
|
| 656 |
return_dict = manager.dict()
|
| 657 |
|
| 658 |
+
mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
|
| 659 |
|
| 660 |
if "samples" in return_dict:
|
| 661 |
samples = return_dict["samples"]
|