| | import torch |
| | import src |
| | from src.utils.tensor import is_dense, is_sorted, fast_repeat, tensor_idx, \ |
| | arange_interleave, fast_randperm |
| | from torch_scatter import scatter_mean |
| |
|
| |
|
| | __all__ = [ |
| | 'indices_to_pointers', 'sizes_to_pointers', 'dense_to_csr', 'csr_to_dense', |
| | 'sparse_sort', 'sparse_sort_along_direction', 'sparse_sample'] |
| |
|
| |
|
| | def indices_to_pointers(indices: torch.Tensor): |
| | """Convert pre-sorted dense indices to CSR format.""" |
| | device = indices.device |
| | assert len(indices.shape) == 1, "Only 1D indices are accepted." |
| | assert indices.shape[0] >= 1, "At least one group index is required." |
| | assert is_dense(indices), "Indices must be dense" |
| |
|
| | |
| | order = torch.arange(indices.shape[0], device=device) |
| | if not is_sorted(indices): |
| | indices, order = indices.sort() |
| |
|
| | |
| | pointers = torch.cat([ |
| | torch.LongTensor([0]).to(device), |
| | torch.where(indices[1:] > indices[:-1])[0] + 1, |
| | torch.LongTensor([indices.shape[0]]).to(device)]) |
| |
|
| | return pointers, order |
| |
|
| |
|
| | def sizes_to_pointers(sizes: torch.Tensor): |
| | """Convert a tensor of sizes into the corresponding pointers. This |
| | is a trivial but often-required operation. |
| | """ |
| | assert sizes.dim() == 1 |
| | assert sizes.dtype == torch.long |
| | zero = torch.zeros(1, device=sizes.device, dtype=torch.long) |
| | return torch.cat((zero, sizes)).cumsum(dim=0) |
| |
|
| |
|
| | def dense_to_csr(a): |
| | """Convert a dense matrix to its CSR counterpart.""" |
| | assert a.dim() == 2 |
| | index = a.nonzero(as_tuple=True) |
| | values = a[index] |
| | columns = index[1] |
| | pointers = indices_to_pointers(index[0])[0] |
| | return pointers, columns, values |
| |
|
| |
|
| | def csr_to_dense(pointers, columns, values, shape=None): |
| | """Convert a CSR matrix to its dense counterpart of a given shape. |
| | """ |
| | assert pointers.dim() == 1 |
| | assert columns.dim() == 1 |
| | assert values.dim() == 1 |
| | assert shape is None or len(shape) == 2 |
| | assert pointers.device == columns.device == values.device |
| |
|
| | device = pointers.device |
| |
|
| | shape_guess = (pointers.shape[0] - 1, columns.max() + 1) |
| | if shape is None: |
| | shape = shape_guess |
| | else: |
| | shape = (max(shape[0], shape_guess[0]), max(shape[1], shape_guess[1])) |
| |
|
| | n, m = shape |
| | a = torch.zeros(n, m, dtype=values.dtype, device=device) |
| | i = torch.arange(n, device=device) |
| | i = fast_repeat(i, pointers[1:] - pointers[:-1]) |
| | j = columns.long() |
| | a[i, j] = values |
| |
|
| | return a |
| |
|
| |
|
| | def sparse_sort(src, index, dim=0, descending=False, eps=1e-6): |
| | """Lexicographic sort 1D src points based on index first and src |
| | values second. |
| | |
| | Credit: https://github.com/rusty1s/pytorch_scatter/issues/48 |
| | """ |
| | |
| | |
| | f_src = src.double() |
| | f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0] |
| | norm = (f_src - f_min)/(f_max - f_min + eps) + index.double()*(-1)**int(descending) |
| | perm = norm.argsort(dim=dim, descending=descending) |
| |
|
| | return src[perm], perm |
| |
|
| |
|
| | def sparse_sort_along_direction(src, index, direction, descending=False): |
| | """Lexicographic sort N-dimensional src points based on index first |
| | and the projection of the src values along a direction second. |
| | """ |
| | assert src.dim() == 2 |
| | assert index.dim() == 1 |
| | assert src.shape[0] == index.shape[0] |
| | assert direction.dim() == 2 or direction.dim() == 1 |
| |
|
| | if direction.dim() == 1: |
| | direction = direction.view(1, -1) |
| |
|
| | |
| | |
| | if direction.shape[0] == 1: |
| | direction = direction.repeat(src.shape[0], 1) |
| |
|
| | |
| | if direction.shape[0] != src.shape[0]: |
| | direction = direction[index] |
| |
|
| | |
| | |
| | |
| | centroid = scatter_mean(src, index, dim=0)[index] |
| |
|
| | |
| | projection = torch.einsum('ed, ed -> e', src - centroid, direction) |
| |
|
| | |
| | _, perm = sparse_sort(projection, index, descending=descending) |
| |
|
| | return src[perm], perm |
| |
|
| |
|
| | def sparse_sample(idx, n_max=32, n_min=1, mask=None, return_pointers=False): |
| | """Compute indices to sample elements in a set of size `idx.shape`, |
| | based on which segment they belong to in `idx`. |
| | |
| | The sampling operation is run without replacement and each |
| | segment is sampled at least `n_min` and at most `n_max` times, |
| | within the limits allowed by its actual size. |
| | |
| | Optionally, a `mask` can be passed to filter out some elements. |
| | |
| | :param idx: LongTensor of size N |
| | Segment indices for each of the N elements |
| | :param n_max: int |
| | Maximum number of elements to sample in each segment |
| | :param n_min: int |
| | Minimum number of elements to sample in each segment, within the |
| | limits of its size (i.e. no oversampling) |
| | :param mask: list, np.ndarray, torch.Tensor |
| | Indicates a subset of elements to consider. This allows ignoring |
| | some segments |
| | :param return_pointers: bool |
| | Whether pointers should be returned along with sampling |
| | indices. These indicate which sampled element belongs to which |
| | segment |
| | """ |
| | assert 0 <= n_min <= n_max |
| |
|
| | |
| | device = idx.device |
| | size = idx.bincount() |
| | num_elements = size.sum() |
| | num_segments = idx.max() + 1 |
| |
|
| | |
| | |
| | if n_max > 0: |
| | |
| | |
| | n_samples = (n_max * torch.tanh(size / n_max)).floor().long() |
| | else: |
| | |
| | n_samples = size.sqrt().round().long() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | n_samples = n_samples.clamp(min=n_min).clamp(max=size) |
| |
|
| | |
| | if src.is_debug_enabled(): |
| | assert n_samples.le(size).all(), \ |
| | "Cannot sample more than the segment sizes." |
| |
|
| | |
| | sample_idx = torch.arange(num_elements, device=device) |
| |
|
| | |
| | |
| | mask = tensor_idx(mask, device=device) |
| | if mask.shape[0] > 0: |
| | sample_idx = sample_idx[mask] |
| | idx = idx[mask] |
| | size = idx.bincount(minlength=num_segments) |
| | n_samples = n_samples.clamp(max=size) |
| |
|
| | |
| | if src.is_debug_enabled(): |
| | assert n_samples.le(size).all(), \ |
| | "Cannot sample more than the segment sizes." |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | perm = fast_randperm(sample_idx.shape[0], device=device) |
| | idx = idx[perm] |
| | sample_idx = sample_idx[perm] |
| |
|
| | |
| | |
| | |
| | idx, order = idx.sort() |
| | sample_idx = sample_idx[order] |
| |
|
| | |
| | |
| | |
| | |
| | offset = sizes_to_pointers(size[:-1]) |
| | idx_samples = sample_idx[arange_interleave(n_samples, start=offset)] |
| |
|
| | |
| | if not return_pointers: |
| | return idx_samples |
| |
|
| | |
| | ptr_samples = sizes_to_pointers(n_samples) |
| |
|
| | return idx_samples, ptr_samples.contiguous() |
| |
|