| | import argparse |
| | import json |
| | import torch |
| | from torch.nn.parameter import Parameter |
| | from stable_audio_tools.models import create_model_from_config |
| |
|
| | if __name__ == '__main__': |
| | args = argparse.ArgumentParser() |
| | args.add_argument('--model-config', type=str, default=None) |
| | args.add_argument('--ckpt-path', type=str, default=None) |
| | args.add_argument('--name', type=str, default='exported_model') |
| | args.add_argument('--use-safetensors', action='store_true') |
| |
|
| | args = args.parse_args() |
| |
|
| | with open(args.model_config) as f: |
| | model_config = json.load(f) |
| | |
| | model = create_model_from_config(model_config) |
| | |
| | model_type = model_config.get('model_type', None) |
| |
|
| | assert model_type is not None, 'model_type must be specified in model config' |
| |
|
| | training_config = model_config.get('training', None) |
| |
|
| | if model_type == 'autoencoder': |
| | from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper |
| | |
| | ema_copy = None |
| |
|
| | if training_config.get("use_ema", False): |
| | from stable_audio_tools.models.factory import create_model_from_config |
| | ema_copy = create_model_from_config(model_config) |
| | ema_copy = create_model_from_config(model_config) |
| | |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | use_ema = training_config.get("use_ema", False) |
| |
|
| | training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( |
| | args.ckpt_path, |
| | autoencoder=model, |
| | strict=False, |
| | loss_config=training_config["loss_configs"], |
| | use_ema=training_config["use_ema"], |
| | ema_copy=ema_copy if use_ema else None |
| | ) |
| | elif model_type == 'diffusion_uncond': |
| | from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper |
| | training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
| |
|
| | elif model_type == 'diffusion_autoencoder': |
| | from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) |
| | elif model_type == 'diffusion_cond': |
| | from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper |
| | |
| | use_ema = training_config.get("use_ema", True) |
| | |
| | training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( |
| | args.ckpt_path, |
| | model=model, |
| | use_ema=use_ema, |
| | lr=training_config.get("learning_rate", None), |
| | optimizer_configs=training_config.get("optimizer_configs", None), |
| | strict=False |
| | ) |
| | elif model_type == 'diffusion_cond_inpaint': |
| | from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper |
| | training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
| | elif model_type == 'diffusion_prior': |
| | from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| | |
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) |
| | elif model_type == 'lm': |
| | from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper |
| |
|
| | ema_copy = None |
| |
|
| | if training_config.get("use_ema", False): |
| |
|
| | ema_copy = create_model_from_config(model_config) |
| |
|
| | for name, param in model.state_dict().items(): |
| | if isinstance(param, Parameter): |
| | |
| | param = param.data |
| | ema_copy.state_dict()[name].copy_(param) |
| |
|
| | training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( |
| | args.ckpt_path, |
| | model=model, |
| | strict=False, |
| | ema_copy=ema_copy, |
| | optimizer_configs=training_config.get("optimizer_configs", None) |
| | ) |
| |
|
| | else: |
| | raise ValueError(f"Unknown model type {model_type}") |
| | |
| | print(f"Loaded model from {args.ckpt_path}") |
| |
|
| | if args.use_safetensors: |
| | ckpt_path = f"{args.name}.safetensors" |
| | else: |
| | ckpt_path = f"{args.name}.ckpt" |
| |
|
| | training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) |
| |
|
| | print(f"Exported model to {ckpt_path}") |