| | import torch |
| | from torch.nn import Parameter |
| | from ..models.factory import create_model_from_config |
| |
|
| | def create_training_wrapper_from_config(model_config, model): |
| | model_type = model_config.get('model_type', None) |
| | assert model_type is not None, 'model_type must be specified in model config' |
| |
|
| | training_config = model_config.get('training', None) |
| | assert training_config is not None, 'training config must be specified in model config' |
| |
|
| | if model_type == 'autoencoder': |
| | from .autoencoders import AutoencoderTrainingWrapper |
| |
|
| | ema_copy = None |
| |
|
| | if training_config.get("use_ema", False): |
| | ema_copy = create_model_from_config(model_config) |
| | ema_copy = create_model_from_config(model_config) |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | use_ema = training_config.get("use_ema", False) |
| |
|
| | latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) |
| |
|
| | teacher_model = training_config.get("teacher_model", None) |
| | if teacher_model is not None: |
| | teacher_model = create_model_from_config(teacher_model) |
| | teacher_model = teacher_model.eval().requires_grad_(False) |
| |
|
| | teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) |
| | if teacher_model_ckpt is not None: |
| | teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) |
| | else: |
| | raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") |
| |
|
| | return AutoencoderTrainingWrapper( |
| | model, |
| | lr=training_config["learning_rate"], |
| | warmup_steps=training_config.get("warmup_steps", 0), |
| | encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), |
| | sample_rate=model_config["sample_rate"], |
| | loss_config=training_config.get("loss_configs", None), |
| | optimizer_configs=training_config.get("optimizer_configs", None), |
| | use_ema=use_ema, |
| | ema_copy=ema_copy if use_ema else None, |
| | force_input_mono=training_config.get("force_input_mono", False), |
| | latent_mask_ratio=latent_mask_ratio, |
| | teacher_model=teacher_model |
| | ) |
| | elif model_type == 'diffusion_uncond': |
| | from .diffusion import DiffusionUncondTrainingWrapper |
| | return DiffusionUncondTrainingWrapper( |
| | model, |
| | lr=training_config["learning_rate"], |
| | pre_encoded=training_config.get("pre_encoded", False), |
| | ) |
| | elif model_type == 'diffusion_cond': |
| | from .diffusion import DiffusionCondTrainingWrapper |
| | return DiffusionCondTrainingWrapper( |
| | model, |
| | lr=training_config.get("learning_rate", None), |
| | mask_padding=training_config.get("mask_padding", False), |
| | mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), |
| | use_ema = training_config.get("use_ema", True), |
| | log_loss_info=training_config.get("log_loss_info", False), |
| | optimizer_configs=training_config.get("optimizer_configs", None), |
| | pre_encoded=training_config.get("pre_encoded", False), |
| | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), |
| | timestep_sampler = training_config.get("timestep_sampler", "uniform") |
| | ) |
| | elif model_type == 'diffusion_prior': |
| | from .diffusion import DiffusionPriorTrainingWrapper |
| | from ..models.diffusion_prior import PriorType |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| | |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | prior_type = training_config.get("prior_type", "mono_stereo") |
| |
|
| | if prior_type == "mono_stereo": |
| | prior_type_enum = PriorType.MonoToStereo |
| | else: |
| | raise ValueError(f"Unknown prior type: {prior_type}") |
| |
|
| | return DiffusionPriorTrainingWrapper( |
| | model, |
| | lr=training_config["learning_rate"], |
| | ema_copy=ema_copy, |
| | prior_type=prior_type_enum, |
| | log_loss_info=training_config.get("log_loss_info", False), |
| | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), |
| | ) |
| | elif model_type == 'diffusion_cond_inpaint': |
| | from .diffusion import DiffusionCondInpaintTrainingWrapper |
| | return DiffusionCondInpaintTrainingWrapper( |
| | model, |
| | lr=training_config.get("learning_rate", None), |
| | max_mask_segments = training_config.get("max_mask_segments", 10), |
| | log_loss_info=training_config.get("log_loss_info", False), |
| | optimizer_configs=training_config.get("optimizer_configs", None), |
| | use_ema=training_config.get("use_ema", True), |
| | pre_encoded=training_config.get("pre_encoded", False), |
| | cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), |
| | timestep_sampler = training_config.get("timestep_sampler", "uniform") |
| | ) |
| | elif model_type == 'diffusion_autoencoder': |
| | from .diffusion import DiffusionAutoencoderTrainingWrapper |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| | |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | return DiffusionAutoencoderTrainingWrapper( |
| | model, |
| | ema_copy=ema_copy, |
| | lr=training_config["learning_rate"], |
| | use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) |
| | ) |
| | elif model_type == 'lm': |
| | from .lm import AudioLanguageModelTrainingWrapper |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| |
|
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | return AudioLanguageModelTrainingWrapper( |
| | model, |
| | ema_copy=ema_copy, |
| | lr=training_config.get("learning_rate", None), |
| | use_ema=training_config.get("use_ema", False), |
| | optimizer_configs=training_config.get("optimizer_configs", None), |
| | pre_encoded=training_config.get("pre_encoded", False), |
| | ) |
| |
|
| | else: |
| | raise NotImplementedError(f'Unknown model type: {model_type}') |
| |
|
| | def create_demo_callback_from_config(model_config, **kwargs): |
| | model_type = model_config.get('model_type', None) |
| | assert model_type is not None, 'model_type must be specified in model config' |
| |
|
| | training_config = model_config.get('training', None) |
| | assert training_config is not None, 'training config must be specified in model config' |
| |
|
| | demo_config = training_config.get("demo", {}) |
| |
|
| | if model_type == 'autoencoder': |
| | from .autoencoders import AutoencoderDemoCallback |
| | return AutoencoderDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | **kwargs |
| | ) |
| | elif model_type == 'diffusion_uncond': |
| | from .diffusion import DiffusionUncondDemoCallback |
| | return DiffusionUncondDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | demo_steps=demo_config.get("demo_steps", 250), |
| | sample_rate=model_config["sample_rate"] |
| | ) |
| | elif model_type == "diffusion_autoencoder": |
| | from .diffusion import DiffusionAutoencoderDemoCallback |
| | return DiffusionAutoencoderDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | demo_steps=demo_config.get("demo_steps", 250), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | **kwargs |
| | ) |
| | elif model_type == "diffusion_prior": |
| | from .diffusion import DiffusionPriorDemoCallback |
| | return DiffusionPriorDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | demo_steps=demo_config.get("demo_steps", 250), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | **kwargs |
| | ) |
| | elif model_type == "diffusion_cond": |
| | from .diffusion import DiffusionCondDemoCallback |
| |
|
| | return DiffusionCondDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | demo_steps=demo_config.get("demo_steps", 250), |
| | num_demos=demo_config["num_demos"], |
| | demo_cfg_scales=demo_config["demo_cfg_scales"], |
| | demo_conditioning=demo_config.get("demo_cond", {}), |
| | demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), |
| | display_audio_cond=demo_config.get("display_audio_cond", False), |
| | ) |
| | elif model_type == "diffusion_cond_inpaint": |
| | from .diffusion import DiffusionCondInpaintDemoCallback |
| |
|
| | return DiffusionCondInpaintDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | demo_steps=demo_config.get("demo_steps", 250), |
| | demo_cfg_scales=demo_config["demo_cfg_scales"], |
| | **kwargs |
| | ) |
| | |
| | elif model_type == "lm": |
| | from .lm import AudioLanguageModelDemoCallback |
| |
|
| | return AudioLanguageModelDemoCallback( |
| | demo_every=demo_config.get("demo_every", 2000), |
| | sample_size=model_config["sample_size"], |
| | sample_rate=model_config["sample_rate"], |
| | demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), |
| | demo_conditioning=demo_config.get("demo_cond", None), |
| | num_demos=demo_config.get("num_demos", 8), |
| | **kwargs |
| | ) |
| | else: |
| | raise NotImplementedError(f'Unknown model type: {model_type}') |