| """Parser for config files that can be used with absl flags.""" |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import re |
| import sys |
| import traceback |
| from typing import Any |
|
|
| from absl import flags |
| from ml_collections import ConfigDict, FieldReference |
| from ml_collections.config_flags.config_flags import ( |
| _ConfigFlag, |
| _ErrorConfig, |
| _LockConfig, |
| ) |
|
|
| from vis4d.config import copy_and_resolve_references |
| from vis4d.config.registry import get_config_by_name |
|
|
|
|
| class ConfigFileParser(flags.ArgumentParser): |
| """Parser for config files.""" |
|
|
| def __init__( |
| self, |
| name: str, |
| lock_config: bool = True, |
| method_name: str = "get_config", |
| ) -> None: |
| """Initializes the parser. |
| |
| Args: |
| name (str): The name of the flag (e.g. config for --config flag) |
| lock_config (bool, optional): Whether or not to lock the config. |
| Defaults to True. |
| method_name (str, optional): Name of the method to call in the |
| config. Defaults to "get_config". |
| """ |
| self.name = name |
| self._lock_config = lock_config |
| self.method_name = method_name |
|
|
| def parse( |
| self, path: str |
| ) -> ConfigDict | _ErrorConfig: |
| """Loads a config module from `path` and returns the `method_name()`. |
| |
| This implementation is based on the original ml_collections and |
| modified to allow for a custom method name. |
| |
| If a colon is present in `path`, everything to the right of the first |
| colon is passed to `method_name` as an argument. This allows the |
| structure of what |
| is returned to be modified, which is useful when performing complex |
| hyperparameter sweeps. |
| |
| Args: |
| path: string, path pointing to the config file to execute. May also |
| contain a config_string argument, e.g. be of the form |
| "config.py:some_configuration". |
| |
| Returns: |
| Result of calling `method_name` in the specified module. |
| """ |
| |
| |
| split_path = path.split(":", 1) |
|
|
| try: |
| config = get_config_by_name( |
| split_path[0], |
| *split_path[1:], |
| method_name=self.method_name, |
| ) |
| if config is None: |
| logging.warning( |
| "%s:%s() returned None, did you forget a return " |
| "statement?", |
| path, |
| self.method_name, |
| ) |
| except IOError as e: |
| |
| |
| return _ErrorConfig(e) |
| |
| |
| |
| except (TypeError, ValueError) as e: |
| error_trace = traceback.format_exc() |
| raise type(e)( |
| "Error whilst parsing config file:\n\n" + error_trace |
| ) |
|
|
| if self._lock_config: |
| _LockConfig(config) |
|
|
| return config |
|
|
| def flag_type(self) -> str: |
| """Returns the type of the flag.""" |
| return "config object" |
|
|
|
|
| def DEFINE_config_file( |
| name: str, |
| default: str | None = None, |
| help_string: str = "path to config file [.py |.yaml].", |
| lock_config: bool = False, |
| method_name: str = "get_config", |
| ) -> flags.FlagHolder: |
| """Registers a new flag for a config file. |
| |
| Args: |
| name (str): The name of the flag (e.g. config for --config flag) |
| default (str | None, optional): Default Value. Defaults to None. |
| help_string (str, optional): Help String. |
| Defaults to "path to config file.". |
| lock_config (bool, optional): Whether or note to lock the returned |
| config. Defaults to False. |
| method_name (str, optional): Name of the method to call in the config. |
| |
| Returns: |
| flags.FlagHolder: Flag holder instance. |
| """ |
| parser = ConfigFileParser( |
| name=name, lock_config=lock_config, method_name=method_name |
| ) |
| flag = _ConfigFlag( |
| parser=parser, |
| serializer=flags.ArgumentSerializer(), |
| name=name, |
| default=default, |
| help_string=help_string, |
| flag_values=flags.FLAGS, |
| ) |
|
|
| |
| module_name = sys._getframe( |
| 1 |
| ).f_globals.get("__name__", None) |
| module_name = sys.argv[0] if module_name == "__main__" else module_name |
| return flags.DEFINE_flag(flag, flags.FLAGS, module_name=module_name) |
|
|
|
|
| def pprints_config(data: ConfigDict) -> str: |
| """Converts a Config Dict into a string with a .yaml like structure. |
| |
| This function differs from __repr__ of ConfigDict in that it will not |
| encode python classes using binary formats but just prints the __repr__ |
| of these classes. |
| |
| Args: |
| data (ConfigDict): Configuration dict to convert to string |
| |
| Returns: |
| str: A string representation of the ConfigDict |
| """ |
| return _pprints_config(copy_and_resolve_references(data)) |
|
|
|
|
| def _pprints_config( |
| data: Any, prefix: str = "", n_indents: int = 1 |
| ) -> str: |
| """Converts a ConfigDict into a string with a YAML like structure. |
| |
| This is the recursive implementation of 'pprints_config' and will be called |
| recursively for every element in the dict. |
| |
| This function differs from __repr__ of ConfigDict in that it will not |
| encode python classes using binary formats but just prints the __repr__ |
| of these classes. |
| |
| Args: |
| data (Any): Configuration dict or object to convert to |
| string |
| prefix (str): Prefix to print on each new line |
| n_indents (int): Number of spaces to append for each nester property. |
| |
| Returns: |
| str: A string representation of the ConfigDict |
| """ |
| string_repr = "" |
| if isinstance(data, FieldReference): |
| data = data.get() |
|
|
| if not isinstance(data, (dict, ConfigDict, list, tuple, dict)): |
| return str(data) |
|
|
| string_repr += "\n" |
|
|
| if isinstance(data, (ConfigDict, dict)): |
| for key in data: |
| value = data[key] |
| string_repr += ( |
| prefix |
| + key |
| + ": " |
| + _pprints_config(value, prefix=prefix + " " * n_indents) |
| ) + "\n" |
|
|
| elif isinstance(data, (list, tuple)): |
| for value in data: |
| string_repr += prefix + "- " |
| if isinstance(value, (ConfigDict, dict)): |
| string_repr += "\n" |
|
|
| string_repr += ( |
| _pprints_config(value, prefix=prefix + " " + " " * n_indents) |
| + "\n" |
| ) |
| string_repr += " \n" |
|
|
| |
| string_repr = re.sub("\n\n+", "\n", string_repr) |
| return re.sub("- +\n +", "- ", string_repr) |
|
|