| | import importlib |
| | import re |
| |
|
| | from coqpit import Coqpit |
| |
|
| |
|
| | def to_camel(text): |
| | text = text.capitalize() |
| | return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) |
| |
|
| |
|
| | def setup_model(config: Coqpit): |
| | """Load models directly from configuration.""" |
| | if "discriminator_model" in config and "generator_model" in config: |
| | MyModel = importlib.import_module("TTS.vocoder.models.gan") |
| | MyModel = getattr(MyModel, "GAN") |
| | else: |
| | MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) |
| | if config.model.lower() == "wavernn": |
| | MyModel = getattr(MyModel, "Wavernn") |
| | elif config.model.lower() == "gan": |
| | MyModel = getattr(MyModel, "GAN") |
| | elif config.model.lower() == "wavegrad": |
| | MyModel = getattr(MyModel, "Wavegrad") |
| | else: |
| | try: |
| | MyModel = getattr(MyModel, to_camel(config.model)) |
| | except ModuleNotFoundError as e: |
| | raise ValueError(f"Model {config.model} not exist!") from e |
| | print(" > Vocoder Model: {}".format(config.model)) |
| | return MyModel.init_from_config(config) |
| |
|
| |
|
| | def setup_generator(c): |
| | """TODO: use config object as arguments""" |
| | print(" > Generator Model: {}".format(c.generator_model)) |
| | MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) |
| | MyModel = getattr(MyModel, to_camel(c.generator_model)) |
| | |
| | if c.generator_model.lower() in "hifigan_generator": |
| | model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) |
| | elif c.generator_model.lower() in "melgan_generator": |
| | model = MyModel( |
| | in_channels=c.audio["num_mels"], |
| | out_channels=1, |
| | proj_kernel=7, |
| | base_channels=512, |
| | upsample_factors=c.generator_model_params["upsample_factors"], |
| | res_kernel=3, |
| | num_res_blocks=c.generator_model_params["num_res_blocks"], |
| | ) |
| | elif c.generator_model in "melgan_fb_generator": |
| | raise ValueError("melgan_fb_generator is now fullband_melgan_generator") |
| | elif c.generator_model.lower() in "multiband_melgan_generator": |
| | model = MyModel( |
| | in_channels=c.audio["num_mels"], |
| | out_channels=4, |
| | proj_kernel=7, |
| | base_channels=384, |
| | upsample_factors=c.generator_model_params["upsample_factors"], |
| | res_kernel=3, |
| | num_res_blocks=c.generator_model_params["num_res_blocks"], |
| | ) |
| | elif c.generator_model.lower() in "fullband_melgan_generator": |
| | model = MyModel( |
| | in_channels=c.audio["num_mels"], |
| | out_channels=1, |
| | proj_kernel=7, |
| | base_channels=512, |
| | upsample_factors=c.generator_model_params["upsample_factors"], |
| | res_kernel=3, |
| | num_res_blocks=c.generator_model_params["num_res_blocks"], |
| | ) |
| | elif c.generator_model.lower() in "parallel_wavegan_generator": |
| | model = MyModel( |
| | in_channels=1, |
| | out_channels=1, |
| | kernel_size=3, |
| | num_res_blocks=c.generator_model_params["num_res_blocks"], |
| | stacks=c.generator_model_params["stacks"], |
| | res_channels=64, |
| | gate_channels=128, |
| | skip_channels=64, |
| | aux_channels=c.audio["num_mels"], |
| | dropout=0.0, |
| | bias=True, |
| | use_weight_norm=True, |
| | upsample_factors=c.generator_model_params["upsample_factors"], |
| | ) |
| | elif c.generator_model.lower() in "univnet_generator": |
| | model = MyModel(**c.generator_model_params) |
| | else: |
| | raise NotImplementedError(f"Model {c.generator_model} not implemented!") |
| | return model |
| |
|
| |
|
| | def setup_discriminator(c): |
| | """TODO: use config objekt as arguments""" |
| | print(" > Discriminator Model: {}".format(c.discriminator_model)) |
| | if "parallel_wavegan" in c.discriminator_model: |
| | MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") |
| | else: |
| | MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) |
| | MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) |
| | if c.discriminator_model in "hifigan_discriminator": |
| | model = MyModel() |
| | if c.discriminator_model in "random_window_discriminator": |
| | model = MyModel( |
| | cond_channels=c.audio["num_mels"], |
| | hop_length=c.audio["hop_length"], |
| | uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], |
| | cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], |
| | cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], |
| | window_sizes=c.discriminator_model_params["window_sizes"], |
| | ) |
| | if c.discriminator_model in "melgan_multiscale_discriminator": |
| | model = MyModel( |
| | in_channels=1, |
| | out_channels=1, |
| | kernel_sizes=(5, 3), |
| | base_channels=c.discriminator_model_params["base_channels"], |
| | max_channels=c.discriminator_model_params["max_channels"], |
| | downsample_factors=c.discriminator_model_params["downsample_factors"], |
| | ) |
| | if c.discriminator_model == "residual_parallel_wavegan_discriminator": |
| | model = MyModel( |
| | in_channels=1, |
| | out_channels=1, |
| | kernel_size=3, |
| | num_layers=c.discriminator_model_params["num_layers"], |
| | stacks=c.discriminator_model_params["stacks"], |
| | res_channels=64, |
| | gate_channels=128, |
| | skip_channels=64, |
| | dropout=0.0, |
| | bias=True, |
| | nonlinear_activation="LeakyReLU", |
| | nonlinear_activation_params={"negative_slope": 0.2}, |
| | ) |
| | if c.discriminator_model == "parallel_wavegan_discriminator": |
| | model = MyModel( |
| | in_channels=1, |
| | out_channels=1, |
| | kernel_size=3, |
| | num_layers=c.discriminator_model_params["num_layers"], |
| | conv_channels=64, |
| | dilation_factor=1, |
| | nonlinear_activation="LeakyReLU", |
| | nonlinear_activation_params={"negative_slope": 0.2}, |
| | bias=True, |
| | ) |
| | if c.discriminator_model == "univnet_discriminator": |
| | model = MyModel() |
| | return model |
| |
|