Honzus24's picture
initial commit
7968cb0
import torch
import numpy as np
from collections.abc import Mapping, Sequence
from tqdm.auto import tqdm
from joblib import Parallel, delayed, cpu_count
from joblib.externals.loky import set_loky_pickler
import yaml
def load_yaml_config(filepath):
with open(filepath, 'r') as file:
config = yaml.safe_load(file)
return config
def cuda(obj, *args, **kwargs):
"""
Transfer any nested conatiner of tensors to CUDA.
"""
if hasattr(obj, "cuda"):
return obj.cuda(*args, **kwargs)
elif isinstance(obj, Mapping):
return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()})
elif isinstance(obj, Sequence):
if isinstance(obj, str):
return obj
return type(obj)(cuda(x, *args, **kwargs) for x in obj)
elif isinstance(obj, np.ndarray):
return torch.tensor(obj, *args, **kwargs)
elif isinstance(obj, T):
return obj.to(*args, **kwargs)
else:
return obj
raise TypeError("Can't transfer object type `%s`" % type(obj))
def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
"""
Parallel map using joblib.
Parameters
----------
pickleable_fn : callable
Function to map over data.
data : iterable
Data over which we want to parallelize the function call.
n_jobs : int, optional
The maximum number of concurrently running jobs. By default, it is one less than
the number of CPUs.
verbose: int, optional
The verbosity level. If nonzero, the function prints the progress messages.
The frequency of the messages increases with the verbosity level. If above 10,
it reports all iterations. If above 50, it sends the output to stdout.
kwargs
Additional arguments for :attr:`pickleable_fn`.
Returns
-------
list
The i-th element of the list corresponds to the output of applying
:attr:`pickleable_fn` to :attr:`data[i]`.
"""
if n_jobs is None:
n_jobs = cpu_count() - 1
# n_jobs = 60
results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc)
)
return results
def modulo_with_wrapped_range(
vals, range_min: float = -np.pi, range_max: float = np.pi
):
"""
Modulo with wrapped range -- capable of handing a range with a negative min
>>> modulo_with_wrapped_range(3, -2, 2)
-1
"""
assert range_min <= 0.0
assert range_min < range_max
# Modulo after we shift values
top_end = range_max - range_min
# Shift the values to be in the range [0, top_end)
vals_shifted = vals - range_min
# Perform modulo
vals_shifted_mod = vals_shifted % top_end
# Shift back down
retval = vals_shifted_mod + range_min
# Checks
# print("Mod return", vals, " --> ", retval)
# if isinstance(retval, torch.Tensor):
# notnan_idx = ~torch.isnan(retval)
# assert torch.all(retval[notnan_idx] >= range_min)
# assert torch.all(retval[notnan_idx] < range_max)
# else:
# assert (
# np.nanmin(retval) >= range_min
# ), f"Illegal value: {np.nanmin(retval)} < {range_min}"
# assert (
# np.nanmax(retval) <= range_max
# ), f"Illegal value: {np.nanmax(retval)} > {range_max}"
return retval