| | from pathlib import Path |
| | from typing import Sequence |
| |
|
| | import rich |
| | import rich.syntax |
| | import rich.tree |
| | from hydra.core.hydra_config import HydraConfig |
| | from lightning.pytorch.utilities import rank_zero_only |
| | from omegaconf import DictConfig, OmegaConf, open_dict |
| | from rich.prompt import Prompt |
| |
|
| | from fish_speech.utils import logger as log |
| |
|
| |
|
| | @rank_zero_only |
| | def print_config_tree( |
| | cfg: DictConfig, |
| | print_order: Sequence[str] = ( |
| | "data", |
| | "model", |
| | "callbacks", |
| | "logger", |
| | "trainer", |
| | "paths", |
| | "extras", |
| | ), |
| | resolve: bool = False, |
| | save_to_file: bool = False, |
| | ) -> None: |
| | """Prints content of DictConfig using Rich library and its tree structure. |
| | |
| | Args: |
| | cfg (DictConfig): Configuration composed by Hydra. |
| | print_order (Sequence[str], optional): Determines in what order config components are printed. |
| | resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
| | save_to_file (bool, optional): Whether to export config to the hydra output folder. |
| | """ |
| |
|
| | style = "dim" |
| | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
| |
|
| | queue = [] |
| |
|
| | |
| | for field in print_order: |
| | ( |
| | queue.append(field) |
| | if field in cfg |
| | else log.warning( |
| | f"Field '{field}' not found in config. " |
| | + f"Skipping '{field}' config printing..." |
| | ) |
| | ) |
| |
|
| | |
| | for field in cfg: |
| | if field not in queue: |
| | queue.append(field) |
| |
|
| | |
| | for field in queue: |
| | branch = tree.add(field, style=style, guide_style=style) |
| |
|
| | config_group = cfg[field] |
| | if isinstance(config_group, DictConfig): |
| | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
| | else: |
| | branch_content = str(config_group) |
| |
|
| | branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
| |
|
| | |
| | rich.print(tree) |
| |
|
| | |
| | if save_to_file: |
| | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: |
| | rich.print(tree, file=file) |
| |
|
| |
|
| | @rank_zero_only |
| | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: |
| | """Prompts user to input tags from command line if no tags are provided in config.""" |
| |
|
| | if not cfg.get("tags"): |
| | if "id" in HydraConfig().cfg.hydra.job: |
| | raise ValueError("Specify tags before launching a multirun!") |
| |
|
| | log.warning("No tags provided in config. Prompting user to input tags...") |
| | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") |
| | tags = [t.strip() for t in tags.split(",") if t != ""] |
| |
|
| | with open_dict(cfg): |
| | cfg.tags = tags |
| |
|
| | log.info(f"Tags: {cfg.tags}") |
| |
|
| | if save_to_file: |
| | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: |
| | rich.print(cfg.tags, file=file) |
| |
|