File size: 3,125 Bytes
56d35ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import yaml
from functools import wraps
from types import SimpleNamespace


class SafeNamespace(SimpleNamespace):
    """支持安全访问的 SimpleNamespace"""
    def __getattr__(self, name):
        return None  # 不存在时返回 None


def dict_to_namespace(d: dict):
    """递归把 dict 转换为 SafeNamespace"""
    if isinstance(d, dict):
        return SafeNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    elif isinstance(d, list):
        return [dict_to_namespace(x) for x in d]
    else:
        return d


def load_config(config_path: str):
    with open(config_path, "r") as f:
        if config_path.endswith((".yaml", ".yml")):
            return yaml.safe_load(f)
        elif config_path.endswith(".json"):
            import json
            return json.load(f)
        else:
            raise ValueError(f"Unsupported config file type: {config_path}")


def merge_config(base: dict, overrides: dict):
    """
    用 overrides 覆盖 base(递归方式)
    """
    if overrides is None:
        return base
    for k, v in overrides.items():
        if isinstance(v, dict) and isinstance(base.get(k), dict):
            base[k] = merge_config(base[k], v)
        else:
            base[k] = v
    return base


def config_entry(config_arg_name="config"):
    """
    装饰器:封装 parse_args + load_config + merge_config
    :param config_arg_name: 配置文件参数在 argparse 中的名字
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            parsed_args = func(*args, **kwargs)

            overrides = vars(parsed_args).copy()
            cfg_path = overrides.pop(config_arg_name, None)

            # 先把 argparse 参数变成 dict(默认值)
            cfg = {}
            for k, v in overrides.items():
                if v is None:
                    continue
                keys = k.split(".")
                d = cfg
                for subkey in keys[:-1]:
                    d = d.setdefault(subkey, {})
                d[keys[-1]] = v

            # 如果有 config 文件,就覆盖 argparse 参数
            if cfg_path is not None:
                file_cfg = load_config(cfg_path)
                cfg = merge_config(cfg, file_cfg)

            return dict_to_namespace(cfg)
        return wrapper
    return decorator


@config_entry("cfg_file")
def example_parse_args():
    parser = argparse.ArgumentParser(description="Training Config")

    parser.add_argument("--cfg_file", type=str, help="Path to config file (optional)")
    parser.add_argument("--train.batch_size", type=int, default=1, help="Override batch size")
    parser.add_argument("--train.lr", type=float, help="Override learning rate")
    parser.add_argument("--model.name", type=str, help="Override model name")

    return parser.parse_args()


if __name__ == "__main__":
    args = example_parse_args()

    print("Final Config:")
    # 即使 rank 不存在,也不会报错,返回 None
    print(args.rank)  
    print(args.train.batch_size)
    print(args.model.hidden_dim)  # 不存在 → None