File size: 4,387 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import argparse
import os
import random
import yaml
import re
def get_cur_path():
"""Get the absolute path of current file.
Returns: The absolute path of this file (Config.py).
"""
return os.path.dirname(__file__)
DEFAULT_FILE = os.path.join(get_cur_path(), "default.yaml")
class Config(object):
""" The config parser of `LibContinual`
`Config` is used to parser *.yaml, console params to python dict. The rules for resolving merge conflicts are as follow
1. The merging is recursive, if a key is not be specified, the existing value will be used.
2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml)
"""
def __init__(self, config_file=None):
"""Initializing the parameter dictionary, completes the merging of all parameter.
Args:
config_file: Configuration file name. (/LibContinual/config/*.yaml)
"""
self.config_file = config_file
self.default_dict = self._load_config_files(DEFAULT_FILE)
self.file_dict = self._load_config_files(config_file)
self.console_dict = self._load_console_dict()
self.config_dict = self._merge_config_dict()
def get_config_dict(self):
""" Return the merged dict.
Returns:
dict: A dict of LibContinual setting.
"""
return self.config_dict
@staticmethod
def _load_config_files(config_file):
"""Parse a YAML file.
Args:
config_file (str): Path to yaml file.
Returns:
dict: A dict of LibContinual setting.
"""
config_dict = dict()
loader = yaml.SafeLoader
loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?[0-9][0-9_]*\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?[0-9][0-9_]*[eE][-+]?[0-9]+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
if config_file is not None:
with open(config_file, "r", encoding="utf-8") as fin:
config_dict.update(yaml.load(fin.read(), Loader=loader))
config_file_dict = config_dict.copy()
for include in config_dict.get("includes", []):
with open(os.path.join("./config/", include), "r", encoding="utf-8") as fin:
config_dict.update(yaml.load(fin.read(), Loader=loader))
if config_dict.get("includes") is not None:
config_dict.pop("includes")
config_dict.update(config_file_dict)
return config_dict
@staticmethod
def _load_console_dict():
"""Parsing command line parameters
Returns:
dict: A dict of LibContinual console setting.
"""
pass
@staticmethod
def _update(dic1, dic2):
"""Merge dictionaries.
Used to merge two dictionaries (profiles), `dic2` will overwrite the value of the same key in `dic1`.
Args:
dic1 (dict): The dict to be overwritten. (low priority)
dic2 (dict): The dict to overwrite. (high priority)
Returns:
dict: Merged dict.
"""
if dic1 is None:
dic1 = dict()
if dic2 is not None:
for k in dic2.keys():
dic1[k] = dic2[k]
return dic1
def _merge_config_dict(self):
"""Merge all dictionaries. Merge rules are as follow
1. The merging is recursive, if a key is not be specified, the existing value will be used.
2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml)
Returns:
dict: A complete dict of LibContinual setting.
"""
config_dict = dict()
config_dict = self._update(config_dict, self.default_dict)
config_dict = self._update(config_dict, self.file_dict)
config_dict = self._update(config_dict, self.console_dict)
return config_dict |