| | import json |
| | import os |
| | import re |
| | from typing import Dict |
| |
|
| | import fsspec |
| | import yaml |
| | from coqpit import Coqpit |
| |
|
| | from TTS.config.shared_configs import * |
| | from TTS.utils.generic_utils import find_module |
| |
|
| |
|
| | def read_json_with_comments(json_path): |
| | """for backward compat.""" |
| | |
| | with fsspec.open(json_path, "r", encoding="utf-8") as f: |
| | input_str = f.read() |
| | |
| | input_str = re.sub(r"\\\n", "", input_str) |
| | input_str = re.sub(r"//.*\n", "\n", input_str) |
| | data = json.loads(input_str) |
| | return data |
| |
|
| |
|
| | def register_config(model_name: str) -> Coqpit: |
| | """Find the right config for the given model name. |
| | |
| | Args: |
| | model_name (str): Model name. |
| | |
| | Raises: |
| | ModuleNotFoundError: No matching config for the model name. |
| | |
| | Returns: |
| | Coqpit: config class. |
| | """ |
| | config_class = None |
| | config_name = model_name + "_config" |
| |
|
| | |
| | if model_name == "xtts": |
| | from TTS.tts.configs.xtts_config import XttsConfig |
| |
|
| | config_class = XttsConfig |
| | paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] |
| | for path in paths: |
| | try: |
| | config_class = find_module(path, config_name) |
| | except ModuleNotFoundError: |
| | pass |
| | if config_class is None: |
| | raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") |
| | return config_class |
| |
|
| |
|
| | def _process_model_name(config_dict: Dict) -> str: |
| | """Format the model name as expected. It is a band-aid for the old `vocoder` model names. |
| | |
| | Args: |
| | config_dict (Dict): A dictionary including the config fields. |
| | |
| | Returns: |
| | str: Formatted modelname. |
| | """ |
| | model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] |
| | model_name = model_name.replace("_generator", "").replace("_discriminator", "") |
| | return model_name |
| |
|
| |
|
| | def load_config(config_path: str) -> Coqpit: |
| | """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name |
| | to find the corresponding Config class. Then initialize the Config. |
| | |
| | Args: |
| | config_path (str): path to the config file. |
| | |
| | Raises: |
| | TypeError: given config file has an unknown type. |
| | |
| | Returns: |
| | Coqpit: TTS config object. |
| | """ |
| | config_dict = {} |
| | ext = os.path.splitext(config_path)[1] |
| | if ext in (".yml", ".yaml"): |
| | with fsspec.open(config_path, "r", encoding="utf-8") as f: |
| | data = yaml.safe_load(f) |
| | elif ext == ".json": |
| | try: |
| | with fsspec.open(config_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | except json.decoder.JSONDecodeError: |
| | |
| | data = read_json_with_comments(config_path) |
| | else: |
| | raise TypeError(f" [!] Unknown config file type {ext}") |
| | config_dict.update(data) |
| | model_name = _process_model_name(config_dict) |
| | config_class = register_config(model_name.lower()) |
| | config = config_class() |
| | config.from_dict(config_dict) |
| | return config |
| |
|
| |
|
| | def check_config_and_model_args(config, arg_name, value): |
| | """Check the give argument in `config.model_args` if exist or in `config` for |
| | the given value. |
| | |
| | Return False if the argument does not exist in `config.model_args` or `config`. |
| | This is to patch up the compatibility between models with and without `model_args`. |
| | |
| | TODO: Remove this in the future with a unified approach. |
| | """ |
| | if hasattr(config, "model_args"): |
| | if arg_name in config.model_args: |
| | return config.model_args[arg_name] == value |
| | if hasattr(config, arg_name): |
| | return config[arg_name] == value |
| | return False |
| |
|
| |
|
| | def get_from_config_or_model_args(config, arg_name): |
| | """Get the given argument from `config.model_args` if exist or in `config`.""" |
| | if hasattr(config, "model_args"): |
| | if arg_name in config.model_args: |
| | return config.model_args[arg_name] |
| | return config[arg_name] |
| |
|
| |
|
| | def get_from_config_or_model_args_with_default(config, arg_name, def_val): |
| | """Get the given argument from `config.model_args` if exist or in `config`.""" |
| | if hasattr(config, "model_args"): |
| | if arg_name in config.model_args: |
| | return config.model_args[arg_name] |
| | if hasattr(config, arg_name): |
| | return config[arg_name] |
| | return def_val |
| |
|