|
|
import torch |
|
|
import numpy as np |
|
|
from numba import njit |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'tensor_idx', 'is_sorted', 'has_duplicates', 'is_dense', 'is_permutation', |
|
|
'arange_interleave', 'print_tensor_info', 'cast_to_optimal_integer_type', |
|
|
'cast_numpyfy', 'numpyfy', 'torchify', 'torch_to_numpy', 'fast_randperm', |
|
|
'fast_zeros', 'fast_repeat', 'string_to_dtype'] |
|
|
|
|
|
|
|
|
def tensor_idx(idx, device=None): |
|
|
"""Convert an int, slice, list or numpy index to a torch.LongTensor. |
|
|
""" |
|
|
if device is None and hasattr(idx, 'device'): |
|
|
device = idx.device |
|
|
elif device is None: |
|
|
device = 'cpu' |
|
|
|
|
|
if idx is None: |
|
|
idx = torch.tensor([], device=device, dtype=torch.long) |
|
|
elif isinstance(idx, int): |
|
|
idx = torch.tensor([idx], device=device, dtype=torch.long) |
|
|
elif isinstance(idx, list): |
|
|
idx = torch.tensor(idx, device=device, dtype=torch.long) |
|
|
elif isinstance(idx, slice): |
|
|
idx = torch.arange(idx.stop, device=device)[idx] |
|
|
elif isinstance(idx, np.ndarray): |
|
|
idx = torch.from_numpy(idx).to(device) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(idx, torch.BoolTensor): |
|
|
idx = torch.where(idx)[0] |
|
|
|
|
|
assert idx.dtype is torch.int64, \ |
|
|
f"Expected LongTensor but got {idx.dtype} instead." |
|
|
|
|
|
|
|
|
|
|
|
return idx |
|
|
|
|
|
|
|
|
def is_sorted(a: torch.Tensor, increasing=True, strict=False): |
|
|
"""Checks whether a 1D tensor of indices is sorted.""" |
|
|
assert a.dim() == 1, "Only supports 1D tensors" |
|
|
assert not a.is_floating_point(), "Float tensors are not supported" |
|
|
if increasing and strict: |
|
|
f = torch.gt |
|
|
if increasing and not strict: |
|
|
f = torch.ge |
|
|
if not increasing and strict: |
|
|
f = torch.lt |
|
|
if not increasing and not strict: |
|
|
f = torch.le |
|
|
return f(a[1:], a[:-1]).all() |
|
|
|
|
|
|
|
|
def has_duplicates(a: torch.Tensor): |
|
|
"""Checks whether a 1D tensor of indices contains duplicates.""" |
|
|
assert a.dim() == 1, "Only supports 1D tensors" |
|
|
assert not a.is_floating_point(), "Float tensors are not supported" |
|
|
return a.unique().numel() != a.numel() |
|
|
|
|
|
|
|
|
def is_dense(a: torch.Tensor): |
|
|
"""Checks whether a 1D tensor of indices contains dense indices. |
|
|
That is to say all values in [0, a.max] appear at least once in a. |
|
|
""" |
|
|
assert a.dim() == 1, "Only supports 1D tensors" |
|
|
assert not a.is_floating_point(), "Float tensors are not supported" |
|
|
assert a.numel() > 0, "0-dimensional tensors are not supported" |
|
|
unique = a.unique() |
|
|
return a.min() == 0 and unique.size(0) == a.max().long() + 1 |
|
|
|
|
|
|
|
|
def is_permutation(a: torch.Tensor): |
|
|
"""Checks whether a 1D tensor of indices is a permutation.""" |
|
|
assert a.dim() == 1, "Only supports 1D tensors" |
|
|
assert not a.is_floating_point(), "Float tensors are not supported" |
|
|
return a.sort().values.long().equal(torch.arange(a.numel(), device=a.device)) |
|
|
|
|
|
|
|
|
def arange_interleave(width, start=None): |
|
|
"""Vectorized equivalent of: |
|
|
>>> torch.cat([torch.arange(s, s + w) for w, s in zip(width, start)]) |
|
|
""" |
|
|
assert width.dim() == 1, 'Only supports 1D tensors' |
|
|
assert isinstance(width, torch.Tensor), 'Only supports Tensors' |
|
|
assert not width.is_floating_point(), 'Only supports Tensors of integers' |
|
|
assert width.ge(0).all(), 'Only supports positive integers' |
|
|
start = start if start is not None else torch.zeros_like(width) |
|
|
assert width.shape == start.shape |
|
|
assert start.dim() == 1, 'Only supports 1D tensors' |
|
|
assert isinstance(start, torch.Tensor), 'Only supports Tensors' |
|
|
assert not start.is_floating_point(), 'Only supports Tensors of integers' |
|
|
width = width.long() |
|
|
start = start.long() |
|
|
device = width.device |
|
|
a = torch.cat((torch.zeros(1, device=device).long(), width[:-1])) |
|
|
offsets = (start - a.cumsum(0)).repeat_interleave(width) |
|
|
return torch.arange(width.sum(), device=device) + offsets |
|
|
|
|
|
|
|
|
def print_tensor_info(a, name=None): |
|
|
"""Print some info about a tensor. Used for debugging. |
|
|
""" |
|
|
is_1d = a.dim() == 1 |
|
|
is_int = not a.is_floating_point() |
|
|
|
|
|
msg = f'{name}: ' if name is not None else '' |
|
|
|
|
|
msg += f'shape={a.shape} ' |
|
|
msg += f'dtype={a.dtype} ' |
|
|
msg += f'min={a.min()} ' |
|
|
msg += f'max={a.max()} ' |
|
|
|
|
|
if is_1d and is_int: |
|
|
msg += f'duplicates={has_duplicates(a)} ' |
|
|
msg += f'sorted={is_sorted(a)} ' |
|
|
msg += f'dense={is_dense(a)} ' |
|
|
msg += f'permutation={is_permutation(a)} ' |
|
|
|
|
|
print(msg) |
|
|
|
|
|
|
|
|
def string_to_dtype(string): |
|
|
if isinstance(string, torch.dtype): |
|
|
return string |
|
|
assert isinstance(string, str) |
|
|
if string in ('half', 'float16'): |
|
|
return torch.float16 |
|
|
if string in ('float', 'float32'): |
|
|
return torch.float32 |
|
|
if string in ('double', 'float64'): |
|
|
return torch.float64 |
|
|
if string == 'bool': |
|
|
return torch.bool |
|
|
if string in ('byte', 'uint8'): |
|
|
return torch.uint8 |
|
|
if string in ('byte', 'int8'): |
|
|
return torch.int8 |
|
|
if string in ('short', 'int16'): |
|
|
return torch.float16 |
|
|
if string in ('int', 'int32'): |
|
|
return torch.float32 |
|
|
if string in ('long', 'int64'): |
|
|
return torch.float64 |
|
|
raise ValueError(f"Unknown dtype='{string}'") |
|
|
|
|
|
|
|
|
def cast_to_optimal_integer_type(a): |
|
|
"""Cast an integer tensor to the smallest possible integer dtype |
|
|
preserving its precision. |
|
|
""" |
|
|
assert isinstance(a, torch.Tensor), \ |
|
|
f"Expected an Tensor input, but received {type(a)} instead" |
|
|
assert not a.is_floating_point(), \ |
|
|
f"Expected an integer-like input, but received dtype={a.dtype} instead" |
|
|
|
|
|
if a.numel() == 0: |
|
|
return a.byte() |
|
|
|
|
|
for dtype in [torch.uint8, torch.int16, torch.int32, torch.int64]: |
|
|
low_enough = torch.iinfo(dtype).min <= a.min() |
|
|
high_enough = a.max() <= torch.iinfo(dtype).max |
|
|
if low_enough and high_enough: |
|
|
return a.to(dtype) |
|
|
|
|
|
raise ValueError(f"Could not cast dtype={a.dtype} to integer.") |
|
|
|
|
|
|
|
|
def cast_numpyfy(a, fp_dtype=torch.float): |
|
|
"""Convert torch.Tensor to numpy while respecting some constraints |
|
|
on output dtype. Integer tensors will be cast to the smallest |
|
|
possible integer dtype preserving their precision. Floating point |
|
|
tensors will be cast to `fp_dtype`. |
|
|
""" |
|
|
if not isinstance(a, torch.Tensor): |
|
|
return numpyfy(a) |
|
|
|
|
|
|
|
|
fp_dtype = string_to_dtype(fp_dtype) |
|
|
|
|
|
|
|
|
if not a.is_floating_point(): |
|
|
return numpyfy(cast_to_optimal_integer_type(a)) |
|
|
|
|
|
|
|
|
return numpyfy(a.to(fp_dtype)) |
|
|
|
|
|
|
|
|
def numpyfy(a): |
|
|
"""Convert torch.Tensor to numpy while respecting some constraints |
|
|
on output dtype. |
|
|
""" |
|
|
if not isinstance(a, torch.Tensor): |
|
|
return a |
|
|
|
|
|
return a.cpu().numpy() |
|
|
|
|
|
|
|
|
def torchify(x): |
|
|
"""Convert np.ndarray to torch.Tensor. |
|
|
""" |
|
|
return torch.from_numpy(x) if isinstance(x, np.ndarray) else x |
|
|
|
|
|
|
|
|
def torch_to_numpy(func): |
|
|
"""Decorator intended for numpy-based functions to be fed and return |
|
|
torch.Tensor arguments. |
|
|
|
|
|
:param func: |
|
|
:return: |
|
|
""" |
|
|
|
|
|
|
|
|
def wrapper_torch_to_numba(*args, **kwargs): |
|
|
args_numba = [numpyfy(x) for x in args] |
|
|
kwargs_numba = {k: numpyfy(v) for k, v in kwargs.items()} |
|
|
out = func(*args_numba, **kwargs_numba) |
|
|
if isinstance(out, list): |
|
|
out = [torchify(x) for x in out] |
|
|
elif isinstance(out, tuple): |
|
|
out = tuple([torchify(x) for x in list(out)]) |
|
|
elif isinstance(out, dict): |
|
|
out = {k: torchify(v) for k, v in out.items()} |
|
|
else: |
|
|
out = torchify(out) |
|
|
return out |
|
|
|
|
|
return wrapper_torch_to_numba |
|
|
|
|
|
|
|
|
@torch_to_numpy |
|
|
@njit(cache=True, nogil=True) |
|
|
def numba_randperm(n): |
|
|
"""Same as torch.randperm but leveraging numba on CPU. |
|
|
|
|
|
NB: slightly faster than `np.random.permutation(np.arange(n))` |
|
|
""" |
|
|
a = np.arange(n) |
|
|
np.random.shuffle(a) |
|
|
return a |
|
|
|
|
|
|
|
|
def fast_randperm(n, device='cpu'): |
|
|
"""Same as torch.randperm, but relies on numba for CPU tensors. This |
|
|
may bring a x2 speedup on CPU for n >= 1e5. |
|
|
|
|
|
``` |
|
|
from time import time |
|
|
import torch |
|
|
from src.utils.tensor import fast_randperm |
|
|
|
|
|
n = 100000 |
|
|
|
|
|
start = time() |
|
|
a = torch.randperm(n) |
|
|
print(f'torch.randperm : {time() - start:0.5f}s') |
|
|
|
|
|
start = time() |
|
|
b = fast_randperm(n) |
|
|
print(f'fast_randperm: {time() - start:0.5f}s') |
|
|
``` |
|
|
""" |
|
|
if device == 'cuda' or \ |
|
|
isinstance(device, torch.device) and device.type == 'cuda': |
|
|
return torch.randperm(n, device=device) |
|
|
return numba_randperm(n) |
|
|
|
|
|
|
|
|
|
|
|
def fast_zeros(*args, dtype=None, device='cpu'): |
|
|
"""Same as torch.zeros but relies numpy on CPU. This may be x40 |
|
|
faster when manipulating large tensors on CPU. |
|
|
|
|
|
``` |
|
|
from time import time |
|
|
import torch |
|
|
import numpy as np |
|
|
from src.utils.tensor import fast_zeros |
|
|
|
|
|
n = 1000000 |
|
|
m = 20 |
|
|
|
|
|
start = time() |
|
|
a = torch.zeros(n, m) |
|
|
print(f'torch.zeros : {time() - start:0.4f}s') |
|
|
|
|
|
start = time() |
|
|
b = torch.from_numpy(np.zeros((n, m), dtype='float32')) |
|
|
print(f'np.zeros: {time() - start:0.4f}s') |
|
|
|
|
|
start = time() |
|
|
c = fast_zeros(n, m) |
|
|
print(f'fast_zeros: {time() - start:0.4f}s') |
|
|
|
|
|
print(torch.equal(a, b), torch.equal(a, c)) |
|
|
``` |
|
|
""" |
|
|
if device == 'cuda' or \ |
|
|
isinstance(device, torch.device) and device.type == 'cuda': |
|
|
return torch.zeros(*args, dtype=dtype, device=device) |
|
|
out = torchify(np.zeros(tuple(args), dtype='float32')) |
|
|
if dtype is not None: |
|
|
out = out.to(dtype) |
|
|
return out |
|
|
|
|
|
|
|
|
def fast_repeat(x, repeats): |
|
|
"""Same as torch.repeat_interleave but relies numpy on CPU. This |
|
|
saves a little bit of time when manipulating large tensors on CPU. |
|
|
|
|
|
``` |
|
|
from time import time |
|
|
import torch |
|
|
import numpy as np |
|
|
from src.utils.tensor import fast_repeat |
|
|
|
|
|
n = 1000000 |
|
|
rmax = 50 |
|
|
values = torch.arange(n) |
|
|
repeats = torch.randint(low=0, high=rmax, size=(n,)) |
|
|
|
|
|
start = time() |
|
|
a = values.repeat_interleave(repeats) |
|
|
print(f'torch.repeat_interleave : {time() - start:0.4f}s') |
|
|
|
|
|
start = time() |
|
|
b = torch.from_numpy(np.repeat(values.numpy(), repeats.numpy())) |
|
|
print(f'np.repeat: {time() - start:0.4f}s') |
|
|
|
|
|
start = time() |
|
|
c = fast_repeat(values, repeats) |
|
|
print(f'fast_repeat: {time() - start:0.4f}s') |
|
|
|
|
|
print(torch.equal(a, b), torch.equal(a, c)) |
|
|
``` |
|
|
""" |
|
|
assert isinstance(x, torch.Tensor) |
|
|
assert isinstance(repeats, int) or x.device == repeats.device |
|
|
if x.is_cuda: |
|
|
return torch.repeat_interleave(x, repeats) |
|
|
if isinstance(repeats, int): |
|
|
return torchify(np.repeat(numpyfy(x), repeats)) |
|
|
else: |
|
|
return torchify(np.repeat(numpyfy(x), numpyfy(repeats))) |
|
|
|