0712-1610
Browse files- diffusion.py +3 -3
diffusion.py
CHANGED
|
@@ -150,7 +150,7 @@ class DDPMScheduler(nn.Module):
|
|
| 150 |
# for i in range(self.num_timesteps, 0, -1):
|
| 151 |
# print(f'sampling!!!')
|
| 152 |
pbar_sample = tqdm(total=self.num_timesteps)
|
| 153 |
-
pbar_sample.set_description("
|
| 154 |
for i in reversed(range(0, self.num_timesteps)):
|
| 155 |
# print(f'sampling timestep {i:4d}',end='\r')
|
| 156 |
t_is = torch.tensor([i]).to(device)
|
|
@@ -517,7 +517,7 @@ class DDPM21CM:
|
|
| 517 |
params_backup = params.numpy().copy()
|
| 518 |
params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 519 |
|
| 520 |
-
print(f"sampling {num_new_img} images with normalized params = {params}")
|
| 521 |
params = params.repeat(num_new_img,1)
|
| 522 |
assert params.dim() == 2, "params must be a 2D torch.tensor"
|
| 523 |
# print("params =", params)
|
|
@@ -640,7 +640,7 @@ if __name__ == "__main__":
|
|
| 640 |
world_size = torch.cuda.device_count()
|
| 641 |
# num_image_list = [1600,3200,6400,12800,25600]
|
| 642 |
num_image_list = [100]
|
| 643 |
-
num_new_img =
|
| 644 |
max_num_img_per_gpu = 2
|
| 645 |
|
| 646 |
# print("config = TrainConfig()")
|
|
|
|
| 150 |
# for i in range(self.num_timesteps, 0, -1):
|
| 151 |
# print(f'sampling!!!')
|
| 152 |
pbar_sample = tqdm(total=self.num_timesteps)
|
| 153 |
+
pbar_sample.set_description(f"device {torch.cuda.current_device()} sampling")
|
| 154 |
for i in reversed(range(0, self.num_timesteps)):
|
| 155 |
# print(f'sampling timestep {i:4d}',end='\r')
|
| 156 |
t_is = torch.tensor([i]).to(device)
|
|
|
|
| 517 |
params_backup = params.numpy().copy()
|
| 518 |
params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 519 |
|
| 520 |
+
print(f"device {torch.cuda.current_device()} sampling {num_new_img} images with normalized params = {params}")
|
| 521 |
params = params.repeat(num_new_img,1)
|
| 522 |
assert params.dim() == 2, "params must be a 2D torch.tensor"
|
| 523 |
# print("params =", params)
|
|
|
|
| 640 |
world_size = torch.cuda.device_count()
|
| 641 |
# num_image_list = [1600,3200,6400,12800,25600]
|
| 642 |
num_image_list = [100]
|
| 643 |
+
num_new_img = 6
|
| 644 |
max_num_img_per_gpu = 2
|
| 645 |
|
| 646 |
# print("config = TrainConfig()")
|