| | import torch |
| | import numpy as np |
| | from joblib import Parallel, delayed, parallel_backend |
| | from multiprocessing import Pool, cpu_count |
| | from tqdm import tqdm |
| |
|
| | |
| | def _worker(args): |
| | fn, d, kwargs = args |
| | return fn(*d, **kwargs) |
| |
|
| |
|
| | from joblib import Parallel, delayed, cpu_count |
| | from tqdm import tqdm |
| |
|
| | def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs): |
| | if n_jobs is None: |
| | n_jobs = cpu_count() |
| |
|
| | |
| | def _wrapped(d): |
| | return pickleable_fn(*d, **kwargs) |
| |
|
| | |
| | data_iter = list(tqdm(data, desc=desc)) |
| |
|
| | with parallel_backend('loky'): |
| | results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( |
| | delayed(_wrapped)(d) for d in data_iter |
| | ) |
| | return results |
| |
|
| |
|
| | def modulo_with_wrapped_range( |
| | vals, range_min: float = -np.pi, range_max: float = np.pi |
| | ): |
| | """ |
| | Modulo with wrapped range -- capable of handing a range with a negative min |
| | |
| | >>> modulo_with_wrapped_range(3, -2, 2) |
| | -1 |
| | """ |
| | assert range_min <= 0.0 |
| | assert range_min < range_max |
| |
|
| | |
| | top_end = range_max - range_min |
| | |
| | vals_shifted = vals - range_min |
| | |
| | vals_shifted_mod = vals_shifted % top_end |
| | |
| | retval = vals_shifted_mod + range_min |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return retval |
| |
|
| |
|
| |
|
| | def flatten_dict(d, parent_key='', sep='.', level=0): |
| | """ |
| | 递归地将嵌套字典拉平为一个单层字典,取消第一级父键。 |
| | |
| | :param d: 输入的嵌套字典 |
| | :param parent_key: 父键(用于递归) |
| | :param sep: 键之间的分隔符,默认为点号 '.' |
| | :param level: 当前递归的层级(用于取消第一级父键) |
| | :return: 拉平后的单层字典 |
| | """ |
| | items = {} |
| | for k, v in d.items(): |
| | |
| | if level <=1: |
| | new_key = k |
| | else: |
| | new_key = f"{parent_key}{sep}{k}" if parent_key else k |
| |
|
| | if isinstance(v, dict): |
| | |
| | items.update(flatten_dict(v, new_key, sep=sep, level=level + 1)) |
| | else: |
| | |
| | items[new_key] = v |
| | return items |
| |
|
| | def process_args(parser, config_path): |
| | from hydra import initialize, compose |
| | from omegaconf import DictConfig, OmegaConf |
| | import sys |
| | def eval_resolver(expr: str): |
| | return eval(expr, {}, {}) |
| |
|
| | OmegaConf.register_new_resolver("eval", eval_resolver, use_cache=False) |
| |
|
| | |
| | |
| | defaults = parser.parse_args([]) |
| | defaults_dict = vars(defaults) |
| |
|
| | |
| | args = parser.parse_args() |
| | args_dict = vars(args) |
| |
|
| | |
| | with initialize(config_path=config_path): |
| | cfg: DictConfig = compose(config_name=args.config_name) |
| | config_dict = flatten_dict(OmegaConf.to_container(cfg, resolve=True)) |
| |
|
| | |
| | |
| | passed = set() |
| | for tok in sys.argv[1:]: |
| | if not tok.startswith('--'): |
| | continue |
| | |
| | key = tok.lstrip('-').split('=')[0].replace('-', '_') |
| | passed.add(key) |
| |
|
| | |
| | merged = {} |
| | for key in set(list(defaults_dict.keys())+list(config_dict.keys())): |
| | if key in passed: |
| | |
| | merged[key] = args_dict[key] |
| | elif key in config_dict: |
| | |
| | merged[key] = config_dict[key] |
| | else: |
| | |
| | merged[key] = defaults_dict[key] |
| |
|
| | |
| | args.__dict__.update(merged) |
| | return args |
| |
|