English
File size: 10,079 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import math
import torch
from torch_scatter import scatter_add, scatter_mean, scatter_min
from itertools import combinations_with_replacement
from src.utils.edge import edge_wise_points
from torch_geometric.utils import coalesce


__all__ = [
    'scatter_mean_weighted', 'scatter_pca', 'scatter_nearest_neighbor',
    'idx_preserving_mask', 'scatter_mean_orientation']


def scatter_mean_weighted(x, idx, w, dim_size=None):
    """Helper for scatter_mean with weights"""
    assert w.ge(0).all(), "Only positive weights are accepted"
    assert w.dim() == idx.dim() == 1, "w and idx should be 1D Tensors"
    assert x.shape[0] == w.shape[0] == idx.shape[0], \
        "Only supports weighted mean along the first dimension"

    # Concatenate w and x in the same tensor to only call scatter once
    x = x.view(-1, 1) if x.dim() == 1 else x
    w = w.view(-1, 1).float()
    wx = torch.cat((w, x * w), dim=1)

    # Scatter sum the wx tensor to obtain
    wx_segment = scatter_add(wx, idx, dim=0, dim_size=dim_size)

    # Extract the weighted mean from the result
    w_segment = wx_segment[:, 0]
    x_segment = wx_segment[:, 1:]
    w_segment[w_segment == 0] = 1
    mean_segment = x_segment / w_segment.view(-1, 1)

    return mean_segment


def scatter_pca(x, idx, on_cpu=True):
    """Scatter implementation for PCA.

    Returns eigenvalues and eigenvectors for each group in idx.
    If x has shape N1xD and idx covers indices in [0, N2], the
    eigenvalues will have shape N2xD and the eigenvectors will
    have shape N2xDxD. The eigenvalues and eigenvectors are
    sorted by increasing eigenvalue.
    """
    assert idx.dim() == 1
    assert x.dim() == 2
    assert idx.shape[0] == x.shape[0]
    assert x.shape[1] > 1

    d = x.shape[1]
    device = x.device

    # Substract mean
    mean = scatter_mean(x, idx, dim=0)
    x = x - mean[idx]

    # Compute pointwise covariance as a N_1x(DxD) matrix
    ij = torch.tensor(list(combinations_with_replacement(range(d), 2)), device=device)
    upper_triangle = x[:, ij[:, 0]] * x[:, ij[:, 1]]

    # Aggregate the covariances as a N_2x(DxD) with scatter_sum
    # and convert it to a N_2xDxD batch of matrices
    upper_triangle = scatter_add(upper_triangle, idx, dim=0) / d
    cov = torch.empty((upper_triangle.shape[0], d, d), device=device)
    cov[:, ij[:, 0], ij[:, 1]] = upper_triangle

    # Eigendecompostion
    if on_cpu:
        device = cov.device
        cov = cov.cpu()
        eval, evec = torch.linalg.eigh(cov, UPLO='U')
        eval = eval.to(device)
        evec = evec.to(device)
    else:
        eval, evec = torch.linalg.eigh(cov, UPLO='U')

    # If Nan values are computed, return equal eigenvalues and
    # Identity eigenvectors
    idx_nan = torch.where(torch.logical_and(
        eval.isnan().any(1), evec.flatten(1).isnan().any(1)))
    eval[idx_nan] = torch.ones(3, dtype=eval.dtype, device=device)
    evec[idx_nan] = torch.eye(3, dtype=evec.dtype, device=device)

    # Precision errors may cause close-to-zero eigenvalues to be
    # negative. Hard-code these to zero
    eval[torch.where(eval < 0)] = 0

    return eval, evec


def scatter_nearest_neighbor(
        points, index, edge_index, cycles=3, chunk_size=None):
    """For each pair of segments indicated in edge_index, find the 2
    closest points between the two segments.

    NB: this is an approximate, iterative process.

    :param points: (N, D) tensor
        Points
    :param index: (N) LongTensor
        Segment index, for each point
    :param edge_index: (2, E) LongTensor
        Segment pairs for which to compute the nearest neighbors
    :param cycles int
        Number of iterations. Starting from a point X in set A, one
        cycle accounts for searching the nearest neighbor, in A, of the
        nearest neighbor of X in set B
    :param chunk_size: int, float
        Allows mitigating memory use when computing the neighbors. If
        `chunk_size > 1`, `edge_index` will be processed into chunks of
        `chunk_size`. If `0 < chunk_size < 1`, then `edge_index` will be
        divided into parts of `edge_index.shape[1] * chunk_size` or less
    """
    assert edge_index.shape == coalesce(edge_index).shape, \
        "Does not support duplicate edges, please coalesce the edges" \
        " before calling this function"

    # Recursive call in case chunk is specified. Chunk allows limiting
    # the number of edges processed at once. This might alleviate
    # memory use
    if chunk_size is not None and chunk_size > 0:

        # Recursive call on smaller edge_index chunks
        chunk_size = int(chunk_size) if chunk_size > 1 \
            else math.ceil(edge_index.shape[1] * chunk_size)
        num_chunks = math.ceil(edge_index.shape[1] / chunk_size)
        out_list = []
        for i_chunk in range(num_chunks):
            start = i_chunk * chunk_size
            end = (i_chunk + 1) * chunk_size
            out_list.append(scatter_nearest_neighbor(
                points, index, edge_index[:, start:end], cycles=cycles,
                chunk_size=None))

        # Combine outputs
        candidate = torch.cat([elt[0] for elt in out_list], dim=0)
        candidate_idx = torch.cat([elt[1] for elt in out_list], dim=1)

        return candidate, candidate_idx

    # We define the segments in the first row of edge_index as 'source'
    # segments, while the elements of the second row are 'target'
    # segments. The corresponding variables are prepended with 's_' and
    # 't_' for clarity
    s_idx = edge_index[0]
    t_idx = edge_index[1]

    # Expand the edge variables to point-edge values. That is, the
    # concatenation of all the source --or target-- points for each
    # edge. The corresponding variables are prepended with 'S_' and 'T_'
    # for clarity
    (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) = \
        edge_wise_points(points, index, edge_index)

    # Initialize the candidate points as the centroid of each segment
    segment_centroid = scatter_mean(points, index, dim=0)
    segment_size = index.bincount()
    s_candidate = segment_centroid[s_idx]
    t_candidate = segment_centroid[t_idx]
    s_candidate_idx = -torch.ones_like(s_idx)
    t_candidate_idx = -torch.ones_like(s_idx)

    # Step operation will update the source --target, respectively--
    # candidate based on the current target --source, respectively--
    # candidate
    def step(source=True):
        if source:
             x_idx, y_candidate, X_points, X_points_idx, X_uid = \
                 s_idx, t_candidate, S_points, S_points_idx, S_uid
        else:
            x_idx, y_candidate, X_points, X_points_idx, X_uid = \
                t_idx, s_candidate, T_points, T_points_idx, T_uid

        # Expand the other segments' candidates to point-edge values
        size = segment_size[x_idx]
        Y_candidate = y_candidate.repeat_interleave(size, dim=0)

        # Compute the distance between the points and the other segment's
        # candidate and update the segment's candidate as the point with
        # the smallest distance to the candidate
        X_dist = (X_points - Y_candidate).norm(dim=1)

        # Update the candidate as the point with the smallest distance
        # for each edge
        # TODO: this is the bottleneck of scatter_nearest_neighbor
        _, X_argmin = scatter_min(X_dist, X_uid)
        x_candidate_idx = X_points_idx[X_argmin]
        x_candidate = points[x_candidate_idx]

        return x_candidate, x_candidate_idx

    # Iteratively update the target and source candidates
    for _ in range(cycles):
        t_candidate, t_candidate_idx = step(source=False)
        s_candidate, s_candidate_idx = step(source=True)

    # Stack for output
    candidate = torch.vstack((s_candidate, t_candidate))
    candidate_idx = torch.vstack((s_candidate_idx, t_candidate_idx))

    return candidate, candidate_idx


def idx_preserving_mask(mask, idx, dim=0):
    """Helper to pass a boolean mask and an index, to make sure indexing
    using the mask will not entirely discard all elements of index.
    """
    is_empty = scatter_add(mask.float(), idx, dim=dim) == 0
    return mask | is_empty[idx]


def scatter_mean_orientation(orientation, idx):
    """Scatter implementation for mean normal orientation computation.
    When dealing with normals, we care more about the orientation than
    the sense. So normals are defined up to a sign. When computing the
    average normal across a set of points, we may run into issues. This
    method aims at computing the mean orientation, expressed in the Z+
    halfspace by default.

    :param orientation: (N, D) tensor
        Orientations vectors. Do not need to be normalized but are
        assumed to be expressed with 0 as their origin
    :param idx: (N) LongTensor
        Group index, for each vector
    """
    epsilon = 1e-4

    # Work on copy of input data
    x = orientation.detach().clone()

    # Normalize the orientations
    x /= x.norm(dim=1).view(-1, 1).add_(epsilon)
    x = x.clamp(min=-1, max=1)

    # Compute the phi angle in [0, π/2]
    phi = x[:, 2].arcsin()

    # The group-wise mean phi will indicate whether the group's mean
    # normal is rather horizontal of vertical, with a simple comparison
    # to π/4
    phi_mean = scatter_mean(phi, idx, dim=0)
    is_horizontal = (phi_mean < torch.pi / 4)[idx]

    # Identify the element with the smallest phi in each group. For
    # horizontal groups, this will help us identify the opposing vectors
    # that will need to be flipped to compute the mean orientation
    _, argmin = scatter_min(phi, idx, dim=0)
    is_opposing = (x * x[argmin[idx]]).sum(dim=1) < 0

    # Flip only needed orientation vectors
    x[is_horizontal & is_opposing] *= -1

    # Compute the mean orientation
    x_mean = scatter_mean(x, idx, dim=0)

    # Normalize
    x_mean /= x_mean.norm(dim=1).view(-1, 1).add_(epsilon)
    x_mean = x_mean.clamp(min=-1, max=1)

    # Express in the canonical sense, pointing towards z+
    x_mean[x_mean[:, -1] < 0] *= -1

    return x_mean