Spaces:
Sleeping
Sleeping
| import argparse | |
| import yaml | |
| from functools import wraps | |
| from types import SimpleNamespace | |
| class SafeNamespace(SimpleNamespace): | |
| """支持安全访问的 SimpleNamespace""" | |
| def __getattr__(self, name): | |
| return None # 不存在时返回 None | |
| def dict_to_namespace(d: dict): | |
| """递归把 dict 转换为 SafeNamespace""" | |
| if isinstance(d, dict): | |
| return SafeNamespace(**{k: dict_to_namespace(v) for k, v in d.items()}) | |
| elif isinstance(d, list): | |
| return [dict_to_namespace(x) for x in d] | |
| else: | |
| return d | |
| def load_config(config_path: str): | |
| with open(config_path, "r") as f: | |
| if config_path.endswith((".yaml", ".yml")): | |
| return yaml.safe_load(f) | |
| elif config_path.endswith(".json"): | |
| import json | |
| return json.load(f) | |
| else: | |
| raise ValueError(f"Unsupported config file type: {config_path}") | |
| def merge_config(base: dict, overrides: dict): | |
| """ | |
| 用 overrides 覆盖 base(递归方式) | |
| """ | |
| if overrides is None: | |
| return base | |
| for k, v in overrides.items(): | |
| if isinstance(v, dict) and isinstance(base.get(k), dict): | |
| base[k] = merge_config(base[k], v) | |
| else: | |
| base[k] = v | |
| return base | |
| def config_entry(config_arg_name="config"): | |
| """ | |
| 装饰器:封装 parse_args + load_config + merge_config | |
| :param config_arg_name: 配置文件参数在 argparse 中的名字 | |
| """ | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| parsed_args = func(*args, **kwargs) | |
| overrides = vars(parsed_args).copy() | |
| cfg_path = overrides.pop(config_arg_name, None) | |
| # 先把 argparse 参数变成 dict(默认值) | |
| cfg = {} | |
| for k, v in overrides.items(): | |
| if v is None: | |
| continue | |
| keys = k.split(".") | |
| d = cfg | |
| for subkey in keys[:-1]: | |
| d = d.setdefault(subkey, {}) | |
| d[keys[-1]] = v | |
| # 如果有 config 文件,就覆盖 argparse 参数 | |
| if cfg_path is not None: | |
| file_cfg = load_config(cfg_path) | |
| cfg = merge_config(cfg, file_cfg) | |
| return dict_to_namespace(cfg) | |
| return wrapper | |
| return decorator | |
| def example_parse_args(): | |
| parser = argparse.ArgumentParser(description="Training Config") | |
| parser.add_argument("--cfg_file", type=str, help="Path to config file (optional)") | |
| parser.add_argument("--train.batch_size", type=int, default=1, help="Override batch size") | |
| parser.add_argument("--train.lr", type=float, help="Override learning rate") | |
| parser.add_argument("--model.name", type=str, help="Override model name") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = example_parse_args() | |
| print("Final Config:") | |
| # 即使 rank 不存在,也不会报错,返回 None | |
| print(args.rank) | |
| print(args.train.batch_size) | |
| print(args.model.hidden_dim) # 不存在 → None | |