| import os |
| import sys |
| import warnings |
| from importlib.util import find_spec |
| from math import ceil |
| from pathlib import Path |
| from typing import Any, Callable, Dict, Tuple |
|
|
| import gdown |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import wget |
| from omegaconf import DictConfig |
|
|
| from matcha.utils import pylogger, rich_utils |
|
|
| log = pylogger.get_pylogger(__name__) |
|
|
|
|
| def extras(cfg: DictConfig) -> None: |
| """Applies optional utilities before the task is started. |
| |
| Utilities: |
| - Ignoring python warnings |
| - Setting tags from command line |
| - Rich config printing |
| |
| :param cfg: A DictConfig object containing the config tree. |
| """ |
| |
| if not cfg.get("extras"): |
| log.warning("Extras config not found! <cfg.extras=null>") |
| return |
|
|
| |
| if cfg.extras.get("ignore_warnings"): |
| log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") |
| warnings.filterwarnings("ignore") |
|
|
| |
| if cfg.extras.get("enforce_tags"): |
| log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") |
| rich_utils.enforce_tags(cfg, save_to_file=True) |
|
|
| |
| if cfg.extras.get("print_config"): |
| log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") |
| rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) |
|
|
|
|
| def task_wrapper(task_func: Callable) -> Callable: |
| """Optional decorator that controls the failure behavior when executing the task function. |
| |
| This wrapper can be used to: |
| - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) |
| - save the exception to a `.log` file |
| - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) |
| - etc. (adjust depending on your needs) |
| |
| Example: |
| ``` |
| @utils.task_wrapper |
| def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| ... |
| return metric_dict, object_dict |
| ``` |
| |
| :param task_func: The task function to be wrapped. |
| |
| :return: The wrapped task function. |
| """ |
|
|
| def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| |
| try: |
| metric_dict, object_dict = task_func(cfg=cfg) |
|
|
| |
| except Exception as ex: |
| |
| log.exception("") |
|
|
| |
| |
| |
| raise ex |
|
|
| |
| finally: |
| |
| log.info(f"Output dir: {cfg.paths.output_dir}") |
|
|
| |
| if find_spec("wandb"): |
| import wandb |
|
|
| if wandb.run: |
| log.info("Closing wandb!") |
| wandb.finish() |
|
|
| return metric_dict, object_dict |
|
|
| return wrap |
|
|
|
|
| def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: |
| """Safely retrieves value of the metric logged in LightningModule. |
| |
| :param metric_dict: A dict containing metric values. |
| :param metric_name: The name of the metric to retrieve. |
| :return: The value of the metric. |
| """ |
| if not metric_name: |
| log.info("Metric name is None! Skipping metric value retrieval...") |
| return None |
|
|
| if metric_name not in metric_dict: |
| raise ValueError( |
| f"Metric value not found! <metric_name={metric_name}>\n" |
| "Make sure metric name logged in LightningModule is correct!\n" |
| "Make sure `optimized_metric` name in `hparams_search` config is correct!" |
| ) |
|
|
| metric_value = metric_dict[metric_name].item() |
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") |
|
|
| return metric_value |
|
|
|
|
| def intersperse(lst, item): |
| |
| result = [item] * (len(lst) * 2 + 1) |
| result[1::2] = lst |
| return result |
|
|
|
|
| def save_figure_to_numpy(fig): |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| return data |
|
|
|
|
| def plot_tensor(tensor): |
| plt.style.use("default") |
| fig, ax = plt.subplots(figsize=(12, 3)) |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") |
| plt.colorbar(im, ax=ax) |
| plt.tight_layout() |
| fig.canvas.draw() |
| data = save_figure_to_numpy(fig) |
| plt.close() |
| return data |
|
|
|
|
| def save_plot(tensor, savepath): |
| plt.style.use("default") |
| fig, ax = plt.subplots(figsize=(12, 3)) |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") |
| plt.colorbar(im, ax=ax) |
| plt.tight_layout() |
| fig.canvas.draw() |
| plt.savefig(savepath) |
| plt.close() |
|
|
|
|
| def to_numpy(tensor): |
| if isinstance(tensor, np.ndarray): |
| return tensor |
| elif isinstance(tensor, torch.Tensor): |
| return tensor.detach().cpu().numpy() |
| elif isinstance(tensor, list): |
| return np.array(tensor) |
| else: |
| raise TypeError("Unsupported type for conversion to numpy array") |
|
|
|
|
| def get_user_data_dir(appname="matcha_tts"): |
| """ |
| Args: |
| appname (str): Name of application |
| |
| Returns: |
| Path: path to user data directory |
| """ |
|
|
| MATCHA_HOME = os.environ.get("MATCHA_HOME") |
| if MATCHA_HOME is not None: |
| ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) |
| elif sys.platform == "win32": |
| import winreg |
|
|
| key = winreg.OpenKey( |
| winreg.HKEY_CURRENT_USER, |
| r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", |
| ) |
| dir_, _ = winreg.QueryValueEx(key, "Local AppData") |
| ans = Path(dir_).resolve(strict=False) |
| elif sys.platform == "darwin": |
| ans = Path("~/Library/Application Support/").expanduser() |
| else: |
| ans = Path.home().joinpath(".local/share") |
|
|
| final_path = ans.joinpath(appname) |
| final_path.mkdir(parents=True, exist_ok=True) |
| return final_path |
|
|
|
|
| def assert_model_downloaded(checkpoint_path, url, use_wget=True): |
| if Path(checkpoint_path).exists(): |
| log.debug(f"[+] Model already present at {checkpoint_path}!") |
| print(f"[+] Model already present at {checkpoint_path}!") |
| return |
| log.info(f"[-] Model not found at {checkpoint_path}! Will download it") |
| print(f"[-] Model not found at {checkpoint_path}! Will download it") |
| checkpoint_path = str(checkpoint_path) |
| if not use_wget: |
| gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) |
| else: |
| wget.download(url=url, out=checkpoint_path) |
|
|
|
|
| def get_phoneme_durations(durations, phones): |
| prev = durations[0] |
| merged_durations = [] |
| |
| for i in range(1, len(durations), 2): |
| if i == len(durations) - 2: |
| |
| next_half = durations[i + 1] |
| else: |
| next_half = ceil(durations[i + 1] / 2) |
|
|
| curr = prev + durations[i] + next_half |
| prev = durations[i + 1] - next_half |
| merged_durations.append(curr) |
|
|
| assert len(phones) == len(merged_durations) |
| assert len(merged_durations) == (len(durations) - 1) // 2 |
|
|
| merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) |
| start = torch.tensor(0) |
| duration_json = [] |
| for i, duration in enumerate(merged_durations): |
| duration_json.append( |
| { |
| phones[i]: { |
| "starttime": start.item(), |
| "endtime": duration.item(), |
| "duration": duration.item() - start.item(), |
| } |
| } |
| ) |
| start = duration |
|
|
| assert list(duration_json[-1].values())[0]["endtime"] == sum( |
| durations |
| ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" |
| return duration_json |
|
|