| """ |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation |
| |
| Official implementation of the paper: |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis |
| Licensed under a modified MIT license |
| """ |
|
|
| from pathlib import Path |
| from typing import Sequence |
|
|
| import rich |
| import rich.syntax |
| import rich.tree |
| from hydra.core.hydra_config import HydraConfig |
| from omegaconf import DictConfig, OmegaConf, open_dict |
| from pytorch_lightning.utilities import rank_zero_only |
| from rich.prompt import Prompt |
|
|
| from . import pylogger |
|
|
| log = pylogger.get_pylogger(__name__) |
|
|
|
|
| @rank_zero_only |
| def print_config_tree( |
| cfg: DictConfig, |
| print_order: Sequence[str] = ( |
| "datamodule", |
| "model", |
| "callbacks", |
| "logger", |
| "trainer", |
| "paths", |
| "extras", |
| ), |
| resolve: bool = False, |
| save_to_file: bool = False, |
| ) -> None: |
| """Prints content of DictConfig using Rich library and its tree structure. |
| |
| Args: |
| cfg (DictConfig): Configuration composed by Hydra. |
| print_order (Sequence[str], optional): Determines in what order config components are printed. |
| resolve (bool, optional): Whether to resolve reference fields of DictConfig. |
| save_to_file (bool, optional): Whether to export config to the hydra output folder. |
| """ |
|
|
| style = "dim" |
| tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
| queue = [] |
|
|
| |
| for field in print_order: |
| queue.append(field) if field in cfg else log.warning( |
| f"Field '{field}' not found in config. Skipping '{field}' config printing..." |
| ) |
|
|
| |
| for field in cfg: |
| if field not in queue: |
| queue.append(field) |
|
|
| |
| for field in queue: |
| branch = tree.add(field, style=style, guide_style=style) |
|
|
| config_group = cfg[field] |
| if isinstance(config_group, DictConfig): |
| branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
| else: |
| branch_content = str(config_group) |
|
|
| branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
|
|
| |
| rich.print(tree) |
|
|
| |
| if save_to_file: |
| with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: |
| rich.print(tree, file=file) |
|
|
|
|
| @rank_zero_only |
| def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: |
| """Prompts user to input tags from command line if no tags are provided in config.""" |
|
|
| if not cfg.get("tags"): |
| if "id" in HydraConfig().cfg.hydra.job: |
| raise ValueError("Specify tags before launching a multirun!") |
|
|
| log.warning("No tags provided in config. Prompting user to input tags...") |
| tags = Prompt.ask("Enter a list of comma separated tags", default="dev") |
| tags = [t.strip() for t in tags.split(",") if t != ""] |
|
|
| with open_dict(cfg): |
| cfg.tags = tags |
|
|
| log.info(f"Tags: {cfg.tags}") |
|
|
| if save_to_file: |
| with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: |
| rich.print(cfg.tags, file=file) |
|
|
|
|
| if __name__ == "__main__": |
| from hydra import compose, initialize |
|
|
| with initialize(version_base="1.2", config_path="../../configs_hydra"): |
| cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) |
| print_config_tree(cfg, resolve=False, save_to_file=False) |
|
|