Xsmos commited on
Commit
433e497
·
verified ·
1 Parent(s): cb6ef38
Files changed (1) hide show
  1. diffusion.py +9 -8
diffusion.py CHANGED
@@ -506,9 +506,10 @@ class DDPM21CM:
506
  value = value * (to[1]-to[0]) + to[0]
507
  return value
508
 
509
- def sample(self, file, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
510
  # n_sample = params.shape[0]
511
-
 
512
  if params is None:
513
  params = torch.tensor([0.20000000000000018, 0.5055875000000001])
514
  params_backup = params.numpy().copy()
@@ -604,12 +605,12 @@ if __name__ == "__main__":
604
  def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
605
  samples = []
606
  for _ in range(num_new_img // max_num_img_per_gpu):
607
- sample = ddpm21cm.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
608
  samples.append(sample)
609
- # ddpm21cm.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
610
- # ddpm21cm.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
611
- # ddpm21cm.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
612
- # ddpm21cm.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
613
  samples = np.concatenate(samples, axis=0)
614
 
615
  samples_list = [np.empty_like(samples) for _ in range(world_size)]
@@ -648,8 +649,8 @@ if __name__ == "__main__":
648
  # print("config.world_size = world_size")
649
 
650
  for num_image in num_image_list:
651
- filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
652
  config.num_image = num_image
 
653
 
654
  # print("ddpm21cm = DDPM21CM(config)")
655
  manager = mp.Manager()
 
506
  value = value * (to[1]-to[0]) + to[0]
507
  return value
508
 
509
+ def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
510
  # n_sample = params.shape[0]
511
+ file = self.config.resume
512
+
513
  if params is None:
514
  params = torch.tensor([0.20000000000000018, 0.5055875000000001])
515
  params_backup = params.numpy().copy()
 
605
  def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
606
  samples = []
607
  for _ in range(num_new_img // max_num_img_per_gpu):
608
+ sample = ddpm21cm.sample(params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
609
  samples.append(sample)
610
+ # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
611
+ # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
612
+ # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
613
+ # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
614
  samples = np.concatenate(samples, axis=0)
615
 
616
  samples_list = [np.empty_like(samples) for _ in range(world_size)]
 
649
  # print("config.world_size = world_size")
650
 
651
  for num_image in num_image_list:
 
652
  config.num_image = num_image
653
+ config.resume = f"./outputs/model_state-N{num_image}-epoch9-device0"
654
 
655
  # print("ddpm21cm = DDPM21CM(config)")
656
  manager = mp.Manager()