GaumlessGraham commited on
Commit
01dd18b
·
verified ·
1 Parent(s): fdb6ab5

Update inner.py

Browse files
Files changed (1) hide show
  1. inner.py +4 -4
inner.py CHANGED
@@ -59,12 +59,12 @@ class TrainingConfig:
59
  image_size = 256 # the generated image resolution
60
  train_batch_size = 10
61
  eval_batch_size = 16 # how many images to sample during evaluation
62
- num_epochs = 500
63
  gradient_accumulation_steps = 1
64
  learning_rate = 1e-4
65
  lr_warmup_steps = 250
66
- save_image_epochs = 25
67
- save_model_epochs = 500
68
  mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
69
  output_dir = "Inner1730_10Real" # the model name locally and on the HF Hub
70
 
@@ -378,7 +378,7 @@ def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_s
378
  if accelerator.is_main_process:
379
  pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
380
 
381
- if ((epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1) and epoch > 0: #Change if want to not evaluate before a certain epoch
382
  evalfirst(config, epoch, pipeline)
383
 
384
  model_dir = os.path.join(config.output_dir, str(epoch))
 
59
  image_size = 256 # the generated image resolution
60
  train_batch_size = 10
61
  eval_batch_size = 16 # how many images to sample during evaluation
62
+ num_epochs = 2000
63
  gradient_accumulation_steps = 1
64
  learning_rate = 1e-4
65
  lr_warmup_steps = 250
66
+ save_image_epochs = 100
67
+ save_model_epochs = 2000
68
  mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
69
  output_dir = "Inner1730_10Real" # the model name locally and on the HF Hub
70
 
 
378
  if accelerator.is_main_process:
379
  pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
380
 
381
+ if ((epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1) and epoch > 195: #Change if want to not evaluate before a certain epoch
382
  evalfirst(config, epoch, pipeline)
383
 
384
  model_dir = os.path.join(config.output_dir, str(epoch))