xjsc0's picture
1
64ec292
import json
def create_model_from_config(model_config):
model_type = model_config.get("model_type", None)
assert model_type is not None, "model_type must be specified in model config"
if model_type == "autoencoder":
from .autoencoders import create_autoencoder_from_config
return create_autoencoder_from_config(model_config)
elif model_type == "diffusion_uncond":
from .diffusion import create_diffusion_uncond_from_config
return create_diffusion_uncond_from_config(model_config)
elif (
model_type == "diffusion_cond"
or model_type == "diffusion_cond_inpaint"
or model_type == "diffusion_prior"
):
from .diffusion import create_diffusion_cond_from_config
return create_diffusion_cond_from_config(model_config)
elif model_type == "diffusion_autoencoder":
from .autoencoders import create_diffAE_from_config
return create_diffAE_from_config(model_config)
elif model_type == "lm":
from .lm import create_audio_lm_from_config
return create_audio_lm_from_config(model_config)
else:
raise NotImplementedError(f"Unknown model type: {model_type}")
def create_model_from_config_path(model_config_path):
with open(model_config_path) as f:
model_config = json.load(f)
return create_model_from_config(model_config)
def create_pretransform_from_config(pretransform_config, sample_rate):
pretransform_type = pretransform_config.get("type", None)
assert pretransform_type is not None, (
"type must be specified in pretransform config"
)
if pretransform_type == "autoencoder":
from .autoencoders import create_autoencoder_from_config
from .pretransforms import AutoencoderPretransform
# Create fake top-level config to pass sample rate to autoencoder constructor
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
autoencoder_config = {
"sample_rate": sample_rate,
"model": pretransform_config["config"],
}
autoencoder = create_autoencoder_from_config(autoencoder_config)
scale = pretransform_config.get("scale", 1.0)
model_half = pretransform_config.get("model_half", False)
iterate_batch = pretransform_config.get("iterate_batch", False)
chunked = pretransform_config.get("chunked", False)
pretransform = AutoencoderPretransform(
autoencoder,
scale=scale,
model_half=model_half,
iterate_batch=iterate_batch,
chunked=chunked,
)
elif pretransform_type == "wavelet":
from .pretransforms import WaveletPretransform
wavelet_config = pretransform_config["config"]
channels = wavelet_config["channels"]
levels = wavelet_config["levels"]
wavelet = wavelet_config["wavelet"]
pretransform = WaveletPretransform(channels, levels, wavelet)
elif pretransform_type == "pqmf":
from .pretransforms import PQMFPretransform
pqmf_config = pretransform_config["config"]
pretransform = PQMFPretransform(**pqmf_config)
elif pretransform_type == "dac_pretrained":
from .pretransforms import PretrainedDACPretransform
pretrained_dac_config = pretransform_config["config"]
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
elif pretransform_type == "audiocraft_pretrained":
from .pretransforms import AudiocraftCompressionPretransform
audiocraft_config = pretransform_config["config"]
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
else:
raise NotImplementedError(f"Unknown pretransform type: {pretransform_type}")
enable_grad = pretransform_config.get("enable_grad", False)
pretransform.enable_grad = enable_grad
pretransform.eval().requires_grad_(pretransform.enable_grad)
return pretransform
def create_bottleneck_from_config(bottleneck_config):
bottleneck_type = bottleneck_config.get("type", None)
assert bottleneck_type is not None, "type must be specified in bottleneck config"
if bottleneck_type == "tanh":
from .bottleneck import TanhBottleneck
bottleneck = TanhBottleneck()
elif bottleneck_type == "vae":
from .bottleneck import VAEBottleneck
bottleneck = VAEBottleneck()
elif bottleneck_type == "rvq":
from .bottleneck import RVQBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQBottleneck(**quantizer_params)
elif bottleneck_type == "dac_rvq":
from .bottleneck import DACRVQBottleneck
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
elif bottleneck_type == "rvq_vae":
from .bottleneck import RVQVAEBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQVAEBottleneck(**quantizer_params)
elif bottleneck_type == "dac_rvq_vae":
from .bottleneck import DACRVQVAEBottleneck
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
elif bottleneck_type == "l2_norm":
from .bottleneck import L2Bottleneck
bottleneck = L2Bottleneck()
elif bottleneck_type == "wasserstein":
from .bottleneck import WassersteinBottleneck
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
elif bottleneck_type == "fsq":
from .bottleneck import FSQBottleneck
bottleneck = FSQBottleneck(**bottleneck_config["config"])
else:
raise NotImplementedError(f"Unknown bottleneck type: {bottleneck_type}")
requires_grad = bottleneck_config.get("requires_grad", True)
if not requires_grad:
for param in bottleneck.parameters():
param.requires_grad = False
return bottleneck