Spaces:
Sleeping
Sleeping
File size: 3,125 Bytes
56d35ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import argparse
import yaml
from functools import wraps
from types import SimpleNamespace
class SafeNamespace(SimpleNamespace):
"""支持安全访问的 SimpleNamespace"""
def __getattr__(self, name):
return None # 不存在时返回 None
def dict_to_namespace(d: dict):
"""递归把 dict 转换为 SafeNamespace"""
if isinstance(d, dict):
return SafeNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
elif isinstance(d, list):
return [dict_to_namespace(x) for x in d]
else:
return d
def load_config(config_path: str):
with open(config_path, "r") as f:
if config_path.endswith((".yaml", ".yml")):
return yaml.safe_load(f)
elif config_path.endswith(".json"):
import json
return json.load(f)
else:
raise ValueError(f"Unsupported config file type: {config_path}")
def merge_config(base: dict, overrides: dict):
"""
用 overrides 覆盖 base(递归方式)
"""
if overrides is None:
return base
for k, v in overrides.items():
if isinstance(v, dict) and isinstance(base.get(k), dict):
base[k] = merge_config(base[k], v)
else:
base[k] = v
return base
def config_entry(config_arg_name="config"):
"""
装饰器:封装 parse_args + load_config + merge_config
:param config_arg_name: 配置文件参数在 argparse 中的名字
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
parsed_args = func(*args, **kwargs)
overrides = vars(parsed_args).copy()
cfg_path = overrides.pop(config_arg_name, None)
# 先把 argparse 参数变成 dict(默认值)
cfg = {}
for k, v in overrides.items():
if v is None:
continue
keys = k.split(".")
d = cfg
for subkey in keys[:-1]:
d = d.setdefault(subkey, {})
d[keys[-1]] = v
# 如果有 config 文件,就覆盖 argparse 参数
if cfg_path is not None:
file_cfg = load_config(cfg_path)
cfg = merge_config(cfg, file_cfg)
return dict_to_namespace(cfg)
return wrapper
return decorator
@config_entry("cfg_file")
def example_parse_args():
parser = argparse.ArgumentParser(description="Training Config")
parser.add_argument("--cfg_file", type=str, help="Path to config file (optional)")
parser.add_argument("--train.batch_size", type=int, default=1, help="Override batch size")
parser.add_argument("--train.lr", type=float, help="Override learning rate")
parser.add_argument("--model.name", type=str, help="Override model name")
return parser.parse_args()
if __name__ == "__main__":
args = example_parse_args()
print("Final Config:")
# 即使 rank 不存在,也不会报错,返回 None
print(args.rank)
print(args.train.batch_size)
print(args.model.hidden_dim) # 不存在 → None
|