Update inner.py
Browse files
inner.py
CHANGED
|
@@ -59,12 +59,12 @@ class TrainingConfig:
|
|
| 59 |
image_size = 256 # the generated image resolution
|
| 60 |
train_batch_size = 10
|
| 61 |
eval_batch_size = 16 # how many images to sample during evaluation
|
| 62 |
-
num_epochs =
|
| 63 |
gradient_accumulation_steps = 1
|
| 64 |
learning_rate = 1e-4
|
| 65 |
lr_warmup_steps = 250
|
| 66 |
-
save_image_epochs =
|
| 67 |
-
save_model_epochs =
|
| 68 |
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
| 69 |
output_dir = "Inner1730_10Real" # the model name locally and on the HF Hub
|
| 70 |
|
|
@@ -378,7 +378,7 @@ def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_s
|
|
| 378 |
if accelerator.is_main_process:
|
| 379 |
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
| 380 |
|
| 381 |
-
if ((epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1) and epoch >
|
| 382 |
evalfirst(config, epoch, pipeline)
|
| 383 |
|
| 384 |
model_dir = os.path.join(config.output_dir, str(epoch))
|
|
|
|
| 59 |
image_size = 256 # the generated image resolution
|
| 60 |
train_batch_size = 10
|
| 61 |
eval_batch_size = 16 # how many images to sample during evaluation
|
| 62 |
+
num_epochs = 2000
|
| 63 |
gradient_accumulation_steps = 1
|
| 64 |
learning_rate = 1e-4
|
| 65 |
lr_warmup_steps = 250
|
| 66 |
+
save_image_epochs = 100
|
| 67 |
+
save_model_epochs = 2000
|
| 68 |
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
| 69 |
output_dir = "Inner1730_10Real" # the model name locally and on the HF Hub
|
| 70 |
|
|
|
|
| 378 |
if accelerator.is_main_process:
|
| 379 |
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
| 380 |
|
| 381 |
+
if ((epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1) and epoch > 195: #Change if want to not evaluate before a certain epoch
|
| 382 |
evalfirst(config, epoch, pipeline)
|
| 383 |
|
| 384 |
model_dir = os.path.join(config.output_dir, str(epoch))
|