Spaces:
Runtime error
Runtime error
fix when save_model_epochs != save_images_epochs
Browse files
scripts/train_unconditional.py
CHANGED
|
@@ -176,11 +176,11 @@ def main(args):
|
|
| 176 |
|
| 177 |
if args.push_to_hub:
|
| 178 |
if args.hub_model_id is None:
|
| 179 |
-
repo_name = get_full_repo_name(Path(
|
| 180 |
token=args.hub_token)
|
| 181 |
else:
|
| 182 |
repo_name = args.hub_model_id
|
| 183 |
-
repo = Repository(
|
| 184 |
|
| 185 |
if accelerator.is_main_process:
|
| 186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
@@ -270,9 +270,9 @@ def main(args):
|
|
| 270 |
|
| 271 |
# Generate sample images for visual inspection
|
| 272 |
if accelerator.is_main_process:
|
| 273 |
-
if (
|
| 274 |
epoch + 1
|
| 275 |
-
) % args.
|
| 276 |
pipeline = AudioDiffusionPipeline(
|
| 277 |
vqvae=vqvae,
|
| 278 |
unet=accelerator.unwrap_model(
|
|
@@ -280,15 +280,17 @@ def main(args):
|
|
| 280 |
mel=mel,
|
| 281 |
scheduler=noise_scheduler,
|
| 282 |
)
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
# save the model
|
| 286 |
if args.push_to_hub:
|
| 287 |
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
| 288 |
blocking=False,
|
| 289 |
auto_lfs_prune=True)
|
| 290 |
-
else:
|
| 291 |
-
pipeline.save_pretrained(output_dir)
|
| 292 |
|
| 293 |
if (epoch + 1) % args.save_images_epochs == 0:
|
| 294 |
generator = torch.Generator(
|
|
|
|
| 176 |
|
| 177 |
if args.push_to_hub:
|
| 178 |
if args.hub_model_id is None:
|
| 179 |
+
repo_name = get_full_repo_name(Path(output_dir).name,
|
| 180 |
token=args.hub_token)
|
| 181 |
else:
|
| 182 |
repo_name = args.hub_model_id
|
| 183 |
+
repo = Repository(output_dir, clone_from=repo_name)
|
| 184 |
|
| 185 |
if accelerator.is_main_process:
|
| 186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
|
|
| 270 |
|
| 271 |
# Generate sample images for visual inspection
|
| 272 |
if accelerator.is_main_process:
|
| 273 |
+
if (epoch + 1) % args.save_model_epochs == 0 or (
|
| 274 |
epoch + 1
|
| 275 |
+
) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
| 276 |
pipeline = AudioDiffusionPipeline(
|
| 277 |
vqvae=vqvae,
|
| 278 |
unet=accelerator.unwrap_model(
|
|
|
|
| 280 |
mel=mel,
|
| 281 |
scheduler=noise_scheduler,
|
| 282 |
)
|
| 283 |
+
|
| 284 |
+
if (
|
| 285 |
+
epoch + 1
|
| 286 |
+
) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
| 287 |
+
pipeline.save_pretrained(output_dir)
|
| 288 |
|
| 289 |
# save the model
|
| 290 |
if args.push_to_hub:
|
| 291 |
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
| 292 |
blocking=False,
|
| 293 |
auto_lfs_prune=True)
|
|
|
|
|
|
|
| 294 |
|
| 295 |
if (epoch + 1) % args.save_images_epochs == 0:
|
| 296 |
generator = torch.Generator(
|