Xsmos commited on
Commit
7cb1435
·
verified ·
1 Parent(s): 8ffd5a4
Files changed (1) hide show
  1. diffusion.py +3 -3
diffusion.py CHANGED
@@ -150,7 +150,7 @@ class DDPMScheduler(nn.Module):
150
  # for i in range(self.num_timesteps, 0, -1):
151
  # print(f'sampling!!!')
152
  pbar_sample = tqdm(total=self.num_timesteps)
153
- pbar_sample.set_description("Sampling")
154
  for i in reversed(range(0, self.num_timesteps)):
155
  # print(f'sampling timestep {i:4d}',end='\r')
156
  t_is = torch.tensor([i]).to(device)
@@ -517,7 +517,7 @@ class DDPM21CM:
517
  params_backup = params.numpy().copy()
518
  params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
519
 
520
- print(f"sampling {num_new_img} images with normalized params = {params}")
521
  params = params.repeat(num_new_img,1)
522
  assert params.dim() == 2, "params must be a 2D torch.tensor"
523
  # print("params =", params)
@@ -640,7 +640,7 @@ if __name__ == "__main__":
640
  world_size = torch.cuda.device_count()
641
  # num_image_list = [1600,3200,6400,12800,25600]
642
  num_image_list = [100]
643
- num_new_img = 12
644
  max_num_img_per_gpu = 2
645
 
646
  # print("config = TrainConfig()")
 
150
  # for i in range(self.num_timesteps, 0, -1):
151
  # print(f'sampling!!!')
152
  pbar_sample = tqdm(total=self.num_timesteps)
153
+ pbar_sample.set_description(f"device {torch.cuda.current_device()} sampling")
154
  for i in reversed(range(0, self.num_timesteps)):
155
  # print(f'sampling timestep {i:4d}',end='\r')
156
  t_is = torch.tensor([i]).to(device)
 
517
  params_backup = params.numpy().copy()
518
  params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
519
 
520
+ print(f"device {torch.cuda.current_device()} sampling {num_new_img} images with normalized params = {params}")
521
  params = params.repeat(num_new_img,1)
522
  assert params.dim() == 2, "params must be a 2D torch.tensor"
523
  # print("params =", params)
 
640
  world_size = torch.cuda.device_count()
641
  # num_image_list = [1600,3200,6400,12800,25600]
642
  num_image_list = [100]
643
+ num_new_img = 6
644
  max_num_img_per_gpu = 2
645
 
646
  # print("config = TrainConfig()")