Spaces:
Running
Running
| import datetime | |
| import importlib | |
| import logging | |
| import os | |
| import re | |
| import unicodedata | |
| import warnings | |
| from collections.abc import Callable | |
| from pathlib import Path | |
| from typing import Any, TextIO, TypeVar | |
| import torch | |
| from packaging.version import Version | |
| from typing_extensions import TypeIs | |
| logger = logging.getLogger(__name__) | |
| _T = TypeVar("_T") | |
| def exists(val: _T | None) -> TypeIs[_T]: | |
| return val is not None | |
| def default(val: _T | None, d: _T | Callable[[], _T]) -> _T: | |
| if exists(val): | |
| return val | |
| return d() if callable(d) else d | |
| def to_camel(text): | |
| text = text.capitalize() | |
| text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) | |
| text = text.replace("Tts", "TTS") | |
| text = text.replace("vc", "VC") | |
| text = text.replace("Knn", "KNN") | |
| return text | |
| def slugify(text: str) -> str: | |
| """Convert a string (e.g. speaker IDs) into a safe filename base.""" | |
| # Normalize to ASCII (e.g., Zoë -> Zoe) | |
| normalized = unicodedata.normalize("NFKD", text) | |
| ascii_str = normalized.encode("ascii", "ignore").decode("ascii") | |
| # Replace unsafe characters with underscores | |
| safe = re.sub(r"[^\w\-]", "_", ascii_str) | |
| # Collapse repeated underscores | |
| return re.sub(r"_+", "_", safe).strip("_") | |
| def find_module(module_path: str, module_name: str) -> type[Any]: | |
| module_name = module_name.lower() | |
| module = importlib.import_module(module_path + "." + module_name) | |
| class_name = to_camel(module_name) | |
| return getattr(module, class_name) | |
| def import_class(module_path: str) -> type[Any]: | |
| """Import a class from a module path. | |
| Args: | |
| module_path (str): The module path of the class. | |
| Returns: | |
| object: The imported class. | |
| """ | |
| class_name = module_path.split(".")[-1] | |
| module_path = ".".join(module_path.split(".")[:-1]) | |
| module = importlib.import_module(module_path) | |
| return getattr(module, class_name) | |
| def get_import_path(obj: object) -> str: | |
| """Get the import path of a class. | |
| Args: | |
| obj (object): The class object. | |
| Returns: | |
| str: The import path of the class. | |
| """ | |
| return ".".join([type(obj).__module__, type(obj).__name__]) | |
| def format_aux_input(def_args: dict, kwargs: dict) -> dict: | |
| """Format kwargs to hande auxilary inputs to models. | |
| Args: | |
| def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. | |
| kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. | |
| Returns: | |
| Dict: arguments with formatted auxilary inputs. | |
| """ | |
| kwargs = kwargs.copy() | |
| for name, arg in def_args.items(): | |
| if name not in kwargs or kwargs[name] is None: | |
| kwargs[name] = arg | |
| return kwargs | |
| def get_timestamp() -> str: | |
| return datetime.datetime.now().strftime("%y%m%d-%H%M%S") | |
| class ConsoleFormatter(logging.Formatter): | |
| """Custom formatter that prints logging.INFO messages without the level name. | |
| Source: https://stackoverflow.com/a/62488520 | |
| """ | |
| def format(self, record): | |
| if record.levelno == logging.INFO: | |
| self._style._fmt = "%(message)s" | |
| else: | |
| self._style._fmt = "%(levelname)s: %(message)s" | |
| return super().format(record) | |
| def setup_logger( | |
| logger_name: str, | |
| level: int = logging.INFO, | |
| *, | |
| formatter: logging.Formatter | None = None, | |
| stream: TextIO | None = None, | |
| log_dir: str | os.PathLike[Any] | None = None, | |
| log_name: str = "log", | |
| ) -> None: | |
| """Set up a logger. | |
| Args: | |
| logger_name: Name of the logger to set up | |
| level: Logging level | |
| formatter: Formatter for the logger | |
| stream: Add a StreamHandler for the given stream, e.g. sys.stderr or sys.stdout | |
| log_dir: Folder to write the log file (no file created if None) | |
| log_name: Prefix of the log file name | |
| """ | |
| lg = logging.getLogger(logger_name) | |
| if formatter is None: | |
| formatter = logging.Formatter( | |
| "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" | |
| ) | |
| lg.setLevel(level) | |
| if log_dir is not None: | |
| Path(log_dir).mkdir(exist_ok=True, parents=True) | |
| log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" | |
| fh = logging.FileHandler(log_file, mode="w") | |
| fh.setFormatter(formatter) | |
| lg.addHandler(fh) | |
| if stream is not None: | |
| sh = logging.StreamHandler(stream) | |
| sh.setFormatter(formatter) | |
| lg.addHandler(sh) | |
| def is_pytorch_at_least_2_4() -> bool: | |
| """Check if the installed Pytorch version is 2.4 or higher.""" | |
| return Version(torch.__version__) >= Version("2.4") | |
| def optional_to_str(x: Any | None) -> str: | |
| """Convert input to string, using empty string if input is None.""" | |
| return "" if x is None else str(x) | |
| def warn_synthesize_config_deprecated() -> None: | |
| warnings.warn( | |
| "The `config` argument of synthesize() is deprecated and will be removed soon. You can safely leave it out.", | |
| DeprecationWarning, | |
| ) | |
| def warn_synthesize_speaker_id_deprecated() -> None: | |
| warnings.warn( | |
| "The `speaker_id` argument of synthesize() is deprecated and will be removed soon. Use `speaker` instead.", | |
| DeprecationWarning, | |
| ) | |