English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
import src
from src.utils.tensor import is_dense, is_sorted, fast_repeat, tensor_idx, \
arange_interleave, fast_randperm
from torch_scatter import scatter_mean
__all__ = [
'indices_to_pointers', 'sizes_to_pointers', 'dense_to_csr', 'csr_to_dense',
'sparse_sort', 'sparse_sort_along_direction', 'sparse_sample']
def indices_to_pointers(indices: torch.Tensor):
"""Convert pre-sorted dense indices to CSR format."""
device = indices.device
assert len(indices.shape) == 1, "Only 1D indices are accepted."
assert indices.shape[0] >= 1, "At least one group index is required."
assert is_dense(indices), "Indices must be dense"
# Sort indices if need be
order = torch.arange(indices.shape[0], device=device)
if not is_sorted(indices):
indices, order = indices.sort()
# Convert sorted indices to pointers
pointers = torch.cat([
torch.LongTensor([0]).to(device),
torch.where(indices[1:] > indices[:-1])[0] + 1,
torch.LongTensor([indices.shape[0]]).to(device)])
return pointers, order
def sizes_to_pointers(sizes: torch.Tensor):
"""Convert a tensor of sizes into the corresponding pointers. This
is a trivial but often-required operation.
"""
assert sizes.dim() == 1
assert sizes.dtype == torch.long
zero = torch.zeros(1, device=sizes.device, dtype=torch.long)
return torch.cat((zero, sizes)).cumsum(dim=0)
def dense_to_csr(a):
"""Convert a dense matrix to its CSR counterpart."""
assert a.dim() == 2
index = a.nonzero(as_tuple=True)
values = a[index]
columns = index[1]
pointers = indices_to_pointers(index[0])[0]
return pointers, columns, values
def csr_to_dense(pointers, columns, values, shape=None):
"""Convert a CSR matrix to its dense counterpart of a given shape.
"""
assert pointers.dim() == 1
assert columns.dim() == 1
assert values.dim() == 1
assert shape is None or len(shape) == 2
assert pointers.device == columns.device == values.device
device = pointers.device
shape_guess = (pointers.shape[0] - 1, columns.max() + 1)
if shape is None:
shape = shape_guess
else:
shape = (max(shape[0], shape_guess[0]), max(shape[1], shape_guess[1]))
n, m = shape
a = torch.zeros(n, m, dtype=values.dtype, device=device)
i = torch.arange(n, device=device)
i = fast_repeat(i, pointers[1:] - pointers[:-1])
j = columns.long()
a[i, j] = values
return a
def sparse_sort(src, index, dim=0, descending=False, eps=1e-6):
"""Lexicographic sort 1D src points based on index first and src
values second.
Credit: https://github.com/rusty1s/pytorch_scatter/issues/48
"""
# NB: we use double precision here to make sure we can capture fine
# grained src changes even with very large index values.
f_src = src.double()
f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0]
norm = (f_src - f_min)/(f_max - f_min + eps) + index.double()*(-1)**int(descending)
perm = norm.argsort(dim=dim, descending=descending)
return src[perm], perm
def sparse_sort_along_direction(src, index, direction, descending=False):
"""Lexicographic sort N-dimensional src points based on index first
and the projection of the src values along a direction second.
"""
assert src.dim() == 2
assert index.dim() == 1
assert src.shape[0] == index.shape[0]
assert direction.dim() == 2 or direction.dim() == 1
if direction.dim() == 1:
direction = direction.view(1, -1)
# If only 1 direction is provided, apply the same direction to all
# points
if direction.shape[0] == 1:
direction = direction.repeat(src.shape[0], 1)
# If the direction is provided group-wise, expand it to the points
if direction.shape[0] != src.shape[0]:
direction = direction[index]
# Compute the centroid for each group. This is not mandatory, but
# may help avoid precision errors if absolute src coordinates are
# too large
centroid = scatter_mean(src, index, dim=0)[index]
# Project the points along the associated direction
projection = torch.einsum('ed, ed -> e', src - centroid, direction)
# Sort the projections
_, perm = sparse_sort(projection, index, descending=descending)
return src[perm], perm
def sparse_sample(idx, n_max=32, n_min=1, mask=None, return_pointers=False):
"""Compute indices to sample elements in a set of size `idx.shape`,
based on which segment they belong to in `idx`.
The sampling operation is run without replacement and each
segment is sampled at least `n_min` and at most `n_max` times,
within the limits allowed by its actual size.
Optionally, a `mask` can be passed to filter out some elements.
:param idx: LongTensor of size N
Segment indices for each of the N elements
:param n_max: int
Maximum number of elements to sample in each segment
:param n_min: int
Minimum number of elements to sample in each segment, within the
limits of its size (i.e. no oversampling)
:param mask: list, np.ndarray, torch.Tensor
Indicates a subset of elements to consider. This allows ignoring
some segments
:param return_pointers: bool
Whether pointers should be returned along with sampling
indices. These indicate which sampled element belongs to which
segment
"""
assert 0 <= n_min <= n_max
# Initialization
device = idx.device
size = idx.bincount()
num_elements = size.sum()
num_segments = idx.max() + 1
# Compute the number of elements that will be sampled from each
# segment, based on a heuristic
if n_max > 0:
# k * tanh(x / k) is bounded by k, is ~x for x~0 and starts
# saturating at x~k
n_samples = (n_max * torch.tanh(size / n_max)).floor().long()
else:
# Fallback to sqrt sampling
n_samples = size.sqrt().round().long()
# Make sure each segment is sampled at least 'n_min' times and not
# sampled more than its size (we sample without replacements).
# If a segment has less than 'n_min' elements, it will be
# entirely sampled (no randomness for sampling this segment),
# which is why we successively apply clamp min and clamp max
n_samples = n_samples.clamp(min=n_min).clamp(max=size)
# Sanity check
if src.is_debug_enabled():
assert n_samples.le(size).all(), \
"Cannot sample more than the segment sizes."
# Prepare the sampled elements indices
sample_idx = torch.arange(num_elements, device=device)
# If a mask is provided, only keep the corresponding elements.
# This also requires updating the `size` and `n_samples`
mask = tensor_idx(mask, device=device)
if mask.shape[0] > 0:
sample_idx = sample_idx[mask]
idx = idx[mask]
size = idx.bincount(minlength=num_segments)
n_samples = n_samples.clamp(max=size)
# Sanity check
if src.is_debug_enabled():
assert n_samples.le(size).all(), \
"Cannot sample more than the segment sizes."
# TODO: IMPORTANT the randperm-sort approach here is a huge
# BOTTLENECK for the sampling operation on CPU. Can we do any
# better ?
# Shuffle the order of elements to introduce randomness
perm = fast_randperm(sample_idx.shape[0], device=device)
idx = idx[perm]
sample_idx = sample_idx[perm]
# Sort by idx. Combined with the previous shuffling,
# this ensures the randomness in the elements selected from each
# segment
idx, order = idx.sort()
sample_idx = sample_idx[order]
# Build the indices of the elements we will sample from
# sample_idx. Note this could easily be expressed with a for
# loop, but we need to use a vectorized formulation to ensure
# reasonable processing time
offset = sizes_to_pointers(size[:-1])
idx_samples = sample_idx[arange_interleave(n_samples, start=offset)]
# Return here if sampling pointers are not required
if not return_pointers:
return idx_samples
# Compute the pointers
ptr_samples = sizes_to_pointers(n_samples)
return idx_samples, ptr_samples.contiguous()