Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import warnings | |
| from importlib.util import find_spec | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Tuple | |
| import gdown | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import wget | |
| from omegaconf import DictConfig | |
| from matcha.utils import pylogger, rich_utils | |
| log = pylogger.get_pylogger(__name__) | |
| def extras(cfg: DictConfig) -> None: | |
| """Applies optional utilities before the task is started. | |
| Utilities: | |
| - Ignoring python warnings | |
| - Setting tags from command line | |
| - Rich config printing | |
| :param cfg: A DictConfig object containing the config tree. | |
| """ | |
| # return if no `extras` config | |
| if not cfg.get("extras"): | |
| log.warning("Extras config not found! <cfg.extras=null>") | |
| return | |
| # disable python warnings | |
| if cfg.extras.get("ignore_warnings"): | |
| log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") | |
| warnings.filterwarnings("ignore") | |
| # prompt user to input tags from command line if none are provided in the config | |
| if cfg.extras.get("enforce_tags"): | |
| log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") | |
| rich_utils.enforce_tags(cfg, save_to_file=True) | |
| # pretty print config tree using Rich library | |
| if cfg.extras.get("print_config"): | |
| log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") | |
| rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) | |
| def task_wrapper(task_func: Callable) -> Callable: | |
| """Optional decorator that controls the failure behavior when executing the task function. | |
| This wrapper can be used to: | |
| - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) | |
| - save the exception to a `.log` file | |
| - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) | |
| - etc. (adjust depending on your needs) | |
| Example: | |
| ``` | |
| @utils.task_wrapper | |
| def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| ... | |
| return metric_dict, object_dict | |
| ``` | |
| :param task_func: The task function to be wrapped. | |
| :return: The wrapped task function. | |
| """ | |
| def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| # execute the task | |
| try: | |
| metric_dict, object_dict = task_func(cfg=cfg) | |
| # things to do if exception occurs | |
| except Exception as ex: | |
| # save exception to `.log` file | |
| log.exception("") | |
| # some hyperparameter combinations might be invalid or cause out-of-memory errors | |
| # so when using hparam search plugins like Optuna, you might want to disable | |
| # raising the below exception to avoid multirun failure | |
| raise ex | |
| # things to always do after either success or exception | |
| finally: | |
| # display output dir path in terminal | |
| log.info(f"Output dir: {cfg.paths.output_dir}") | |
| # always close wandb run (even if exception occurs so multirun won't fail) | |
| if find_spec("wandb"): # check if wandb is installed | |
| import wandb | |
| if wandb.run: | |
| log.info("Closing wandb!") | |
| wandb.finish() | |
| return metric_dict, object_dict | |
| return wrap | |
| def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: | |
| """Safely retrieves value of the metric logged in LightningModule. | |
| :param metric_dict: A dict containing metric values. | |
| :param metric_name: The name of the metric to retrieve. | |
| :return: The value of the metric. | |
| """ | |
| if not metric_name: | |
| log.info("Metric name is None! Skipping metric value retrieval...") | |
| return None | |
| if metric_name not in metric_dict: | |
| raise ValueError( | |
| f"Metric value not found! <metric_name={metric_name}>\n" | |
| "Make sure metric name logged in LightningModule is correct!\n" | |
| "Make sure `optimized_metric` name in `hparams_search` config is correct!" | |
| ) | |
| metric_value = metric_dict[metric_name].item() | |
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") | |
| return metric_value | |
| def intersperse(lst, item): | |
| # Adds blank symbol | |
| result = [item] * (len(lst) * 2 + 1) | |
| result[1::2] = lst | |
| return result | |
| def save_figure_to_numpy(fig): | |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data | |
| def plot_tensor(tensor): | |
| plt.style.use("default") | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| data = save_figure_to_numpy(fig) | |
| plt.close() | |
| return data | |
| def save_plot(tensor, savepath): | |
| plt.style.use("default") | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| plt.savefig(savepath) | |
| plt.close() | |
| def to_numpy(tensor): | |
| if isinstance(tensor, np.ndarray): | |
| return tensor | |
| elif isinstance(tensor, torch.Tensor): | |
| return tensor.detach().cpu().numpy() | |
| elif isinstance(tensor, list): | |
| return np.array(tensor) | |
| else: | |
| raise TypeError("Unsupported type for conversion to numpy array") | |
| def get_user_data_dir(appname="matcha_tts"): | |
| """ | |
| Args: | |
| appname (str): Name of application | |
| Returns: | |
| Path: path to user data directory | |
| """ | |
| MATCHA_HOME = os.environ.get("MATCHA_HOME") | |
| if MATCHA_HOME is not None: | |
| ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) | |
| elif sys.platform == "win32": | |
| import winreg # pylint: disable=import-outside-toplevel | |
| key = winreg.OpenKey( | |
| winreg.HKEY_CURRENT_USER, | |
| r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", | |
| ) | |
| dir_, _ = winreg.QueryValueEx(key, "Local AppData") | |
| ans = Path(dir_).resolve(strict=False) | |
| elif sys.platform == "darwin": | |
| ans = Path("~/Library/Application Support/").expanduser() | |
| else: | |
| ans = Path.home().joinpath(".local/share") | |
| final_path = ans.joinpath(appname) | |
| final_path.mkdir(parents=True, exist_ok=True) | |
| return final_path | |
| def assert_model_downloaded(checkpoint_path, url, use_wget=True): | |
| if Path(checkpoint_path).exists(): | |
| log.debug(f"[+] Model already present at {checkpoint_path}!") | |
| print(f"[+] Model already present at {checkpoint_path}!") | |
| return | |
| log.info(f"[-] Model not found at {checkpoint_path}! Will download it") | |
| print(f"[-] Model not found at {checkpoint_path}! Will download it") | |
| checkpoint_path = str(checkpoint_path) | |
| if not use_wget: | |
| gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) | |
| else: | |
| wget.download(url=url, out=checkpoint_path) | |