File size: 4,967 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import torch
import numpy as np
from joblib import Parallel, delayed, parallel_backend
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
# 顶层函数,multiprocessing 才能 pickle
def _worker(args):
fn, d, kwargs = args
return fn(*d, **kwargs) # d 必须是 tuple;如果是单参数就传成 (d,) 即可
from joblib import Parallel, delayed, cpu_count
from tqdm import tqdm
def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
if n_jobs is None:
n_jobs = cpu_count()
# 定义一个真正可以 pickling 的函数,避免 lambda 引起问题
def _wrapped(d):
return pickleable_fn(*d, **kwargs)
# tqdm 外部包裹,不要嵌入 generator 里
data_iter = list(tqdm(data, desc=desc))
with parallel_backend('loky'): # 或 'multiprocessing'
results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
delayed(_wrapped)(d) for d in data_iter
)
return results
def modulo_with_wrapped_range(
vals, range_min: float = -np.pi, range_max: float = np.pi
):
"""
Modulo with wrapped range -- capable of handing a range with a negative min
>>> modulo_with_wrapped_range(3, -2, 2)
-1
"""
assert range_min <= 0.0
assert range_min < range_max
# Modulo after we shift values
top_end = range_max - range_min
# Shift the values to be in the range [0, top_end)
vals_shifted = vals - range_min
# Perform modulo
vals_shifted_mod = vals_shifted % top_end
# Shift back down
retval = vals_shifted_mod + range_min
# Checks
# print("Mod return", vals, " --> ", retval)
# if isinstance(retval, torch.Tensor):
# notnan_idx = ~torch.isnan(retval)
# assert torch.all(retval[notnan_idx] >= range_min)
# assert torch.all(retval[notnan_idx] < range_max)
# else:
# assert (
# np.nanmin(retval) >= range_min
# ), f"Illegal value: {np.nanmin(retval)} < {range_min}"
# assert (
# np.nanmax(retval) <= range_max
# ), f"Illegal value: {np.nanmax(retval)} > {range_max}"
return retval
def flatten_dict(d, parent_key='', sep='.', level=0):
"""
递归地将嵌套字典拉平为一个单层字典,取消第一级父键。
:param d: 输入的嵌套字典
:param parent_key: 父键(用于递归)
:param sep: 键之间的分隔符,默认为点号 '.'
:param level: 当前递归的层级(用于取消第一级父键)
:return: 拉平后的单层字典
"""
items = {}
for k, v in d.items():
# 构建新的键
if level <=1:
new_key = k # 第一级取消父键
else:
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
# 如果值是字典,递归拉平
items.update(flatten_dict(v, new_key, sep=sep, level=level + 1))
else:
# 否则直接添加到结果中
items[new_key] = v
return items
def process_args(parser, config_path):
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
import sys
def eval_resolver(expr: str):
return eval(expr, {}, {})
OmegaConf.register_new_resolver("eval", eval_resolver, use_cache=False)
# ----------------------------------------------------------------------------
# 1) 先不看命令行,拿到纯“parser 默认值”:
defaults = parser.parse_args([])
defaults_dict = vars(defaults)
# 2) 真正去解析一次命令行(CLI + 默认):
args = parser.parse_args()
args_dict = vars(args)
# 3) 再读你的 Hydra config,平展开成普通 dict:
with initialize(config_path=config_path):
cfg: DictConfig = compose(config_name=args.config_name)
config_dict = flatten_dict(OmegaConf.to_container(cfg, resolve=True))
# ----------------------------------------------------------------------------
# 4) 挖出哪些 key 的值是“真由用户在命令行里指定”的:
passed = set()
for tok in sys.argv[1:]:
if not tok.startswith('--'):
continue
# 支持 --foo=bar 和 --foo bar 两种写法
key = tok.lstrip('-').split('=')[0].replace('-', '_')
passed.add(key)
# 5) 最终合并:CLI > config_file > parser_default
merged = {}
for key in set(list(defaults_dict.keys())+list(config_dict.keys())):
if key in passed:
# 用户显式传进来的
merged[key] = args_dict[key]
elif key in config_dict:
# config 文件里有,且用户没在 CLI 指定,就用它
merged[key] = config_dict[key]
else:
# 都没有指定,就退回 parser 默认
merged[key] = defaults_dict[key]
# 用合并后的结果更新 args Namespace
args.__dict__.update(merged)
return args
|