Xsmos commited on
Commit
d5b0597
·
verified ·
1 Parent(s): 3ceda46
Files changed (1) hide show
  1. diffusion.py +15 -12
diffusion.py CHANGED
@@ -274,13 +274,12 @@ class TrainConfig:
274
  # save_period = 1 #10 # the period of saving model
275
  # cond = True # if training using the conditional information
276
  # lr_decay = False #True# if using the learning rate decay
277
- resume = save_name # if resume from the trained checkpoints
278
  # params_single = torch.tensor([0.2,0.80000023])
279
  # params = torch.tile(params_single,(n_sample,1)).to(device)
280
  # params = params
281
  # data_dir = './data' # data directory
282
 
283
-
284
  mixed_precision = "fp16"
285
  gradient_accumulation_steps = 1
286
 
@@ -322,10 +321,6 @@ class DDPM21CM:
322
  # initialize the unet
323
  self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)
324
 
325
- if config.resume and os.path.exists(config.resume):
326
- # resume_file = os.path.join(config.output_dir, f"{config.resume}")
327
- self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
328
- print(f"resumed nn_model from {config.resume}")
329
  # nn_model = ContextUnet(n_param=1, image_size=28)
330
  self.nn_model.train()
331
  # print("self.ddpm.device =", self.ddpm.device)
@@ -333,6 +328,14 @@ class DDPM21CM:
333
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
334
  # print("nn_model.device =", ddpm.device)
335
  # number of parameters to be trained
 
 
 
 
 
 
 
 
336
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
337
  print(f"Number of parameters for nn_model: {self.number_of_params}")
338
 
@@ -508,7 +511,7 @@ class DDPM21CM:
508
 
509
  def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
510
  # n_sample = params.shape[0]
511
- file = self.config.resume
512
 
513
  if params is None:
514
  params = torch.tensor([0.20000000000000018, 0.5055875000000001])
@@ -528,11 +531,11 @@ class DDPM21CM:
528
  # params = torch.tile(params, (n_sample,1)).to(device)
529
 
530
  # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
531
- if ema:
532
- self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
533
- else:
534
- self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
535
- print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
536
  # nn_model = ContextUnet(n_param=1, image_size=28)
537
  # nn_model.train()
538
  # self.nn_model.to(self.ddpm.device)
 
274
  # save_period = 1 #10 # the period of saving model
275
  # cond = True # if training using the conditional information
276
  # lr_decay = False #True# if using the learning rate decay
277
+ resume = False # if resume from the trained checkpoints
278
  # params_single = torch.tensor([0.2,0.80000023])
279
  # params = torch.tile(params_single,(n_sample,1)).to(device)
280
  # params = params
281
  # data_dir = './data' # data directory
282
 
 
283
  mixed_precision = "fp16"
284
  gradient_accumulation_steps = 1
285
 
 
321
  # initialize the unet
322
  self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)
323
 
 
 
 
 
324
  # nn_model = ContextUnet(n_param=1, image_size=28)
325
  self.nn_model.train()
326
  # print("self.ddpm.device =", self.ddpm.device)
 
328
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
329
  # print("nn_model.device =", ddpm.device)
330
  # number of parameters to be trained
331
+
332
+ if config.resume and os.path.exists(config.resume):
333
+ # resume_file = os.path.join(config.output_dir, f"{config.resume}")
334
+ # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
335
+ # print(f"resumed nn_model from {config.resume}")
336
+ self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
337
+ print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
338
+
339
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
340
  print(f"Number of parameters for nn_model: {self.number_of_params}")
341
 
 
511
 
512
  def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
513
  # n_sample = params.shape[0]
514
+ # file = self.config.resume
515
 
516
  if params is None:
517
  params = torch.tensor([0.20000000000000018, 0.5055875000000001])
 
531
  # params = torch.tile(params, (n_sample,1)).to(device)
532
 
533
  # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
534
+ # if ema:
535
+ # self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
536
+ # else:
537
+ # self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
538
+ # print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
539
  # nn_model = ContextUnet(n_param=1, image_size=28)
540
  # nn_model.train()
541
  # self.nn_model.to(self.ddpm.device)