0712-1603
Browse files- diffusion.py +9 -8
diffusion.py
CHANGED
|
@@ -506,9 +506,10 @@ class DDPM21CM:
|
|
| 506 |
value = value * (to[1]-to[0]) + to[0]
|
| 507 |
return value
|
| 508 |
|
| 509 |
-
def sample(self,
|
| 510 |
# n_sample = params.shape[0]
|
| 511 |
-
|
|
|
|
| 512 |
if params is None:
|
| 513 |
params = torch.tensor([0.20000000000000018, 0.5055875000000001])
|
| 514 |
params_backup = params.numpy().copy()
|
|
@@ -604,12 +605,12 @@ if __name__ == "__main__":
|
|
| 604 |
def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
|
| 605 |
samples = []
|
| 606 |
for _ in range(num_new_img // max_num_img_per_gpu):
|
| 607 |
-
sample = ddpm21cm.sample(
|
| 608 |
samples.append(sample)
|
| 609 |
-
# ddpm21cm.sample(
|
| 610 |
-
# ddpm21cm.sample(
|
| 611 |
-
# ddpm21cm.sample(
|
| 612 |
-
# ddpm21cm.sample(
|
| 613 |
samples = np.concatenate(samples, axis=0)
|
| 614 |
|
| 615 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
|
@@ -648,8 +649,8 @@ if __name__ == "__main__":
|
|
| 648 |
# print("config.world_size = world_size")
|
| 649 |
|
| 650 |
for num_image in num_image_list:
|
| 651 |
-
filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
|
| 652 |
config.num_image = num_image
|
|
|
|
| 653 |
|
| 654 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 655 |
manager = mp.Manager()
|
|
|
|
| 506 |
value = value * (to[1]-to[0]) + to[0]
|
| 507 |
return value
|
| 508 |
|
| 509 |
+
def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
|
| 510 |
# n_sample = params.shape[0]
|
| 511 |
+
file = self.config.resume
|
| 512 |
+
|
| 513 |
if params is None:
|
| 514 |
params = torch.tensor([0.20000000000000018, 0.5055875000000001])
|
| 515 |
params_backup = params.numpy().copy()
|
|
|
|
| 605 |
def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
|
| 606 |
samples = []
|
| 607 |
for _ in range(num_new_img // max_num_img_per_gpu):
|
| 608 |
+
sample = ddpm21cm.sample(params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
|
| 609 |
samples.append(sample)
|
| 610 |
+
# ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
|
| 611 |
+
# ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
|
| 612 |
+
# ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
|
| 613 |
+
# ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
|
| 614 |
samples = np.concatenate(samples, axis=0)
|
| 615 |
|
| 616 |
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
|
|
|
| 649 |
# print("config.world_size = world_size")
|
| 650 |
|
| 651 |
for num_image in num_image_list:
|
|
|
|
| 652 |
config.num_image = num_image
|
| 653 |
+
config.resume = f"./outputs/model_state-N{num_image}-epoch9-device0"
|
| 654 |
|
| 655 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 656 |
manager = mp.Manager()
|