| | import json |
| |
|
| | def create_model_from_config(model_config): |
| | model_type = model_config.get('model_type', None) |
| |
|
| | assert model_type is not None, 'model_type must be specified in model config' |
| |
|
| | if model_type == 'autoencoder': |
| | from .autoencoders import create_autoencoder_from_config |
| | return create_autoencoder_from_config(model_config) |
| | elif model_type == 'diffusion_uncond': |
| | from .diffusion import create_diffusion_uncond_from_config |
| | return create_diffusion_uncond_from_config(model_config) |
| | elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": |
| | from .diffusion import create_diffusion_cond_from_config |
| | return create_diffusion_cond_from_config(model_config) |
| | elif model_type == 'diffusion_autoencoder': |
| | from .autoencoders import create_diffAE_from_config |
| | return create_diffAE_from_config(model_config) |
| | elif model_type == 'lm': |
| | from .lm import create_audio_lm_from_config |
| | return create_audio_lm_from_config(model_config) |
| | else: |
| | raise NotImplementedError(f'Unknown model type: {model_type}') |
| |
|
| | def create_model_from_config_path(model_config_path): |
| | with open(model_config_path) as f: |
| | model_config = json.load(f) |
| | |
| | return create_model_from_config(model_config) |
| |
|
| | def create_pretransform_from_config(pretransform_config, sample_rate): |
| | pretransform_type = pretransform_config.get('type', None) |
| |
|
| | assert pretransform_type is not None, 'type must be specified in pretransform config' |
| |
|
| | if pretransform_type == 'autoencoder': |
| | from .autoencoders import create_autoencoder_from_config |
| | from .pretransforms import AutoencoderPretransform |
| |
|
| | |
| | |
| | autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} |
| | autoencoder = create_autoencoder_from_config(autoencoder_config) |
| |
|
| | scale = pretransform_config.get("scale", 1.0) |
| | model_half = pretransform_config.get("model_half", False) |
| | iterate_batch = pretransform_config.get("iterate_batch", False) |
| | chunked = pretransform_config.get("chunked", False) |
| |
|
| | pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) |
| | elif pretransform_type == 'wavelet': |
| | from .pretransforms import WaveletPretransform |
| |
|
| | wavelet_config = pretransform_config["config"] |
| | channels = wavelet_config["channels"] |
| | levels = wavelet_config["levels"] |
| | wavelet = wavelet_config["wavelet"] |
| |
|
| | pretransform = WaveletPretransform(channels, levels, wavelet) |
| | elif pretransform_type == 'pqmf': |
| | from .pretransforms import PQMFPretransform |
| | pqmf_config = pretransform_config["config"] |
| | pretransform = PQMFPretransform(**pqmf_config) |
| | elif pretransform_type == 'dac_pretrained': |
| | from .pretransforms import PretrainedDACPretransform |
| | pretrained_dac_config = pretransform_config["config"] |
| | pretransform = PretrainedDACPretransform(**pretrained_dac_config) |
| | elif pretransform_type == "audiocraft_pretrained": |
| | from .pretransforms import AudiocraftCompressionPretransform |
| |
|
| | audiocraft_config = pretransform_config["config"] |
| | pretransform = AudiocraftCompressionPretransform(**audiocraft_config) |
| | else: |
| | raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') |
| | |
| | enable_grad = pretransform_config.get('enable_grad', False) |
| | pretransform.enable_grad = enable_grad |
| |
|
| | pretransform.eval().requires_grad_(pretransform.enable_grad) |
| |
|
| | return pretransform |
| |
|
| | def create_bottleneck_from_config(bottleneck_config): |
| | bottleneck_type = bottleneck_config.get('type', None) |
| |
|
| | assert bottleneck_type is not None, 'type must be specified in bottleneck config' |
| |
|
| | if bottleneck_type == 'tanh': |
| | from .bottleneck import TanhBottleneck |
| | bottleneck = TanhBottleneck() |
| | elif bottleneck_type == 'vae': |
| | from .bottleneck import VAEBottleneck |
| | bottleneck = VAEBottleneck() |
| | elif bottleneck_type == 'rvq': |
| | from .bottleneck import RVQBottleneck |
| |
|
| | quantizer_params = { |
| | "dim": 128, |
| | "codebook_size": 1024, |
| | "num_quantizers": 8, |
| | "decay": 0.99, |
| | "kmeans_init": True, |
| | "kmeans_iters": 50, |
| | "threshold_ema_dead_code": 2, |
| | } |
| |
|
| | quantizer_params.update(bottleneck_config["config"]) |
| |
|
| | bottleneck = RVQBottleneck(**quantizer_params) |
| | elif bottleneck_type == "dac_rvq": |
| | from .bottleneck import DACRVQBottleneck |
| |
|
| | bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) |
| | |
| | elif bottleneck_type == 'rvq_vae': |
| | from .bottleneck import RVQVAEBottleneck |
| |
|
| | quantizer_params = { |
| | "dim": 128, |
| | "codebook_size": 1024, |
| | "num_quantizers": 8, |
| | "decay": 0.99, |
| | "kmeans_init": True, |
| | "kmeans_iters": 50, |
| | "threshold_ema_dead_code": 2, |
| | } |
| |
|
| | quantizer_params.update(bottleneck_config["config"]) |
| |
|
| | bottleneck = RVQVAEBottleneck(**quantizer_params) |
| | |
| | elif bottleneck_type == 'dac_rvq_vae': |
| | from .bottleneck import DACRVQVAEBottleneck |
| | bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) |
| | elif bottleneck_type == 'l2_norm': |
| | from .bottleneck import L2Bottleneck |
| | bottleneck = L2Bottleneck() |
| | elif bottleneck_type == "wasserstein": |
| | from .bottleneck import WassersteinBottleneck |
| | bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) |
| | elif bottleneck_type == "fsq": |
| | from .bottleneck import FSQBottleneck |
| | bottleneck = FSQBottleneck(**bottleneck_config["config"]) |
| | else: |
| | raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') |
| | |
| | requires_grad = bottleneck_config.get('requires_grad', True) |
| | if not requires_grad: |
| | for param in bottleneck.parameters(): |
| | param.requires_grad = False |
| |
|
| | return bottleneck |
| |
|