Xsmos commited on
Commit
f409e64
·
verified ·
1 Parent(s): d0e1ae0
Files changed (1) hide show
  1. diffusion.py +3 -3
diffusion.py CHANGED
@@ -259,7 +259,7 @@ class TrainConfig:
259
  n_param = 2
260
  guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
261
  drop_prob = 0#0.28 # only takes effect when guide_w != -1
262
- ema=True # whether to use ema
263
  ema_rate=0.995
264
 
265
  # seed = 0
@@ -482,8 +482,8 @@ class DDPM21CM:
482
  if self.config.save_name:
483
  model_state = {
484
  'epoch': ep,
485
- 'unet_state_dict': self.nn_model.state_dict(),
486
- 'ema_unet_state_dict': self.ema_model.state_dict(),
487
  }
488
  torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
489
  print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
 
259
  n_param = 2
260
  guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
261
  drop_prob = 0#0.28 # only takes effect when guide_w != -1
262
+ ema=False # whether to use ema
263
  ema_rate=0.995
264
 
265
  # seed = 0
 
482
  if self.config.save_name:
483
  model_state = {
484
  'epoch': ep,
485
+ 'unet_state_dict': self.nn_model.module.state_dict(),
486
+ # 'ema_unet_state_dict': self.ema_model.state_dict(),
487
  }
488
  torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
489
  print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")