|
|
|
|
|
import datetime |
|
|
import importlib |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Any, Callable, Dict, Optional, TextIO, TypeVar, Union |
|
|
|
|
|
import torch |
|
|
from packaging.version import Version |
|
|
from typing_extensions import TypeIs |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
|
|
|
def exists(val: Union[_T, None]) -> TypeIs[_T]: |
|
|
return val is not None |
|
|
|
|
|
|
|
|
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T: |
|
|
if exists(val): |
|
|
return val |
|
|
return d() if callable(d) else d |
|
|
|
|
|
|
|
|
def to_camel(text): |
|
|
text = text.capitalize() |
|
|
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) |
|
|
text = text.replace("Tts", "TTS") |
|
|
text = text.replace("vc", "VC") |
|
|
return text |
|
|
|
|
|
|
|
|
def find_module(module_path: str, module_name: str) -> object: |
|
|
module_name = module_name.lower() |
|
|
module = importlib.import_module(module_path + "." + module_name) |
|
|
class_name = to_camel(module_name) |
|
|
return getattr(module, class_name) |
|
|
|
|
|
|
|
|
def import_class(module_path: str) -> object: |
|
|
"""Import a class from a module path. |
|
|
|
|
|
Args: |
|
|
module_path (str): The module path of the class. |
|
|
|
|
|
Returns: |
|
|
object: The imported class. |
|
|
""" |
|
|
class_name = module_path.split(".")[-1] |
|
|
module_path = ".".join(module_path.split(".")[:-1]) |
|
|
module = importlib.import_module(module_path) |
|
|
return getattr(module, class_name) |
|
|
|
|
|
|
|
|
def get_import_path(obj: object) -> str: |
|
|
"""Get the import path of a class. |
|
|
|
|
|
Args: |
|
|
obj (object): The class object. |
|
|
|
|
|
Returns: |
|
|
str: The import path of the class. |
|
|
""" |
|
|
return ".".join([type(obj).__module__, type(obj).__name__]) |
|
|
|
|
|
|
|
|
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: |
|
|
"""Format kwargs to hande auxilary inputs to models. |
|
|
|
|
|
Args: |
|
|
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. |
|
|
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. |
|
|
|
|
|
Returns: |
|
|
Dict: arguments with formatted auxilary inputs. |
|
|
""" |
|
|
kwargs = kwargs.copy() |
|
|
for name in def_args: |
|
|
if name not in kwargs or kwargs[name] is None: |
|
|
kwargs[name] = def_args[name] |
|
|
return kwargs |
|
|
|
|
|
|
|
|
def get_timestamp() -> str: |
|
|
return datetime.datetime.now().strftime("%y%m%d-%H%M%S") |
|
|
|
|
|
|
|
|
class ConsoleFormatter(logging.Formatter): |
|
|
"""Custom formatter that prints logging.INFO messages without the level name. |
|
|
|
|
|
Source: https://stackoverflow.com/a/62488520 |
|
|
""" |
|
|
|
|
|
def format(self, record): |
|
|
if record.levelno == logging.INFO: |
|
|
self._style._fmt = "%(message)s" |
|
|
else: |
|
|
self._style._fmt = "%(levelname)s: %(message)s" |
|
|
return super().format(record) |
|
|
|
|
|
|
|
|
def setup_logger( |
|
|
logger_name: str, |
|
|
level: int = logging.INFO, |
|
|
*, |
|
|
formatter: Optional[logging.Formatter] = None, |
|
|
stream: Optional[TextIO] = None, |
|
|
log_dir: Optional[Union[str, os.PathLike[Any]]] = None, |
|
|
log_name: str = "log", |
|
|
) -> None: |
|
|
"""Set up a logger. |
|
|
|
|
|
Args: |
|
|
logger_name: Name of the logger to set up |
|
|
level: Logging level |
|
|
formatter: Formatter for the logger |
|
|
stream: Add a StreamHandler for the given stream, e.g. sys.stderr or sys.stdout |
|
|
log_dir: Folder to write the log file (no file created if None) |
|
|
log_name: Prefix of the log file name |
|
|
""" |
|
|
lg = logging.getLogger(logger_name) |
|
|
if formatter is None: |
|
|
formatter = logging.Formatter( |
|
|
"%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" |
|
|
) |
|
|
lg.setLevel(level) |
|
|
if log_dir is not None: |
|
|
Path(log_dir).mkdir(exist_ok=True, parents=True) |
|
|
log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" |
|
|
fh = logging.FileHandler(log_file, mode="w") |
|
|
fh.setFormatter(formatter) |
|
|
lg.addHandler(fh) |
|
|
if stream is not None: |
|
|
sh = logging.StreamHandler(stream) |
|
|
sh.setFormatter(formatter) |
|
|
lg.addHandler(sh) |
|
|
|
|
|
|
|
|
def is_pytorch_at_least_2_4() -> bool: |
|
|
"""Check if the installed Pytorch version is 2.4 or higher.""" |
|
|
return Version(torch.__version__) >= Version("2.4") |
|
|
|
|
|
|
|
|
def optional_to_str(x: Optional[Any]) -> str: |
|
|
"""Convert input to string, using empty string if input is None.""" |
|
|
return "" if x is None else str(x) |
|
|
|