| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import os |
| import copy |
| import yaml |
| from typing import Any, Dict, Optional, List |
|
|
| from .workspace import GLOBAL_CONFIG |
|
|
| __all__ = [ |
| 'load_config', |
| 'merge_config', |
| 'merge_dict', |
| 'parse_cli', |
| ] |
|
|
|
|
| INCLUDE_KEY = '__include__' |
|
|
|
|
| def load_config(file_path, cfg=dict()): |
| """load config |
| """ |
| _, ext = os.path.splitext(file_path) |
| assert ext in ['.yml', '.yaml'], "only support yaml files" |
|
|
| with open(file_path) as f: |
| file_cfg = yaml.load(f, Loader=yaml.Loader) |
| if file_cfg is None: |
| return {} |
|
|
| if INCLUDE_KEY in file_cfg: |
| base_yamls = list(file_cfg[INCLUDE_KEY]) |
| for base_yaml in base_yamls: |
| if base_yaml.startswith('~'): |
| base_yaml = os.path.expanduser(base_yaml) |
|
|
| if not base_yaml.startswith('/'): |
| base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) |
|
|
| with open(base_yaml) as f: |
| base_cfg = load_config(base_yaml, cfg) |
| merge_dict(cfg, base_cfg) |
|
|
| return merge_dict(cfg, file_cfg) |
|
|
|
|
| def merge_dict(dct, another_dct, inplace=True) -> Dict: |
| """merge another_dct into dct |
| """ |
| def _merge(dct, another) -> Dict: |
| for k in another: |
| if (k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict)): |
| _merge(dct[k], another[k]) |
| else: |
| dct[k] = another[k] |
|
|
| return dct |
|
|
| if not inplace: |
| dct = copy.deepcopy(dct) |
|
|
| return _merge(dct, another_dct) |
|
|
|
|
| def dictify(s: str, v: Any) -> Dict: |
| if '.' not in s: |
| return {s: v} |
| key, rest = s.split('.', 1) |
| return {key: dictify(rest, v)} |
|
|
|
|
| def parse_cli(nargs: List[str]) -> Dict: |
| """ |
| parse command-line arguments |
| convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}` |
| """ |
| cfg = {} |
| if nargs is None or len(nargs) == 0: |
| return cfg |
|
|
| for s in nargs: |
| s = s.strip() |
| k, v = s.split('=', 1) |
| d = dictify(k, yaml.load(v, Loader=yaml.Loader)) |
| cfg = merge_dict(cfg, d) |
|
|
| return cfg |
|
|
|
|
|
|
| def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite: bool=False): |
| """ |
| Merge another_cfg into cfg, return the merged config |
| |
| Example: |
| |
| cfg1 = load_config('./dfine_r18vd_6x_coco.yml') |
| cfg1 = merge_config(cfg, inplace=True) |
| |
| cfg2 = load_config('./dfine_r50vd_6x_coco.yml') |
| cfg2 = merge_config(cfg2, inplace=True) |
| |
| model1 = create(cfg1['model'], cfg1) |
| model2 = create(cfg2['model'], cfg2) |
| """ |
| def _merge(dct, another): |
| for k in another: |
| if k not in dct: |
| dct[k] = another[k] |
|
|
| elif isinstance(dct[k], dict) and isinstance(another[k], dict): |
| _merge(dct[k], another[k]) |
|
|
| elif overwrite: |
| dct[k] = another[k] |
|
|
| return cfg |
|
|
| if not inplace: |
| cfg = copy.deepcopy(cfg) |
|
|
| return _merge(cfg, another_cfg) |
|
|