Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| from collections.abc import Mapping, Sequence | |
| from tqdm.auto import tqdm | |
| from joblib import Parallel, delayed, cpu_count | |
| from joblib.externals.loky import set_loky_pickler | |
| import yaml | |
| def load_yaml_config(filepath): | |
| with open(filepath, 'r') as file: | |
| config = yaml.safe_load(file) | |
| return config | |
| def cuda(obj, *args, **kwargs): | |
| """ | |
| Transfer any nested conatiner of tensors to CUDA. | |
| """ | |
| if hasattr(obj, "cuda"): | |
| return obj.cuda(*args, **kwargs) | |
| elif isinstance(obj, Mapping): | |
| return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()}) | |
| elif isinstance(obj, Sequence): | |
| if isinstance(obj, str): | |
| return obj | |
| return type(obj)(cuda(x, *args, **kwargs) for x in obj) | |
| elif isinstance(obj, np.ndarray): | |
| return torch.tensor(obj, *args, **kwargs) | |
| elif isinstance(obj, T): | |
| return obj.to(*args, **kwargs) | |
| else: | |
| return obj | |
| raise TypeError("Can't transfer object type `%s`" % type(obj)) | |
| def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs): | |
| """ | |
| Parallel map using joblib. | |
| Parameters | |
| ---------- | |
| pickleable_fn : callable | |
| Function to map over data. | |
| data : iterable | |
| Data over which we want to parallelize the function call. | |
| n_jobs : int, optional | |
| The maximum number of concurrently running jobs. By default, it is one less than | |
| the number of CPUs. | |
| verbose: int, optional | |
| The verbosity level. If nonzero, the function prints the progress messages. | |
| The frequency of the messages increases with the verbosity level. If above 10, | |
| it reports all iterations. If above 50, it sends the output to stdout. | |
| kwargs | |
| Additional arguments for :attr:`pickleable_fn`. | |
| Returns | |
| ------- | |
| list | |
| The i-th element of the list corresponds to the output of applying | |
| :attr:`pickleable_fn` to :attr:`data[i]`. | |
| """ | |
| if n_jobs is None: | |
| n_jobs = cpu_count() - 1 | |
| # n_jobs = 60 | |
| results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( | |
| delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc) | |
| ) | |
| return results | |
| def modulo_with_wrapped_range( | |
| vals, range_min: float = -np.pi, range_max: float = np.pi | |
| ): | |
| """ | |
| Modulo with wrapped range -- capable of handing a range with a negative min | |
| >>> modulo_with_wrapped_range(3, -2, 2) | |
| -1 | |
| """ | |
| assert range_min <= 0.0 | |
| assert range_min < range_max | |
| # Modulo after we shift values | |
| top_end = range_max - range_min | |
| # Shift the values to be in the range [0, top_end) | |
| vals_shifted = vals - range_min | |
| # Perform modulo | |
| vals_shifted_mod = vals_shifted % top_end | |
| # Shift back down | |
| retval = vals_shifted_mod + range_min | |
| # Checks | |
| # print("Mod return", vals, " --> ", retval) | |
| # if isinstance(retval, torch.Tensor): | |
| # notnan_idx = ~torch.isnan(retval) | |
| # assert torch.all(retval[notnan_idx] >= range_min) | |
| # assert torch.all(retval[notnan_idx] < range_max) | |
| # else: | |
| # assert ( | |
| # np.nanmin(retval) >= range_min | |
| # ), f"Illegal value: {np.nanmin(retval)} < {range_min}" | |
| # assert ( | |
| # np.nanmax(retval) <= range_max | |
| # ), f"Illegal value: {np.nanmax(retval)} > {range_max}" | |
| return retval |