|
|
import torch |
|
|
import src |
|
|
from src.utils.tensor import is_dense, is_sorted, fast_repeat, tensor_idx, \ |
|
|
arange_interleave, fast_randperm |
|
|
from torch_scatter import scatter_mean |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'indices_to_pointers', 'sizes_to_pointers', 'dense_to_csr', 'csr_to_dense', |
|
|
'sparse_sort', 'sparse_sort_along_direction', 'sparse_sample'] |
|
|
|
|
|
|
|
|
def indices_to_pointers(indices: torch.Tensor): |
|
|
"""Convert pre-sorted dense indices to CSR format.""" |
|
|
device = indices.device |
|
|
assert len(indices.shape) == 1, "Only 1D indices are accepted." |
|
|
assert indices.shape[0] >= 1, "At least one group index is required." |
|
|
assert is_dense(indices), "Indices must be dense" |
|
|
|
|
|
|
|
|
order = torch.arange(indices.shape[0], device=device) |
|
|
if not is_sorted(indices): |
|
|
indices, order = indices.sort() |
|
|
|
|
|
|
|
|
pointers = torch.cat([ |
|
|
torch.LongTensor([0]).to(device), |
|
|
torch.where(indices[1:] > indices[:-1])[0] + 1, |
|
|
torch.LongTensor([indices.shape[0]]).to(device)]) |
|
|
|
|
|
return pointers, order |
|
|
|
|
|
|
|
|
def sizes_to_pointers(sizes: torch.Tensor): |
|
|
"""Convert a tensor of sizes into the corresponding pointers. This |
|
|
is a trivial but often-required operation. |
|
|
""" |
|
|
assert sizes.dim() == 1 |
|
|
assert sizes.dtype == torch.long |
|
|
zero = torch.zeros(1, device=sizes.device, dtype=torch.long) |
|
|
return torch.cat((zero, sizes)).cumsum(dim=0) |
|
|
|
|
|
|
|
|
def dense_to_csr(a): |
|
|
"""Convert a dense matrix to its CSR counterpart.""" |
|
|
assert a.dim() == 2 |
|
|
index = a.nonzero(as_tuple=True) |
|
|
values = a[index] |
|
|
columns = index[1] |
|
|
pointers = indices_to_pointers(index[0])[0] |
|
|
return pointers, columns, values |
|
|
|
|
|
|
|
|
def csr_to_dense(pointers, columns, values, shape=None): |
|
|
"""Convert a CSR matrix to its dense counterpart of a given shape. |
|
|
""" |
|
|
assert pointers.dim() == 1 |
|
|
assert columns.dim() == 1 |
|
|
assert values.dim() == 1 |
|
|
assert shape is None or len(shape) == 2 |
|
|
assert pointers.device == columns.device == values.device |
|
|
|
|
|
device = pointers.device |
|
|
|
|
|
shape_guess = (pointers.shape[0] - 1, columns.max() + 1) |
|
|
if shape is None: |
|
|
shape = shape_guess |
|
|
else: |
|
|
shape = (max(shape[0], shape_guess[0]), max(shape[1], shape_guess[1])) |
|
|
|
|
|
n, m = shape |
|
|
a = torch.zeros(n, m, dtype=values.dtype, device=device) |
|
|
i = torch.arange(n, device=device) |
|
|
i = fast_repeat(i, pointers[1:] - pointers[:-1]) |
|
|
j = columns.long() |
|
|
a[i, j] = values |
|
|
|
|
|
return a |
|
|
|
|
|
|
|
|
def sparse_sort(src, index, dim=0, descending=False, eps=1e-6): |
|
|
"""Lexicographic sort 1D src points based on index first and src |
|
|
values second. |
|
|
|
|
|
Credit: https://github.com/rusty1s/pytorch_scatter/issues/48 |
|
|
""" |
|
|
|
|
|
|
|
|
f_src = src.double() |
|
|
f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0] |
|
|
norm = (f_src - f_min)/(f_max - f_min + eps) + index.double()*(-1)**int(descending) |
|
|
perm = norm.argsort(dim=dim, descending=descending) |
|
|
|
|
|
return src[perm], perm |
|
|
|
|
|
|
|
|
def sparse_sort_along_direction(src, index, direction, descending=False): |
|
|
"""Lexicographic sort N-dimensional src points based on index first |
|
|
and the projection of the src values along a direction second. |
|
|
""" |
|
|
assert src.dim() == 2 |
|
|
assert index.dim() == 1 |
|
|
assert src.shape[0] == index.shape[0] |
|
|
assert direction.dim() == 2 or direction.dim() == 1 |
|
|
|
|
|
if direction.dim() == 1: |
|
|
direction = direction.view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
if direction.shape[0] == 1: |
|
|
direction = direction.repeat(src.shape[0], 1) |
|
|
|
|
|
|
|
|
if direction.shape[0] != src.shape[0]: |
|
|
direction = direction[index] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
centroid = scatter_mean(src, index, dim=0)[index] |
|
|
|
|
|
|
|
|
projection = torch.einsum('ed, ed -> e', src - centroid, direction) |
|
|
|
|
|
|
|
|
_, perm = sparse_sort(projection, index, descending=descending) |
|
|
|
|
|
return src[perm], perm |
|
|
|
|
|
|
|
|
def sparse_sample(idx, n_max=32, n_min=1, mask=None, return_pointers=False): |
|
|
"""Compute indices to sample elements in a set of size `idx.shape`, |
|
|
based on which segment they belong to in `idx`. |
|
|
|
|
|
The sampling operation is run without replacement and each |
|
|
segment is sampled at least `n_min` and at most `n_max` times, |
|
|
within the limits allowed by its actual size. |
|
|
|
|
|
Optionally, a `mask` can be passed to filter out some elements. |
|
|
|
|
|
:param idx: LongTensor of size N |
|
|
Segment indices for each of the N elements |
|
|
:param n_max: int |
|
|
Maximum number of elements to sample in each segment |
|
|
:param n_min: int |
|
|
Minimum number of elements to sample in each segment, within the |
|
|
limits of its size (i.e. no oversampling) |
|
|
:param mask: list, np.ndarray, torch.Tensor |
|
|
Indicates a subset of elements to consider. This allows ignoring |
|
|
some segments |
|
|
:param return_pointers: bool |
|
|
Whether pointers should be returned along with sampling |
|
|
indices. These indicate which sampled element belongs to which |
|
|
segment |
|
|
""" |
|
|
assert 0 <= n_min <= n_max |
|
|
|
|
|
|
|
|
device = idx.device |
|
|
size = idx.bincount() |
|
|
num_elements = size.sum() |
|
|
num_segments = idx.max() + 1 |
|
|
|
|
|
|
|
|
|
|
|
if n_max > 0: |
|
|
|
|
|
|
|
|
n_samples = (n_max * torch.tanh(size / n_max)).floor().long() |
|
|
else: |
|
|
|
|
|
n_samples = size.sqrt().round().long() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_samples = n_samples.clamp(min=n_min).clamp(max=size) |
|
|
|
|
|
|
|
|
if src.is_debug_enabled(): |
|
|
assert n_samples.le(size).all(), \ |
|
|
"Cannot sample more than the segment sizes." |
|
|
|
|
|
|
|
|
sample_idx = torch.arange(num_elements, device=device) |
|
|
|
|
|
|
|
|
|
|
|
mask = tensor_idx(mask, device=device) |
|
|
if mask.shape[0] > 0: |
|
|
sample_idx = sample_idx[mask] |
|
|
idx = idx[mask] |
|
|
size = idx.bincount(minlength=num_segments) |
|
|
n_samples = n_samples.clamp(max=size) |
|
|
|
|
|
|
|
|
if src.is_debug_enabled(): |
|
|
assert n_samples.le(size).all(), \ |
|
|
"Cannot sample more than the segment sizes." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
perm = fast_randperm(sample_idx.shape[0], device=device) |
|
|
idx = idx[perm] |
|
|
sample_idx = sample_idx[perm] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx, order = idx.sort() |
|
|
sample_idx = sample_idx[order] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = sizes_to_pointers(size[:-1]) |
|
|
idx_samples = sample_idx[arange_interleave(n_samples, start=offset)] |
|
|
|
|
|
|
|
|
if not return_pointers: |
|
|
return idx_samples |
|
|
|
|
|
|
|
|
ptr_samples = sizes_to_pointers(n_samples) |
|
|
|
|
|
return idx_samples, ptr_samples.contiguous() |
|
|
|