English
File size: 8,260 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import torch
import src
from src.utils.tensor import is_dense, is_sorted, fast_repeat, tensor_idx, \
    arange_interleave, fast_randperm
from torch_scatter import scatter_mean


__all__ = [
    'indices_to_pointers', 'sizes_to_pointers', 'dense_to_csr', 'csr_to_dense',
    'sparse_sort', 'sparse_sort_along_direction', 'sparse_sample']


def indices_to_pointers(indices: torch.Tensor):
    """Convert pre-sorted dense indices to CSR format."""
    device = indices.device
    assert len(indices.shape) == 1, "Only 1D indices are accepted."
    assert indices.shape[0] >= 1, "At least one group index is required."
    assert is_dense(indices), "Indices must be dense"

    # Sort indices if need be
    order = torch.arange(indices.shape[0], device=device)
    if not is_sorted(indices):
        indices, order = indices.sort()

    # Convert sorted indices to pointers
    pointers = torch.cat([
        torch.LongTensor([0]).to(device),
        torch.where(indices[1:] > indices[:-1])[0] + 1,
        torch.LongTensor([indices.shape[0]]).to(device)])

    return pointers, order


def sizes_to_pointers(sizes: torch.Tensor):
    """Convert a tensor of sizes into the corresponding pointers. This
    is a trivial but often-required operation.
    """
    assert sizes.dim() == 1
    assert sizes.dtype == torch.long
    zero = torch.zeros(1, device=sizes.device, dtype=torch.long)
    return torch.cat((zero, sizes)).cumsum(dim=0)


def dense_to_csr(a):
    """Convert a dense matrix to its CSR counterpart."""
    assert a.dim() == 2
    index = a.nonzero(as_tuple=True)
    values = a[index]
    columns = index[1]
    pointers = indices_to_pointers(index[0])[0]
    return pointers, columns, values


def csr_to_dense(pointers, columns, values, shape=None):
    """Convert a CSR matrix to its dense counterpart of a given shape.
    """
    assert pointers.dim() == 1
    assert columns.dim() == 1
    assert values.dim() == 1
    assert shape is None or len(shape) == 2
    assert pointers.device == columns.device == values.device

    device = pointers.device

    shape_guess = (pointers.shape[0] - 1, columns.max() + 1)
    if shape is None:
        shape = shape_guess
    else:
        shape = (max(shape[0], shape_guess[0]), max(shape[1], shape_guess[1]))

    n, m = shape
    a = torch.zeros(n, m, dtype=values.dtype, device=device)
    i = torch.arange(n, device=device)
    i = fast_repeat(i, pointers[1:] - pointers[:-1])
    j = columns.long()
    a[i, j] = values

    return a


def sparse_sort(src, index, dim=0, descending=False, eps=1e-6):
    """Lexicographic sort 1D src points based on index first and src
    values second.

    Credit: https://github.com/rusty1s/pytorch_scatter/issues/48
    """
    # NB: we use double precision here to make sure we can capture fine
    # grained src changes even with very large index values.
    f_src = src.double()
    f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0]
    norm = (f_src - f_min)/(f_max - f_min + eps) + index.double()*(-1)**int(descending)
    perm = norm.argsort(dim=dim, descending=descending)

    return src[perm], perm


def sparse_sort_along_direction(src, index, direction, descending=False):
    """Lexicographic sort N-dimensional src points based on index first
    and the projection of the src values along a direction second.
    """
    assert src.dim() == 2
    assert index.dim() == 1
    assert src.shape[0] == index.shape[0]
    assert direction.dim() == 2 or direction.dim() == 1

    if direction.dim() == 1:
        direction = direction.view(1, -1)

    # If only 1 direction is provided, apply the same direction to all
    # points
    if direction.shape[0] == 1:
        direction = direction.repeat(src.shape[0], 1)

    # If the direction is provided group-wise, expand it to the points
    if direction.shape[0] != src.shape[0]:
        direction = direction[index]

    # Compute the centroid for each group. This is not mandatory, but
    # may help avoid precision errors if absolute src coordinates are
    # too large
    centroid = scatter_mean(src, index, dim=0)[index]

    # Project the points along the associated direction
    projection = torch.einsum('ed, ed -> e', src - centroid, direction)

    # Sort the projections
    _, perm = sparse_sort(projection, index, descending=descending)

    return src[perm], perm


def sparse_sample(idx, n_max=32, n_min=1, mask=None, return_pointers=False):
    """Compute indices to sample elements in a set of size `idx.shape`,
    based on which segment they belong to in `idx`.

    The sampling operation is run without replacement and each
    segment is sampled at least `n_min` and at most `n_max` times,
    within the limits allowed by its actual size.

    Optionally, a `mask` can be passed to filter out some elements.

    :param idx: LongTensor of size N
        Segment indices for each of the N elements
    :param n_max: int
        Maximum number of elements to sample in each segment
    :param n_min: int
        Minimum number of elements to sample in each segment, within the
        limits of its size (i.e. no oversampling)
    :param mask: list, np.ndarray, torch.Tensor
        Indicates a subset of elements to consider. This allows ignoring
        some segments
    :param return_pointers: bool
        Whether pointers should be returned along with sampling
        indices. These indicate which sampled element belongs to which
        segment
    """
    assert 0 <= n_min <= n_max

    # Initialization
    device = idx.device
    size = idx.bincount()
    num_elements = size.sum()
    num_segments = idx.max() + 1

    # Compute the number of elements that will be sampled from each
    # segment, based on a heuristic
    if n_max > 0:
        # k * tanh(x / k) is bounded by k, is ~x for x~0 and starts
        # saturating at x~k
        n_samples = (n_max * torch.tanh(size / n_max)).floor().long()
    else:
        # Fallback to sqrt sampling
        n_samples = size.sqrt().round().long()

    # Make sure each segment is sampled at least 'n_min' times and not
    # sampled more than its size (we sample without replacements).
    # If a segment has less than 'n_min' elements, it will be
    # entirely sampled (no randomness for sampling this segment),
    # which is why we successively apply clamp min and clamp max
    n_samples = n_samples.clamp(min=n_min).clamp(max=size)

    # Sanity check
    if src.is_debug_enabled():
        assert n_samples.le(size).all(), \
            "Cannot sample more than the segment sizes."

    # Prepare the sampled elements indices
    sample_idx = torch.arange(num_elements, device=device)

    # If a mask is provided, only keep the corresponding elements.
    # This also requires updating the `size` and `n_samples`
    mask = tensor_idx(mask, device=device)
    if mask.shape[0] > 0:
        sample_idx = sample_idx[mask]
        idx = idx[mask]
        size = idx.bincount(minlength=num_segments)
        n_samples = n_samples.clamp(max=size)

    # Sanity check
    if src.is_debug_enabled():
        assert n_samples.le(size).all(), \
            "Cannot sample more than the segment sizes."

    # TODO: IMPORTANT the randperm-sort approach here is a huge
    #  BOTTLENECK for the sampling operation on CPU. Can we do any
    #  better ?

    # Shuffle the order of elements to introduce randomness
    perm = fast_randperm(sample_idx.shape[0], device=device)
    idx = idx[perm]
    sample_idx = sample_idx[perm]

    # Sort by idx. Combined with the previous shuffling,
    # this ensures the randomness in the elements selected from each
    # segment
    idx, order = idx.sort()
    sample_idx = sample_idx[order]

    # Build the indices of the elements we will sample from
    # sample_idx. Note this could easily be expressed with a for
    # loop, but we need to use a vectorized formulation to ensure
    # reasonable processing time
    offset = sizes_to_pointers(size[:-1])
    idx_samples = sample_idx[arange_interleave(n_samples, start=offset)]

    # Return here if sampling pointers are not required
    if not return_pointers:
        return idx_samples

    # Compute the pointers
    ptr_samples = sizes_to_pointers(n_samples)

    return idx_samples, ptr_samples.contiguous()