0712-1558
Browse files- diffusion.py +8 -8
diffusion.py
CHANGED
|
@@ -528,9 +528,9 @@ class DDPM21CM:
|
|
| 528 |
|
| 529 |
# nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 530 |
if ema:
|
| 531 |
-
self.nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 532 |
else:
|
| 533 |
-
self.nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 534 |
print(f"nn_model resumed from {file}")
|
| 535 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 536 |
# nn_model.train()
|
|
@@ -601,15 +601,15 @@ if __name__ == "__main__":
|
|
| 601 |
|
| 602 |
# %%
|
| 603 |
|
| 604 |
-
def generate_samples(
|
| 605 |
samples = []
|
| 606 |
for _ in range(num_new_img // max_num_img_per_gpu):
|
| 607 |
-
sample =
|
| 608 |
samples.append(sample)
|
| 609 |
-
#
|
| 610 |
-
#
|
| 611 |
-
#
|
| 612 |
-
#
|
| 613 |
samples = np.concatenate(samples, axis=0)
|
| 614 |
|
| 615 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
|
|
|
| 528 |
|
| 529 |
# nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 530 |
if ema:
|
| 531 |
+
self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 532 |
else:
|
| 533 |
+
self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 534 |
print(f"nn_model resumed from {file}")
|
| 535 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 536 |
# nn_model.train()
|
|
|
|
| 601 |
|
| 602 |
# %%
|
| 603 |
|
| 604 |
+
def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
|
| 605 |
samples = []
|
| 606 |
for _ in range(num_new_img // max_num_img_per_gpu):
|
| 607 |
+
sample = ddpm21cm.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
|
| 608 |
samples.append(sample)
|
| 609 |
+
# ddpm21cm.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
|
| 610 |
+
# ddpm21cm.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
|
| 611 |
+
# ddpm21cm.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
|
| 612 |
+
# ddpm21cm.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
|
| 613 |
samples = np.concatenate(samples, axis=0)
|
| 614 |
|
| 615 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|