Spaces:
Runtime error
Runtime error
remove unnecessary check
Browse files- scripts/train_unconditional.py +32 -32
scripts/train_unconditional.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
# based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
|
| 2 |
|
| 3 |
-
# TODO
|
| 4 |
-
# Migrate to diffusers
|
| 5 |
-
# from diffusers.hub_utils import Repository
|
| 6 |
-
|
| 7 |
import argparse
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
@@ -14,13 +11,14 @@ from accelerate import Accelerator
|
|
| 14 |
from accelerate.logging import get_logger
|
| 15 |
from datasets import load_from_disk, load_dataset
|
| 16 |
from diffusers import (
|
| 17 |
-
|
|
|
|
| 18 |
DDPMScheduler,
|
| 19 |
UNet2DModel,
|
| 20 |
DDIMScheduler,
|
| 21 |
AutoencoderKL,
|
| 22 |
)
|
| 23 |
-
from
|
| 24 |
from diffusers.optimization import get_scheduler
|
| 25 |
from diffusers.training_utils import EMAModel
|
| 26 |
from torchvision.transforms import (
|
|
@@ -32,12 +30,21 @@ import numpy as np
|
|
| 32 |
from tqdm.auto import tqdm
|
| 33 |
from librosa.util import normalize
|
| 34 |
|
| 35 |
-
#from diffusers import Mel, AudioDiffusionPipeline
|
| 36 |
-
from audiodiffusion import Mel, AudioDiffusionPipeline
|
| 37 |
-
|
| 38 |
logger = get_logger(__name__)
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def main(args):
|
| 42 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
| 43 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
|
@@ -94,8 +101,7 @@ def main(args):
|
|
| 94 |
try:
|
| 95 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
| 96 |
except EnvironmentError:
|
| 97 |
-
vqvae = AudioDiffusionPipeline.from_pretrained(
|
| 98 |
-
args.vae).vqvae
|
| 99 |
# Determine latent resolution
|
| 100 |
with torch.no_grad():
|
| 101 |
latent_resolution = (vqvae.encode(
|
|
@@ -169,7 +175,12 @@ def main(args):
|
|
| 169 |
)
|
| 170 |
|
| 171 |
if args.push_to_hub:
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
if accelerator.is_main_process:
|
| 175 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
@@ -265,24 +276,17 @@ def main(args):
|
|
| 265 |
pipeline = AudioDiffusionPipeline(
|
| 266 |
vqvae=vqvae,
|
| 267 |
unet=accelerator.unwrap_model(
|
| 268 |
-
ema_model.averaged_model if args.use_ema else model
|
| 269 |
-
),
|
| 270 |
mel=mel,
|
| 271 |
scheduler=noise_scheduler,
|
| 272 |
)
|
|
|
|
| 273 |
|
| 274 |
# save the model
|
| 275 |
if args.push_to_hub:
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
pipeline,
|
| 280 |
-
repo,
|
| 281 |
-
commit_message=f"Epoch {epoch}",
|
| 282 |
-
blocking=False,
|
| 283 |
-
)
|
| 284 |
-
except NameError: # current version of diffusers has a little bug
|
| 285 |
-
pass
|
| 286 |
else:
|
| 287 |
pipeline.save_pretrained(output_dir)
|
| 288 |
|
|
@@ -290,11 +294,10 @@ def main(args):
|
|
| 290 |
generator = torch.Generator(
|
| 291 |
device=clean_images.device).manual_seed(42)
|
| 292 |
# run pipeline in inference (sample random noise and denoise)
|
| 293 |
-
images, (sample_rate,
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
)
|
| 298 |
|
| 299 |
# denormalize the images and save to tensorboard
|
| 300 |
images = np.array([
|
|
@@ -390,8 +393,5 @@ if __name__ == "__main__":
|
|
| 390 |
raise ValueError(
|
| 391 |
"You must specify either a dataset name from the hub or a train data directory."
|
| 392 |
)
|
| 393 |
-
if args.dataset_name is not None and args.dataset_name == args.hub_model_id:
|
| 394 |
-
raise ValueError(
|
| 395 |
-
"The local dataset name must be different from the hub model id.")
|
| 396 |
|
| 397 |
main(args)
|
|
|
|
| 1 |
# based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
+
from typing import Optional
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
|
|
|
| 11 |
from accelerate.logging import get_logger
|
| 12 |
from datasets import load_from_disk, load_dataset
|
| 13 |
from diffusers import (
|
| 14 |
+
AudioDiffusionPipeline,
|
| 15 |
+
Mel,
|
| 16 |
DDPMScheduler,
|
| 17 |
UNet2DModel,
|
| 18 |
DDIMScheduler,
|
| 19 |
AutoencoderKL,
|
| 20 |
)
|
| 21 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
| 22 |
from diffusers.optimization import get_scheduler
|
| 23 |
from diffusers.training_utils import EMAModel
|
| 24 |
from torchvision.transforms import (
|
|
|
|
| 30 |
from tqdm.auto import tqdm
|
| 31 |
from librosa.util import normalize
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
logger = get_logger(__name__)
|
| 34 |
|
| 35 |
|
| 36 |
+
def get_full_repo_name(model_id: str,
|
| 37 |
+
organization: Optional[str] = None,
|
| 38 |
+
token: Optional[str] = None):
|
| 39 |
+
if token is None:
|
| 40 |
+
token = HfFolder.get_token()
|
| 41 |
+
if organization is None:
|
| 42 |
+
username = whoami(token)["name"]
|
| 43 |
+
return f"{username}/{model_id}"
|
| 44 |
+
else:
|
| 45 |
+
return f"{organization}/{model_id}"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def main(args):
|
| 49 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
| 50 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
|
|
|
| 101 |
try:
|
| 102 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
| 103 |
except EnvironmentError:
|
| 104 |
+
vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
|
|
|
|
| 105 |
# Determine latent resolution
|
| 106 |
with torch.no_grad():
|
| 107 |
latent_resolution = (vqvae.encode(
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
if args.push_to_hub:
|
| 178 |
+
if args.hub_model_id is None:
|
| 179 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name,
|
| 180 |
+
token=args.hub_token)
|
| 181 |
+
else:
|
| 182 |
+
repo_name = args.hub_model_id
|
| 183 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
| 184 |
|
| 185 |
if accelerator.is_main_process:
|
| 186 |
run = os.path.split(__file__)[-1].split(".")[0]
|
|
|
|
| 276 |
pipeline = AudioDiffusionPipeline(
|
| 277 |
vqvae=vqvae,
|
| 278 |
unet=accelerator.unwrap_model(
|
| 279 |
+
ema_model.averaged_model if args.use_ema else model),
|
|
|
|
| 280 |
mel=mel,
|
| 281 |
scheduler=noise_scheduler,
|
| 282 |
)
|
| 283 |
+
pipeline.save_pretrained(args.output_dir)
|
| 284 |
|
| 285 |
# save the model
|
| 286 |
if args.push_to_hub:
|
| 287 |
+
repo.push_to_hub(commit_message=f"Epoch {epoch}",
|
| 288 |
+
blocking=False,
|
| 289 |
+
auto_lfs_prune=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
else:
|
| 291 |
pipeline.save_pretrained(output_dir)
|
| 292 |
|
|
|
|
| 294 |
generator = torch.Generator(
|
| 295 |
device=clean_images.device).manual_seed(42)
|
| 296 |
# run pipeline in inference (sample random noise and denoise)
|
| 297 |
+
images, (sample_rate,
|
| 298 |
+
audios) = pipeline(generator=generator,
|
| 299 |
+
batch_size=args.eval_batch_size,
|
| 300 |
+
return_dict=False)
|
|
|
|
| 301 |
|
| 302 |
# denormalize the images and save to tensorboard
|
| 303 |
images = np.array([
|
|
|
|
| 393 |
raise ValueError(
|
| 394 |
"You must specify either a dataset name from the hub or a train data directory."
|
| 395 |
)
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
main(args)
|