| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import logging |
|
|
| from omegaconf import OmegaConf |
|
|
|
|
| class Config: |
| def __init__(self, args): |
| self.config = {} |
|
|
| self.args = args |
| user_config = self._build_opt_list(self.args.options) |
| config = OmegaConf.load(self.args.cfg_path) |
| config = OmegaConf.merge(config, user_config) |
| self.config = config |
|
|
| def _convert_to_dot_list(self, opts): |
| if opts is None: |
| opts = [] |
|
|
| if len(opts) == 0: |
| return opts |
|
|
| has_equal = opts[0].find("=") != -1 |
|
|
| if has_equal: |
| return opts |
|
|
| return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] |
|
|
| def _build_opt_list(self, opts): |
| opts_dot_list = self._convert_to_dot_list(opts) |
| return OmegaConf.from_dotlist(opts_dot_list) |
|
|
| def pretty_print(self): |
| logging.info("\n===== Running Parameters =====") |
| logging.info(self._convert_node_to_json(self.config.run)) |
|
|
| logging.info("\n====== Dataset Attributes ======") |
| logging.info(self._convert_node_to_json(self.config.datasets)) |
|
|
| logging.info(f"\n====== Model Attributes ======") |
| logging.info(self._convert_node_to_json(self.config.model)) |
|
|
| def _convert_node_to_json(self, node): |
| container = OmegaConf.to_container(node, resolve=True) |
| return json.dumps(container, indent=4, sort_keys=True) |
|
|
| def to_dict(self): |
| return OmegaConf.to_container(self.config) |
|
|