Spaces:
Runtime error
Runtime error
log after x epochs, not after 1
Browse files
scripts/train_unconditional.py
CHANGED
|
@@ -244,7 +244,9 @@ def main(args):
|
|
| 244 |
|
| 245 |
# Generate sample images for visual inspection
|
| 246 |
if accelerator.is_main_process:
|
| 247 |
-
if
|
|
|
|
|
|
|
| 248 |
if vqvae is not None:
|
| 249 |
pipeline = LatentAudioDiffusionPipeline(
|
| 250 |
unet=accelerator.unwrap_model(
|
|
@@ -275,7 +277,9 @@ def main(args):
|
|
| 275 |
else:
|
| 276 |
pipeline.save_pretrained(output_dir)
|
| 277 |
|
| 278 |
-
if
|
|
|
|
|
|
|
| 279 |
generator = torch.manual_seed(42)
|
| 280 |
# run pipeline in inference (sample random noise and denoise)
|
| 281 |
images, (sample_rate, audios) = pipeline(
|
|
|
|
| 244 |
|
| 245 |
# Generate sample images for visual inspection
|
| 246 |
if accelerator.is_main_process:
|
| 247 |
+
if (
|
| 248 |
+
epoch + 1
|
| 249 |
+
) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
| 250 |
if vqvae is not None:
|
| 251 |
pipeline = LatentAudioDiffusionPipeline(
|
| 252 |
unet=accelerator.unwrap_model(
|
|
|
|
| 277 |
else:
|
| 278 |
pipeline.save_pretrained(output_dir)
|
| 279 |
|
| 280 |
+
if (
|
| 281 |
+
epoch + 1
|
| 282 |
+
) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
| 283 |
generator = torch.manual_seed(42)
|
| 284 |
# run pipeline in inference (sample random noise and denoise)
|
| 285 |
images, (sample_rate, audios) = pipeline(
|