Xsmos commited on
Commit
cb6ef38
·
verified ·
1 Parent(s): fef6e9a
Files changed (1) hide show
  1. diffusion.py +8 -8
diffusion.py CHANGED
@@ -528,9 +528,9 @@ class DDPM21CM:
528
 
529
  # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
530
  if ema:
531
- self.nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
532
  else:
533
- self.nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
534
  print(f"nn_model resumed from {file}")
535
  # nn_model = ContextUnet(n_param=1, image_size=28)
536
  # nn_model.train()
@@ -601,15 +601,15 @@ if __name__ == "__main__":
601
 
602
  # %%
603
 
604
- def generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size):
605
  samples = []
606
  for _ in range(num_new_img // max_num_img_per_gpu):
607
- sample = model.module.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
608
  samples.append(sample)
609
- # model.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
610
- # model.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
611
- # model.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
612
- # model.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
613
  samples = np.concatenate(samples, axis=0)
614
 
615
  samples_list = [np.empty_like(samples) for _ in range(world_size)]
 
528
 
529
  # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
530
  if ema:
531
+ self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
532
  else:
533
+ self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
534
  print(f"nn_model resumed from {file}")
535
  # nn_model = ContextUnet(n_param=1, image_size=28)
536
  # nn_model.train()
 
601
 
602
  # %%
603
 
604
+ def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
605
  samples = []
606
  for _ in range(num_new_img // max_num_img_per_gpu):
607
+ sample = ddpm21cm.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
608
  samples.append(sample)
609
+ # ddpm21cm.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
610
+ # ddpm21cm.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
611
+ # ddpm21cm.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
612
+ # ddpm21cm.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
613
  samples = np.concatenate(samples, axis=0)
614
 
615
  samples_list = [np.empty_like(samples) for _ in range(world_size)]