| | import datetime |
| | import importlib |
| | import logging |
| | import os |
| | import re |
| | import unicodedata |
| | import warnings |
| | from collections.abc import Callable |
| | from pathlib import Path |
| | from typing import Any, TextIO, TypeVar |
| |
|
| | import torch |
| | from packaging.version import Version |
| | from typing_extensions import TypeIs |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | _T = TypeVar("_T") |
| |
|
| |
|
| | def exists(val: _T | None) -> TypeIs[_T]: |
| | return val is not None |
| |
|
| |
|
| | def default(val: _T | None, d: _T | Callable[[], _T]) -> _T: |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | def to_camel(text): |
| | text = text.capitalize() |
| | text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) |
| | text = text.replace("Tts", "TTS") |
| | text = text.replace("vc", "VC") |
| | text = text.replace("Knn", "KNN") |
| | return text |
| |
|
| |
|
| | def slugify(text: str) -> str: |
| | """Convert a string (e.g. speaker IDs) into a safe filename base.""" |
| | |
| | normalized = unicodedata.normalize("NFKD", text) |
| | ascii_str = normalized.encode("ascii", "ignore").decode("ascii") |
| |
|
| | |
| | safe = re.sub(r"[^\w\-]", "_", ascii_str) |
| |
|
| | |
| | return re.sub(r"_+", "_", safe).strip("_") |
| |
|
| |
|
| | def find_module(module_path: str, module_name: str) -> type[Any]: |
| | module_name = module_name.lower() |
| | module = importlib.import_module(module_path + "." + module_name) |
| | class_name = to_camel(module_name) |
| | return getattr(module, class_name) |
| |
|
| |
|
| | def import_class(module_path: str) -> type[Any]: |
| | """Import a class from a module path. |
| | |
| | Args: |
| | module_path (str): The module path of the class. |
| | |
| | Returns: |
| | object: The imported class. |
| | """ |
| | class_name = module_path.split(".")[-1] |
| | module_path = ".".join(module_path.split(".")[:-1]) |
| | module = importlib.import_module(module_path) |
| | return getattr(module, class_name) |
| |
|
| |
|
| | def get_import_path(obj: object) -> str: |
| | """Get the import path of a class. |
| | |
| | Args: |
| | obj (object): The class object. |
| | |
| | Returns: |
| | str: The import path of the class. |
| | """ |
| | return ".".join([type(obj).__module__, type(obj).__name__]) |
| |
|
| |
|
| | def format_aux_input(def_args: dict, kwargs: dict) -> dict: |
| | """Format kwargs to hande auxilary inputs to models. |
| | |
| | Args: |
| | def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. |
| | kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. |
| | |
| | Returns: |
| | Dict: arguments with formatted auxilary inputs. |
| | """ |
| | kwargs = kwargs.copy() |
| | for name, arg in def_args.items(): |
| | if name not in kwargs or kwargs[name] is None: |
| | kwargs[name] = arg |
| | return kwargs |
| |
|
| |
|
| | def get_timestamp() -> str: |
| | return datetime.datetime.now().strftime("%y%m%d-%H%M%S") |
| |
|
| |
|
| | class ConsoleFormatter(logging.Formatter): |
| | """Custom formatter that prints logging.INFO messages without the level name. |
| | |
| | Source: https://stackoverflow.com/a/62488520 |
| | """ |
| |
|
| | def format(self, record): |
| | if record.levelno == logging.INFO: |
| | self._style._fmt = "%(message)s" |
| | else: |
| | self._style._fmt = "%(levelname)s: %(message)s" |
| | return super().format(record) |
| |
|
| |
|
| | def setup_logger( |
| | logger_name: str, |
| | level: int = logging.INFO, |
| | *, |
| | formatter: logging.Formatter | None = None, |
| | stream: TextIO | None = None, |
| | log_dir: str | os.PathLike[Any] | None = None, |
| | log_name: str = "log", |
| | ) -> None: |
| | """Set up a logger. |
| | |
| | Args: |
| | logger_name: Name of the logger to set up |
| | level: Logging level |
| | formatter: Formatter for the logger |
| | stream: Add a StreamHandler for the given stream, e.g. sys.stderr or sys.stdout |
| | log_dir: Folder to write the log file (no file created if None) |
| | log_name: Prefix of the log file name |
| | """ |
| | lg = logging.getLogger(logger_name) |
| | if formatter is None: |
| | formatter = logging.Formatter( |
| | "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" |
| | ) |
| | lg.setLevel(level) |
| | if log_dir is not None: |
| | Path(log_dir).mkdir(exist_ok=True, parents=True) |
| | log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" |
| | fh = logging.FileHandler(log_file, mode="w") |
| | fh.setFormatter(formatter) |
| | lg.addHandler(fh) |
| | if stream is not None: |
| | sh = logging.StreamHandler(stream) |
| | sh.setFormatter(formatter) |
| | lg.addHandler(sh) |
| |
|
| |
|
| | def is_pytorch_at_least_2_4() -> bool: |
| | """Check if the installed Pytorch version is 2.4 or higher.""" |
| | return Version(torch.__version__) >= Version("2.4") |
| |
|
| |
|
| | def optional_to_str(x: Any | None) -> str: |
| | """Convert input to string, using empty string if input is None.""" |
| | return "" if x is None else str(x) |
| |
|
| |
|
| | def warn_synthesize_config_deprecated() -> None: |
| | warnings.warn( |
| | "The `config` argument of synthesize() is deprecated and will be removed soon. You can safely leave it out.", |
| | DeprecationWarning, |
| | ) |
| |
|
| |
|
| | def warn_synthesize_speaker_id_deprecated() -> None: |
| | warnings.warn( |
| | "The `speaker_id` argument of synthesize() is deprecated and will be removed soon. Use `speaker` instead.", |
| | DeprecationWarning, |
| | ) |
| |
|