0710-1805
Browse files- diffusion.py +3 -3
diffusion.py
CHANGED
|
@@ -259,7 +259,7 @@ class TrainConfig:
|
|
| 259 |
n_param = 2
|
| 260 |
guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
|
| 261 |
drop_prob = 0#0.28 # only takes effect when guide_w != -1
|
| 262 |
-
ema=
|
| 263 |
ema_rate=0.995
|
| 264 |
|
| 265 |
# seed = 0
|
|
@@ -482,8 +482,8 @@ class DDPM21CM:
|
|
| 482 |
if self.config.save_name:
|
| 483 |
model_state = {
|
| 484 |
'epoch': ep,
|
| 485 |
-
'unet_state_dict': self.nn_model.state_dict(),
|
| 486 |
-
'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 487 |
}
|
| 488 |
torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
|
| 489 |
print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
|
|
|
|
| 259 |
n_param = 2
|
| 260 |
guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
|
| 261 |
drop_prob = 0#0.28 # only takes effect when guide_w != -1
|
| 262 |
+
ema=False # whether to use ema
|
| 263 |
ema_rate=0.995
|
| 264 |
|
| 265 |
# seed = 0
|
|
|
|
| 482 |
if self.config.save_name:
|
| 483 |
model_state = {
|
| 484 |
'epoch': ep,
|
| 485 |
+
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 486 |
+
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 487 |
}
|
| 488 |
torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
|
| 489 |
print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
|