Spaces:
Runtime error
Runtime error
| import sys | |
| import json | |
| import rich | |
| import rich.text | |
| import rich.tree | |
| import rich.syntax | |
| import hydra | |
| from typing import List, Optional, Union, Any | |
| from pathlib import Path | |
| from omegaconf import OmegaConf, DictConfig, ListConfig | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from lib.info.log import get_logger | |
| from .proj_manager import ProjManager as PM | |
| def get_PM_info_dict(): | |
| ''' Get a OmegaConf object containing the information from the ProjManager. ''' | |
| PM_info = OmegaConf.create({ | |
| '_pm_': { | |
| 'root' : str(PM.root), | |
| 'inputs' : str(PM.inputs), | |
| 'outputs': str(PM.outputs), | |
| } | |
| }) | |
| return PM_info | |
| def get_PM_info_list(): | |
| ''' Get a list containing the information from the ProjManager. ''' | |
| PM_info = [ | |
| f'_pm_.root={str(PM.root)}', | |
| f'_pm_.inputs={str(PM.inputs)}', | |
| f'_pm_.outputs={str(PM.outputs)}', | |
| ] | |
| return PM_info | |
| def entrypoint_with_args(*args, log_cfg=True, **kwargs): | |
| ''' | |
| This decorator extends the `hydra.main` decorator in these parts: | |
| - Inject some runtime-known arguments, e.g., `proj_root`. | |
| - Enable additional arguments that needn't to be specified in command line. | |
| - Positional arguments are added to the command line arguments directly, so make sure they are valid. | |
| - e.g., \'exp=<...>\', \'+extra=<...>\', etc. | |
| - Key-specified arguments have the same effect as command line arguments {k}={v}. | |
| - Check the validation of experiment name. | |
| ''' | |
| overrides = get_PM_info_list() | |
| for arg in args: | |
| overrides.append(arg) | |
| for k, v in kwargs.items(): | |
| overrides.append(f'{k}={v}') | |
| overrides.extend(sys.argv[1:]) | |
| def entrypoint_wrapper(func): | |
| # Import extra pre-specified arguments. | |
| if len(overrides) > 0: | |
| # The args from command line have higher priority, so put them in the back. | |
| sys.argv = sys.argv[:1] + overrides + sys.argv[1:] | |
| _log_exp_info(func.__name__, overrides) | |
| def entrypoint_preprocess(cfg:DictConfig): | |
| # Resolve the references and make it editable. | |
| cfg = unfold_cfg(cfg) | |
| # Print out the configuration files. | |
| if log_cfg and cfg.get('show_cfg', True): | |
| sum_keys = ['output_dir', 'pipeline.name', 'data.name', 'exp_name', 'exp_tag'] | |
| print_cfg(cfg, sum_keys=sum_keys) | |
| # Check the validation of experiment name. | |
| if cfg.get('exp_name') is None: | |
| get_logger(brief=True).fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.') | |
| raise ValueError('`exp_name` is not given!') | |
| # Bind config. | |
| PM.init_with_cfg(cfg) | |
| try: | |
| with PM.time_monitor('exp', f'Main part of experiment `{cfg.exp_name}`.'): | |
| # Enter the main function. | |
| func(cfg) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| PM.time_monitor.report(level='global') | |
| # TODO: Wrap a notifier here. | |
| return entrypoint_preprocess | |
| return entrypoint_wrapper | |
| #! This implementation can't dump the config files in default ways. In order to keep c | |
| # def entrypoint_wrapper(func): | |
| # def entrypoint_preprocess(): | |
| # # Initialize the configuration module. | |
| # with hydra.initialize_config_dir(version_base=None, config_dir=str(PM.configs)): | |
| # get_logger(brief=True).info(f'Exp entry `{func.__name__}` is called with overrides: {overrides}') | |
| # cfg = hydra.compose(config_name='base', overrides=overrides) | |
| # cfg4dump_raw = cfg.copy() # store the folded raw configuration files | |
| # # Resolve the references and make it editable. | |
| # cfg = unfold_cfg(cfg) | |
| # # Print out the configuration files. | |
| # if log_cfg: | |
| # sum_keys = ['pipeline.name', 'data.name', 'exp_name'] | |
| # print_cfg(cfg, sum_keys=sum_keys) | |
| # # Check the validation of experiment name. | |
| # if cfg.get('exp_name') is None: | |
| # get_logger().fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.') | |
| # raise ValueError('`exp_name` is not given!') | |
| # # Enter the main function. | |
| # func(cfg) | |
| # return entrypoint_preprocess | |
| # return entrypoint_wrapper | |
| def entrypoint(func): | |
| ''' | |
| This decorator extends the `hydra.main` decorator in these parts: | |
| - Inject some runtime-known arguments, e.g., `proj_root`. | |
| - Check the validation of experiment name. | |
| ''' | |
| return entrypoint_with_args()(func) | |
| def unfold_cfg( | |
| cfg : Union[DictConfig, Any], | |
| ): | |
| ''' | |
| Unfold the configuration files, i.e. from structured mode to container mode and recreate the | |
| configuration files. It will resolve all the references and make the config editable. | |
| ### Args | |
| - cfg: DictConfig or None | |
| ### Returns | |
| - cfg: DictConfig or None | |
| ''' | |
| if cfg is None: | |
| return None | |
| cfg_container = OmegaConf.to_container(cfg, resolve=True) | |
| cfg = OmegaConf.create(cfg_container) | |
| return cfg | |
| def recursively_simplify_cfg( | |
| node : DictConfig, | |
| hide_misc : bool = True, | |
| ): | |
| if isinstance(node, DictConfig): | |
| for k in list(node.keys()): | |
| # We delete some terms that are not commonly concerned. | |
| if hide_misc: | |
| if k in ['_hub_', 'hydra', 'job_logging']: | |
| node.__delattr__(k) | |
| continue | |
| node[k] = recursively_simplify_cfg(node[k], hide_misc) | |
| elif isinstance(node, ListConfig): | |
| if len(node) > 0 and all([ | |
| not isinstance(x, DictConfig) \ | |
| and not isinstance(x, ListConfig) \ | |
| for x in node | |
| ]): | |
| # We fold all lists of basic elements (int, float, ...) into a single line if possible. | |
| folded_list_str = '*' + str(list(node)) | |
| node = folded_list_str if len(folded_list_str) < 320 else node | |
| else: | |
| for i in range(len(node)): | |
| node[i] = recursively_simplify_cfg(node[i], hide_misc) | |
| return node | |
| def print_cfg( | |
| cfg : Optional[DictConfig], | |
| title : str ='cfg', | |
| sum_keys: List[str] = [], | |
| show_all: bool = False | |
| ): | |
| ''' | |
| Print configuration files using rich. | |
| ### Args | |
| - cfg: DictConfig or None | |
| - If None, print nothing. | |
| - sum_keys: List[str], default [] | |
| - If keys given in the list exist in the first level of the configuration files, | |
| they will be printed in the summary part. | |
| - show_all: bool, default False | |
| - If False, hide terms starts with `_` in the configuration files's first level | |
| and some hydra supporting configs. | |
| ''' | |
| theme = 'coffee' | |
| style = 'dim' | |
| tf_dict = { True: '◼', False: '◻' } | |
| print_setting = f'<< {tf_dict[show_all]} SHOW_ALL >>' | |
| tree = rich.tree.Tree(f'⌾ {title} - {print_setting}', style=style, guide_style=style) | |
| if cfg is None: | |
| tree.add('None') | |
| rich.print(tree) | |
| return | |
| # Clone a new one to avoid changing the original configuration files. | |
| cfg = cfg.copy() | |
| cfg = unfold_cfg(cfg) | |
| if not show_all: | |
| cfg = recursively_simplify_cfg(cfg) | |
| cfg_yaml = OmegaConf.to_yaml(cfg) | |
| cfg_yaml = rich.syntax.Syntax(cfg_yaml, 'yaml', theme=theme, line_numbers=True) | |
| tree.add(cfg_yaml) | |
| # Add a summary containing information only is commonly concerned. | |
| if len(sum_keys) > 0: | |
| concerned = {} | |
| for k_str in sum_keys: | |
| k_list = k_str.split('.') | |
| tgt = cfg | |
| for k in k_list: | |
| if tgt is not None: | |
| tgt = tgt.get(k) | |
| if tgt is not None: | |
| concerned[k_str] = tgt | |
| else: | |
| get_logger().warning(f'Key `{k_str}` is not found in the configuration files.') | |
| tree.add(rich.syntax.Syntax(OmegaConf.to_yaml(concerned), 'yaml', theme=theme)) | |
| rich.print(tree) | |
| def _log_exp_info( | |
| func_name : str, | |
| overrides : List[str], | |
| ): | |
| get_logger(brief=True).info(f'Exp entry `{func_name}` is called with overrides: {overrides}') |