LibContinual / core /config /config.py
boringKey's picture
Upload 236 files
5fee096 verified
import argparse
import os
import random
import yaml
import re
def get_cur_path():
"""Get the absolute path of current file.
Returns: The absolute path of this file (Config.py).
"""
return os.path.dirname(__file__)
DEFAULT_FILE = os.path.join(get_cur_path(), "default.yaml")
class Config(object):
""" The config parser of `LibContinual`
`Config` is used to parser *.yaml, console params to python dict. The rules for resolving merge conflicts are as follow
1. The merging is recursive, if a key is not be specified, the existing value will be used.
2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml)
"""
def __init__(self, config_file=None):
"""Initializing the parameter dictionary, completes the merging of all parameter.
Args:
config_file: Configuration file name. (/LibContinual/config/*.yaml)
"""
self.config_file = config_file
self.default_dict = self._load_config_files(DEFAULT_FILE)
self.file_dict = self._load_config_files(config_file)
self.console_dict = self._load_console_dict()
self.config_dict = self._merge_config_dict()
def get_config_dict(self):
""" Return the merged dict.
Returns:
dict: A dict of LibContinual setting.
"""
return self.config_dict
@staticmethod
def _load_config_files(config_file):
"""Parse a YAML file.
Args:
config_file (str): Path to yaml file.
Returns:
dict: A dict of LibContinual setting.
"""
config_dict = dict()
loader = yaml.SafeLoader
loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?[0-9][0-9_]*\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?[0-9][0-9_]*[eE][-+]?[0-9]+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
if config_file is not None:
with open(config_file, "r", encoding="utf-8") as fin:
config_dict.update(yaml.load(fin.read(), Loader=loader))
config_file_dict = config_dict.copy()
for include in config_dict.get("includes", []):
with open(os.path.join("./config/", include), "r", encoding="utf-8") as fin:
config_dict.update(yaml.load(fin.read(), Loader=loader))
if config_dict.get("includes") is not None:
config_dict.pop("includes")
config_dict.update(config_file_dict)
return config_dict
@staticmethod
def _load_console_dict():
"""Parsing command line parameters
Returns:
dict: A dict of LibContinual console setting.
"""
pass
@staticmethod
def _update(dic1, dic2):
"""Merge dictionaries.
Used to merge two dictionaries (profiles), `dic2` will overwrite the value of the same key in `dic1`.
Args:
dic1 (dict): The dict to be overwritten. (low priority)
dic2 (dict): The dict to overwrite. (high priority)
Returns:
dict: Merged dict.
"""
if dic1 is None:
dic1 = dict()
if dic2 is not None:
for k in dic2.keys():
dic1[k] = dic2[k]
return dic1
def _merge_config_dict(self):
"""Merge all dictionaries. Merge rules are as follow
1. The merging is recursive, if a key is not be specified, the existing value will be used.
2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml)
Returns:
dict: A complete dict of LibContinual setting.
"""
config_dict = dict()
config_dict = self._update(config_dict, self.default_dict)
config_dict = self._update(config_dict, self.file_dict)
config_dict = self._update(config_dict, self.console_dict)
return config_dict