File size: 4,967 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import numpy as np
from joblib import Parallel, delayed, parallel_backend
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

# 顶层函数,multiprocessing 才能 pickle
def _worker(args):
    fn, d, kwargs = args
    return fn(*d, **kwargs)  # d 必须是 tuple;如果是单参数就传成 (d,) 即可


from joblib import Parallel, delayed, cpu_count
from tqdm import tqdm

def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
    if n_jobs is None:
        n_jobs = cpu_count()

    # 定义一个真正可以 pickling 的函数,避免 lambda 引起问题
    def _wrapped(d):
        return pickleable_fn(*d, **kwargs)

    # tqdm 外部包裹,不要嵌入 generator 里
    data_iter = list(tqdm(data, desc=desc))

    with parallel_backend('loky'):  # 或 'multiprocessing'
        results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
            delayed(_wrapped)(d) for d in data_iter
        )
    return results


def modulo_with_wrapped_range(
    vals, range_min: float = -np.pi, range_max: float = np.pi
):
    """
    Modulo with wrapped range -- capable of handing a range with a negative min

    >>> modulo_with_wrapped_range(3, -2, 2)
    -1
    """
    assert range_min <= 0.0
    assert range_min < range_max

    # Modulo after we shift values
    top_end = range_max - range_min
    # Shift the values to be in the range [0, top_end)
    vals_shifted = vals - range_min
    # Perform modulo
    vals_shifted_mod = vals_shifted % top_end
    # Shift back down
    retval = vals_shifted_mod + range_min

    # Checks
    # print("Mod return", vals, " --> ", retval)
    # if isinstance(retval, torch.Tensor):
    #     notnan_idx = ~torch.isnan(retval)
    #     assert torch.all(retval[notnan_idx] >= range_min)
    #     assert torch.all(retval[notnan_idx] < range_max)
    # else:
    #     assert (
    #         np.nanmin(retval) >= range_min
    #     ), f"Illegal value: {np.nanmin(retval)} < {range_min}"
    #     assert (
    #         np.nanmax(retval) <= range_max
    #     ), f"Illegal value: {np.nanmax(retval)} > {range_max}"
    return retval



def flatten_dict(d, parent_key='', sep='.', level=0):
    """
    递归地将嵌套字典拉平为一个单层字典,取消第一级父键。

    :param d: 输入的嵌套字典
    :param parent_key: 父键(用于递归)
    :param sep: 键之间的分隔符,默认为点号 '.'
    :param level: 当前递归的层级(用于取消第一级父键)
    :return: 拉平后的单层字典
    """
    items = {}
    for k, v in d.items():
        # 构建新的键
        if level <=1:
            new_key = k  # 第一级取消父键
        else:
            new_key = f"{parent_key}{sep}{k}" if parent_key else k

        if isinstance(v, dict):
            # 如果值是字典,递归拉平
            items.update(flatten_dict(v, new_key, sep=sep, level=level + 1))
        else:
            # 否则直接添加到结果中
            items[new_key] = v
    return items

def process_args(parser, config_path):
    from hydra import initialize, compose
    from omegaconf import DictConfig, OmegaConf
    import sys
    def eval_resolver(expr: str):
        return eval(expr, {}, {})

    OmegaConf.register_new_resolver("eval", eval_resolver, use_cache=False)

    # ----------------------------------------------------------------------------
    # 1) 先不看命令行,拿到纯“parser 默认值”:
    defaults = parser.parse_args([])
    defaults_dict = vars(defaults)

    # 2) 真正去解析一次命令行(CLI + 默认):
    args = parser.parse_args()
    args_dict = vars(args)

    # 3) 再读你的 Hydra config,平展开成普通 dict:
    with initialize(config_path=config_path):
        cfg: DictConfig = compose(config_name=args.config_name)
    config_dict = flatten_dict(OmegaConf.to_container(cfg, resolve=True))

    # ----------------------------------------------------------------------------
    # 4) 挖出哪些 key 的值是“真由用户在命令行里指定”的:
    passed = set()
    for tok in sys.argv[1:]:
        if not tok.startswith('--'):
            continue
        # 支持 --foo=bar 和 --foo bar 两种写法
        key = tok.lstrip('-').split('=')[0].replace('-', '_')
        passed.add(key)

    # 5) 最终合并:CLI > config_file > parser_default
    merged = {}
    for key in set(list(defaults_dict.keys())+list(config_dict.keys())):
        if key in passed:
            # 用户显式传进来的
            merged[key] = args_dict[key]
        elif key in config_dict:
            # config 文件里有,且用户没在 CLI 指定,就用它
            merged[key] = config_dict[key]
        else:
            # 都没有指定,就退回 parser 默认
            merged[key] = defaults_dict[key]

    # 用合并后的结果更新 args Namespace
    args.__dict__.update(merged)
    return args