| |
| |
| |
| import os, sys |
| import os.path as osp |
| import ast |
| import tempfile |
| import shutil |
| from importlib import import_module |
|
|
| from argparse import Action |
|
|
| from addict import Dict |
| from yapf.yapflib.yapf_api import FormatCode |
|
|
| import platform |
| MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) |
|
|
| BASE_KEY = '_base_' |
| DELETE_KEY = '_delete_' |
| RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict'] |
|
|
|
|
| def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): |
| if not osp.isfile(filename): |
| raise FileNotFoundError(msg_tmpl.format(filename)) |
|
|
| class ConfigDict(Dict): |
|
|
| def __missing__(self, name): |
| raise KeyError(name) |
|
|
| def __getattr__(self, name): |
| try: |
| value = super(ConfigDict, self).__getattr__(name) |
| except KeyError: |
| ex = AttributeError(f"'{self.__class__.__name__}' object has no " |
| f"attribute '{name}'") |
| except Exception as e: |
| ex = e |
| else: |
| return value |
| raise ex |
|
|
|
|
| class SLConfig(object): |
| """ |
| config files. |
| only support .py file as config now. |
| |
| ref: mmcv.utils.config |
| |
| Example: |
| >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) |
| >>> cfg.a |
| 1 |
| >>> cfg.b |
| {'b1': [0, 1]} |
| >>> cfg.b.b1 |
| [0, 1] |
| >>> cfg = Config.fromfile('tests/data/config/a.py') |
| >>> cfg.filename |
| "/home/kchen/projects/mmcv/tests/data/config/a.py" |
| >>> cfg.item4 |
| 'test' |
| >>> cfg |
| "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " |
| "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" |
| """ |
| @staticmethod |
| def _validate_py_syntax(filename): |
| with open(filename) as f: |
| content = f.read() |
| try: |
| ast.parse(content) |
| except SyntaxError: |
| raise SyntaxError('There are syntax errors in config ' |
| f'file {filename}') |
|
|
| @staticmethod |
| def _file2dict(filename): |
| filename = osp.abspath(osp.expanduser(filename)) |
| check_file_exist(filename) |
| if filename.lower().endswith('.py'): |
| with tempfile.TemporaryDirectory() as temp_config_dir: |
| temp_config_file = tempfile.NamedTemporaryFile( |
| dir=temp_config_dir, suffix='.py') |
| temp_config_name = osp.basename(temp_config_file.name) |
| if WINDOWS: |
| temp_config_file.close() |
| shutil.copyfile(filename, |
| osp.join(temp_config_dir, temp_config_name)) |
| temp_module_name = osp.splitext(temp_config_name)[0] |
| sys.path.insert(0, temp_config_dir) |
| SLConfig._validate_py_syntax(filename) |
| mod = import_module(temp_module_name) |
| sys.path.pop(0) |
| cfg_dict = { |
| name: value |
| for name, value in mod.__dict__.items() |
| if not name.startswith('__') |
| } |
| |
| del sys.modules[temp_module_name] |
| |
| temp_config_file.close() |
| elif filename.lower().endswith(('.yml', '.yaml', '.json')): |
| from .slio import slload |
| cfg_dict = slload(filename) |
| else: |
| raise IOError('Only py/yml/yaml/json type are supported now!') |
|
|
| cfg_text = filename + '\n' |
| with open(filename, 'r') as f: |
| cfg_text += f.read() |
|
|
| |
| if BASE_KEY in cfg_dict: |
| cfg_dir = osp.dirname(filename) |
| base_filename = cfg_dict.pop(BASE_KEY) |
| base_filename = base_filename if isinstance( |
| base_filename, list) else [base_filename] |
|
|
| cfg_dict_list = list() |
| cfg_text_list = list() |
| for f in base_filename: |
| _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f)) |
| cfg_dict_list.append(_cfg_dict) |
| cfg_text_list.append(_cfg_text) |
|
|
| base_cfg_dict = dict() |
| for c in cfg_dict_list: |
| if len(base_cfg_dict.keys() & c.keys()) > 0: |
| raise KeyError('Duplicate key is not allowed among bases') |
| |
| base_cfg_dict.update(c) |
|
|
| base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict) |
| cfg_dict = base_cfg_dict |
|
|
| |
| cfg_text_list.append(cfg_text) |
| cfg_text = '\n'.join(cfg_text_list) |
|
|
| return cfg_dict, cfg_text |
|
|
| @staticmethod |
| def _merge_a_into_b(a, b): |
| """merge dict `a` into dict `b` (non-inplace). |
| values in `a` will overwrite `b`. |
| copy first to avoid inplace modification |
| |
| Args: |
| a ([type]): [description] |
| b ([type]): [description] |
| |
| Returns: |
| [dict]: [description] |
| """ |
|
|
| if not isinstance(a, dict): |
| return a |
|
|
| b = b.copy() |
| for k, v in a.items(): |
| if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): |
| |
| if not isinstance(b[k], dict) and not isinstance(b[k], list): |
| |
|
|
| raise TypeError( |
| f'{k}={v} in child config cannot inherit from base ' |
| f'because {k} is a dict in the child config but is of ' |
| f'type {type(b[k])} in base config. You may set ' |
| f'`{DELETE_KEY}=True` to ignore the base config') |
| b[k] = SLConfig._merge_a_into_b(v, b[k]) |
| elif isinstance(b, list): |
| try: |
| _ = int(k) |
| except: |
| raise TypeError( |
| f'b is a list, ' |
| f'index {k} should be an int when input but {type(k)}' |
| ) |
| b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)]) |
| else: |
| b[k] = v |
| |
| return b |
|
|
| @staticmethod |
| def fromfile(filename): |
| cfg_dict, cfg_text = SLConfig._file2dict(filename) |
| return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename) |
|
|
|
|
| def __init__(self, cfg_dict=None, cfg_text=None, filename=None): |
| if cfg_dict is None: |
| cfg_dict = dict() |
| elif not isinstance(cfg_dict, dict): |
| raise TypeError('cfg_dict must be a dict, but ' |
| f'got {type(cfg_dict)}') |
| for key in cfg_dict: |
| if key in RESERVED_KEYS: |
| raise KeyError(f'{key} is reserved for config file') |
|
|
| super(SLConfig, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) |
| super(SLConfig, self).__setattr__('_filename', filename) |
| if cfg_text: |
| text = cfg_text |
| elif filename: |
| with open(filename, 'r') as f: |
| text = f.read() |
| else: |
| text = '' |
| super(SLConfig, self).__setattr__('_text', text) |
|
|
|
|
| @property |
| def filename(self): |
| return self._filename |
|
|
| @property |
| def text(self): |
| return self._text |
|
|
| @property |
| def pretty_text(self): |
|
|
| indent = 4 |
|
|
| def _indent(s_, num_spaces): |
| s = s_.split('\n') |
| if len(s) == 1: |
| return s_ |
| first = s.pop(0) |
| s = [(num_spaces * ' ') + line for line in s] |
| s = '\n'.join(s) |
| s = first + '\n' + s |
| return s |
|
|
| def _format_basic_types(k, v, use_mapping=False): |
| if isinstance(v, str): |
| v_str = f"'{v}'" |
| else: |
| v_str = str(v) |
|
|
| if use_mapping: |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) |
| attr_str = f'{k_str}: {v_str}' |
| else: |
| attr_str = f'{str(k)}={v_str}' |
| attr_str = _indent(attr_str, indent) |
|
|
| return attr_str |
|
|
| def _format_list(k, v, use_mapping=False): |
| |
| if all(isinstance(_, dict) for _ in v): |
| v_str = '[\n' |
| v_str += '\n'.join( |
| f'dict({_indent(_format_dict(v_), indent)}),' |
| for v_ in v).rstrip(',') |
| if use_mapping: |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) |
| attr_str = f'{k_str}: {v_str}' |
| else: |
| attr_str = f'{str(k)}={v_str}' |
| attr_str = _indent(attr_str, indent) + ']' |
| else: |
| attr_str = _format_basic_types(k, v, use_mapping) |
| return attr_str |
|
|
| def _contain_invalid_identifier(dict_str): |
| contain_invalid_identifier = False |
| for key_name in dict_str: |
| contain_invalid_identifier |= \ |
| (not str(key_name).isidentifier()) |
| return contain_invalid_identifier |
|
|
| def _format_dict(input_dict, outest_level=False): |
| r = '' |
| s = [] |
|
|
| use_mapping = _contain_invalid_identifier(input_dict) |
| if use_mapping: |
| r += '{' |
| for idx, (k, v) in enumerate(input_dict.items()): |
| is_last = idx >= len(input_dict) - 1 |
| end = '' if outest_level or is_last else ',' |
| if isinstance(v, dict): |
| v_str = '\n' + _format_dict(v) |
| if use_mapping: |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) |
| attr_str = f'{k_str}: dict({v_str}' |
| else: |
| attr_str = f'{str(k)}=dict({v_str}' |
| attr_str = _indent(attr_str, indent) + ')' + end |
| elif isinstance(v, list): |
| attr_str = _format_list(k, v, use_mapping) + end |
| else: |
| attr_str = _format_basic_types(k, v, use_mapping) + end |
|
|
| s.append(attr_str) |
| r += '\n'.join(s) |
| if use_mapping: |
| r += '}' |
| return r |
|
|
| cfg_dict = self._cfg_dict.to_dict() |
| text = _format_dict(cfg_dict, outest_level=True) |
| |
| yapf_style = dict( |
| based_on_style='pep8', |
| blank_line_before_nested_class_or_def=True, |
| split_before_expression_after_opening_paren=True) |
| text, _ = FormatCode(text, style_config=yapf_style, verify=True) |
|
|
| return text |
| |
|
|
| def __repr__(self): |
| return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' |
|
|
| def __len__(self): |
| return len(self._cfg_dict) |
|
|
| def __getattr__(self, name): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return getattr(self._cfg_dict, name) |
|
|
| def __getitem__(self, name): |
| return self._cfg_dict.__getitem__(name) |
|
|
| def __setattr__(self, name, value): |
| if isinstance(value, dict): |
| value = ConfigDict(value) |
| self._cfg_dict.__setattr__(name, value) |
|
|
| def __setitem__(self, name, value): |
| if isinstance(value, dict): |
| value = ConfigDict(value) |
| self._cfg_dict.__setitem__(name, value) |
|
|
| def __iter__(self): |
| return iter(self._cfg_dict) |
|
|
| def dump(self, file=None): |
|
|
| if file is None: |
| return self.pretty_text |
| else: |
| with open(file, 'w') as f: |
| f.write(self.pretty_text) |
|
|
| def merge_from_dict(self, options): |
| """Merge list into cfg_dict |
| |
| Merge the dict parsed by MultipleKVAction into this cfg. |
| |
| Examples: |
| >>> options = {'model.backbone.depth': 50, |
| ... 'model.backbone.with_cp':True} |
| >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) |
| >>> cfg.merge_from_dict(options) |
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') |
| >>> assert cfg_dict == dict( |
| ... model=dict(backbone=dict(depth=50, with_cp=True))) |
| |
| Args: |
| options (dict): dict of configs to merge from. |
| """ |
| option_cfg_dict = {} |
| for full_key, v in options.items(): |
| d = option_cfg_dict |
| key_list = full_key.split('.') |
| for subkey in key_list[:-1]: |
| d.setdefault(subkey, ConfigDict()) |
| d = d[subkey] |
| subkey = key_list[-1] |
| d[subkey] = v |
|
|
| cfg_dict = super(SLConfig, self).__getattribute__('_cfg_dict') |
| super(SLConfig, self).__setattr__( |
| '_cfg_dict', SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)) |
|
|
| |
| def __setstate__(self, state): |
| self.__init__(state) |
|
|
|
|
| def copy(self): |
| return SLConfig(self._cfg_dict.copy()) |
|
|
| def deepcopy(self): |
| return SLConfig(self._cfg_dict.deepcopy()) |
|
|
|
|
| class DictAction(Action): |
| """ |
| argparse action to split an argument into KEY=VALUE form |
| on the first = and append to a dictionary. List options should |
| be passed as comma separated values, i.e KEY=V1,V2,V3 |
| """ |
|
|
| @staticmethod |
| def _parse_int_float_bool(val): |
| try: |
| return int(val) |
| except ValueError: |
| pass |
| try: |
| return float(val) |
| except ValueError: |
| pass |
| if val.lower() in ['true', 'false']: |
| return True if val.lower() == 'true' else False |
| if val.lower() in ['none', 'null']: |
| return None |
| return val |
|
|
| def __call__(self, parser, namespace, values, option_string=None): |
| options = {} |
| for kv in values: |
| key, val = kv.split('=', maxsplit=1) |
| val = [self._parse_int_float_bool(v) for v in val.split(',')] |
| if len(val) == 1: |
| val = val[0] |
| options[key] = val |
| setattr(namespace, self.dest, options) |
|
|
|
|