| import argparse |
| import json |
| import torch |
| from torch.nn.parameter import Parameter |
| from stable_audio_tools.models import create_model_from_config |
|
|
| if __name__ == '__main__': |
| args = argparse.ArgumentParser() |
| args.add_argument('--model-config', type=str, default=None) |
| args.add_argument('--ckpt-path', type=str, default=None) |
| args.add_argument('--name', type=str, default='exported_model') |
| args.add_argument('--use-safetensors', action='store_true') |
|
|
| args = args.parse_args() |
|
|
| with open(args.model_config) as f: |
| model_config = json.load(f) |
| |
| model = create_model_from_config(model_config) |
| |
| model_type = model_config.get('model_type', None) |
|
|
| assert model_type is not None, 'model_type must be specified in model config' |
|
|
| training_config = model_config.get('training', None) |
|
|
| if model_type == 'autoencoder': |
| from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper |
| |
| ema_copy = None |
|
|
| if training_config.get("use_ema", False): |
| from stable_audio_tools.models.factory import create_model_from_config |
| ema_copy = create_model_from_config(model_config) |
| ema_copy = create_model_from_config(model_config) |
| |
| |
| for name, param in model.state_dict().items(): |
| if isinstance(param, Parameter): |
| |
| param = param.data |
| ema_copy.state_dict()[name].copy_(param) |
|
|
| use_ema = training_config.get("use_ema", False) |
|
|
| training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( |
| args.ckpt_path, |
| autoencoder=model, |
| strict=False, |
| loss_config=training_config["loss_configs"], |
| use_ema=training_config["use_ema"], |
| ema_copy=ema_copy if use_ema else None |
| ) |
| elif model_type == 'diffusion_uncond': |
| from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper |
| training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
|
|
| elif model_type == 'diffusion_autoencoder': |
| from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper |
|
|
| ema_copy = create_model_from_config(model_config) |
| |
| for name, param in model.state_dict().items(): |
| if isinstance(param, Parameter): |
| |
| param = param.data |
| ema_copy.state_dict()[name].copy_(param) |
|
|
| training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) |
| elif model_type == 'diffusion_cond': |
| from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper |
| |
| use_ema = training_config.get("use_ema", True) |
| |
| training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( |
| args.ckpt_path, |
| model=model, |
| use_ema=use_ema, |
| lr=training_config.get("learning_rate", None), |
| optimizer_configs=training_config.get("optimizer_configs", None), |
| strict=False |
| ) |
| elif model_type == 'diffusion_cond_inpaint': |
| from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper |
| training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
| elif model_type == 'diffusion_prior': |
| from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper |
|
|
| ema_copy = create_model_from_config(model_config) |
| |
| for name, param in model.state_dict().items(): |
| if isinstance(param, Parameter): |
| |
| param = param.data |
| ema_copy.state_dict()[name].copy_(param) |
|
|
| training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) |
| elif model_type == 'lm': |
| from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper |
|
|
| ema_copy = None |
|
|
| if training_config.get("use_ema", False): |
|
|
| ema_copy = create_model_from_config(model_config) |
|
|
| for name, param in model.state_dict().items(): |
| if isinstance(param, Parameter): |
| |
| param = param.data |
| ema_copy.state_dict()[name].copy_(param) |
|
|
| training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( |
| args.ckpt_path, |
| model=model, |
| strict=False, |
| ema_copy=ema_copy, |
| optimizer_configs=training_config.get("optimizer_configs", None) |
| ) |
|
|
| else: |
| raise ValueError(f"Unknown model type {model_type}") |
| |
| print(f"Loaded model from {args.ckpt_path}") |
|
|
| if args.use_safetensors: |
| ckpt_path = f"{args.name}.safetensors" |
| else: |
| ckpt_path = f"{args.name}.ckpt" |
|
|
| training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) |
|
|
| print(f"Exported model to {ckpt_path}") |