English
File size: 18,134 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
import torch
import math
from torch_scatter import scatter_min, scatter_max, scatter_mean
from torch_geometric.utils import coalesce, remove_self_loops
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from src.utils.tensor import arange_interleave
from src.utils.geometry import base_vectors_3d
from src.utils.sparse import sizes_to_pointers, sparse_sort, \
    sparse_sort_along_direction
from src.utils.scatter import scatter_pca, scatter_nearest_neighbor, \
    idx_preserving_mask
from src.utils.edge import edge_wise_points

__all__ = [
    'is_pyg_edge_format', 'isolated_nodes', 'edge_to_superedge', 'subedges',
    'to_trimmed', 'is_trimmed']


def is_pyg_edge_format(edge_index):
    """Check whether edge_index follows pytorch geometric graph edge
    format: a [2, N] torch.LongTensor.
    """
    return \
        isinstance(edge_index, torch.Tensor) and edge_index.dim() == 2 \
        and edge_index.dtype == torch.long and edge_index.shape[0] == 2


def isolated_nodes(edge_index, num_nodes=None):
    """Return a boolean mask of size num_nodes indicating which node has
    no edge in edge_index.
    """
    assert is_pyg_edge_format(edge_index)
    num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes
    device = edge_index.device
    mask = torch.ones(num_nodes, dtype=torch.bool, device=device)
    mask[edge_index.unique()] = False
    return mask


def edge_to_superedge(edges, super_index, edge_attr=None):
    """Convert point-level edges into superedges between clusters, based
    on point-to-cluster indexing 'super_index'. Optionally 'edge_attr'
    can be passed to describe edge attributes that will be returned
    filtered and ordered to describe the superedges.

    NB: this function treats (i, j) and (j, i) superedges as identical.
    By default, the final edges are expressed with i <= j
    """
    # We are only interested in the edges connecting two different
    # clusters and not in the intra-cluster connections. So we first
    # identify the edges of interest. This step requires having access
    # to the 'super_index' to convert point indices into their
    # corresponding cluster indices
    se = super_index[edges]
    inter_cluster = torch.where(se[0] != se[1])[0]

    # Now only consider the edges of interest (i.e. inter-cluster edges)
    edges_inter = edges[:, inter_cluster]
    edge_attr = edge_attr[inter_cluster] if edge_attr is not None else None
    se = se[:, inter_cluster]

    # Search for undirected edges, i.e. edges with (i,j) and (j,i)
    # both present in edge_index. Flip (j,i) into (i,j) to make them
    # redundant. By default, the final edges are expressed with i <= j
    s_larger_t = se[0] > se[1]
    se[:, s_larger_t] = se[:, s_larger_t].flip(0)

    # So far we are manipulating inter-cluster edges, but there may be
    # multiple of those for a given source-target pair. If, we want to
    # aggregate those into 'superedges' and compute corresponding
    # features (designated with 'se_'), we will need unique and
    # compact inter-cluster edge identifiers for torch_scatter
    # operations. We use 'se' to designate 'superedge' (i.e. an edge
    # between two clusters)
    se_id = \
        se[0] * (max(se[0].max(), se[1].max()) + 1) + se[1]
    se_id, perm = consecutive_cluster(se_id)
    se = se[:, perm]

    return se, se_id, edges_inter, edge_attr


def subedges(
        points,
        index,
        edge_index,
        ratio=0.2,
        k_min=20,
        cycles=3,
        pca_on_cpu=True,
        margin=0.2,
        halfspace_filter=True,
        bbox_filter=True,
        target_pc_flip=True,
        source_pc_sort=False,
        chunk_size=None):
    """Compute the subedges making up each edge between segments. These
    are needed for superedge features computation. This approach relies
    on heuristics to avoid the Delaunay triangulation or any other O(N²)
    operation.

    NB: the input edges will be trimmed (see `to_trimmed`) in the first
    place and the returned edge_index will reflect this change. This is
    because subedge computation relies on costly operations. To save
    compute and memory, we only build subedges for the trimmed graph.

    :param points:
        Level-0 points
    :param index:
        Index of the segment each point belongs to
    :param edge_index:
        Edges of the graph between segments
    :param ratio:
        Maximum ratio of a segment's points than can be used in a
        superedge's subedges
    :param k_min:
        Minimum of subedges per superedge
    :param cycles:
        Number of iterations for nearest neighbor search between
        segments
    :param pca_on_cpu:
        Whether PCA should be computed on CPU if need be. Should be kept
        as True
    :param margin:
        Tolerance margin used for selecting subedges points and
        excluding segment points from potential subedge candidates
    :param halfspace_filter:
        Whether the halfspace filtering should be applied
    :param bbox_filter:
        Whether the bounding box filtering should be applied
    :param target_pc_flip:
        Whether the subedge point pairs should be carefully ordered
    :param source_pc_sort:
        Whether the source and target subedge point pairs should be
        ordered along the same vector
    :param chunk_size: int, float
        Allows mitigating memory use when computing the subedges. If
        `chunk_size > 1`, `edge_index` will be processed into chunks of
        `chunk_size`. If `0 < chunk_size < 1`, then `edge_index` will be
        divided into parts of `edge_index.shape[1] * chunk_size` or less
    :return:
    """
    # Trim the graph
    edge_index = to_trimmed(edge_index)

    # Number of segments
    num_segments = index.max() + 1

    # Recursive call in case chunk is specified. Chunk allows limiting
    # the number of edges processed at once. This might alleviate
    # memory use
    if chunk_size is not None and chunk_size > 0:

        # Recursive call on smaller edge_index chunks
        chunk_size = int(chunk_size) if chunk_size > 1 \
            else math.ceil(edge_index.shape[1] * chunk_size)
        num_chunks = math.ceil(edge_index.shape[1] / chunk_size)
        out_list = []
        for i_chunk in range(num_chunks):
            start = i_chunk * chunk_size
            end = (i_chunk + 1) * chunk_size
            out_list.append(subedges(
                points,
                index,
                edge_index[:, start:end],
                ratio=ratio,
                k_min=k_min,
                cycles=cycles,
                pca_on_cpu=pca_on_cpu,
                margin=margin,
                halfspace_filter=halfspace_filter,
                bbox_filter=bbox_filter,
                target_pc_flip=target_pc_flip,
                source_pc_sort=source_pc_sort,
                chunk_size=None))

        # Combine outputs
        device = points.device
        edge_index = torch.cat([elt[0] for elt in out_list], dim=1)
        ST_pairs = torch.cat([elt[1] for elt in out_list], dim=1)
        size = torch.tensor([o[0].shape[1] for o in out_list], device=device)
        offset = sizes_to_pointers(size[:-1])
        ST_uid = torch.cat([elt[2] + o for elt, o in zip(out_list, offset)])

        return edge_index, ST_pairs, ST_uid

    # Compute the nearest neighbors between superedge segments. This
    # pair of points will be crucial in finding the other level-0
    # points making up the superedge
    _, edge_anchor_idx = scatter_nearest_neighbor(
        points, index, edge_index, cycles=cycles)

    # Compute base vectors based on the anchor points source->target
    # direction
    s_anchor = points[edge_anchor_idx[0]]
    t_anchor = points[edge_anchor_idx[1]]
    anchor_base = base_vectors_3d(t_anchor - s_anchor)

    # Recover the number of points in source and target segments. 's_'
    # and 't_' indicate we are dealing with edge-wise values
    s_size, t_size = index.bincount(minlength=num_segments)[edge_index]

    # Expand the points to point-edge values. That is, the concatenation
    # of all the source --or target-- points for each edge. The
    # corresponding variables are prepended with 'S_' and 'T_' for
    # clarity
    (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) = \
        edge_wise_points(points, index, edge_index)

    # Local helper function to convert absolute points coordinates to
    # their local edge coordinate system. This system is defined as
    # such: the origin is the source --target-- anchor point, the 1st
    # axis is given by the source->target direction of the anchor
    # points, and the 2nd and 3rd axes are constructed in the orthogonal
    # plane. NB: the base construction has a degree of freedom in
    # rotation around the 1st axis, but we do not care too much about it
    # here
    def to_anchor_base(source=True):
        if source:
            x_size, x_anchor, X_points = s_size, s_anchor, S_points
        else:
            x_size, x_anchor, X_points = t_size, t_anchor, T_points

        # Center the points wrt their anchor
        X_points = X_points - x_anchor.repeat_interleave(x_size, dim=0)

        # Project on the base vectors
        X_proj = []
        for i in range(3):
            v = anchor_base[:, i].repeat_interleave(x_size, dim=0)
            X_proj.append(torch.einsum('nd, nd -> n', X_points, v))

        return torch.vstack(X_proj).T

    # Project points in their local edge coordinate system
    S_points = to_anchor_base(source=True)
    T_points = to_anchor_base(source=False)
    del s_anchor, t_anchor, anchor_base

    # Select points that are in the half-space before their anchor.
    # Since subedge points (level-0 point pairs making up the superedge
    # between two segments) are searched along the nearest-neighbors
    # (i.e. anchor points) direction, this operation aims at dealing with
    # edges located in concave regions of the segment boundaries
    if halfspace_filter:
        in_S_halfspace = S_points[:, 0] <= margin
        in_S_halfspace = idx_preserving_mask(in_S_halfspace, S_uid)
        in_S_halfspace = torch.where(in_S_halfspace)[0]
        S_points = S_points[in_S_halfspace]
        S_points_idx = S_points_idx[in_S_halfspace]
        S_uid = S_uid[in_S_halfspace]
        del in_S_halfspace
        in_T_halfspace = T_points[:, 0] >= -margin
        in_T_halfspace = idx_preserving_mask(in_T_halfspace, T_uid)
        in_T_halfspace = torch.where(in_T_halfspace)[0]
        T_points = T_points[in_T_halfspace]
        T_points_idx = T_points_idx[in_T_halfspace]
        T_uid = T_uid[in_T_halfspace]
        del in_T_halfspace

    # Compute the bbox intersection in the 2nd and 3rd coordinates
    # plane. This is a proxy for computing the intersection of the
    # projection areas of the segments in the 2nd and 3rd coordinates
    # plane. This operation prevents subedge points from lying too far
    # from the source segment's projection on the target segment along
    # the anchor direction (and conversely)
    if bbox_filter:
        s_min, _ = scatter_min(S_points[:, 1:], S_uid, dim=0)
        s_max, _ = scatter_max(S_points[:, 1:], S_uid, dim=0)
        t_min, _ = scatter_min(T_points[:, 1:], T_uid, dim=0)
        t_max, _ = scatter_max(T_points[:, 1:], T_uid, dim=0)
        st_min = torch.max(s_min, t_min).clamp(max=-margin)
        st_max = torch.min(s_max, t_max).clamp(min=margin)
        del s_min, s_max, t_min, t_max

        # Local helper to select points inside the bbox intersection
        def select_in_bbox(source=True):
            if source:
                X_points, X_points_idx, X_uid = S_points, S_points_idx, S_uid
            else:
                X_points, X_points_idx, X_uid = T_points, T_points_idx, T_uid

            in_bbox = (X_points[:, 1:] >= st_min[X_uid]).all(dim=1) & \
                      (X_points[:, 1:] <= st_max[X_uid]).all(dim=1)
            in_bbox = idx_preserving_mask(in_bbox, X_uid)
            in_bbox = torch.where(in_bbox)[0]

            return X_points[in_bbox], X_points_idx[in_bbox], X_uid[in_bbox]

        # Select points inside the bbox intersection
        S_points, S_points_idx, S_uid = select_in_bbox(source=True)
        T_points, T_points_idx, T_uid = select_in_bbox(source=False)

    # Sort points along the edge direction, the first point being the
    # anchor point and subsequent points farther and farther away from
    # the anchor
    _, perm = sparse_sort(S_points[:, 0], S_uid, descending=True)
    S_points = S_points[perm]
    S_points_idx = S_points_idx[perm]
    S_uid = S_uid[perm]
    del perm
    _, perm = sparse_sort(T_points[:, 0], T_uid, descending=False)
    T_points = T_points[perm]
    T_points_idx = T_points_idx[perm]
    T_uid = T_uid[perm]
    del perm

    # Update the number of selected points in the source/target segments
    # and compute the number of points to keep for each edge. The
    # heuristic we use here is: the top ratio points, with a minimum
    # of k_min, within the limits of the cluster
    s_size = S_uid.bincount()
    t_size = T_uid.bincount()
    s_k = (s_size * ratio).long().clamp(min=k_min).clamp(max=s_size)
    t_k = (t_size * ratio).long().clamp(min=k_min).clamp(max=t_size)
    st_k = torch.min(s_k, t_k)
    del s_k, t_k

    # Select only the first k points for each edge
    S_k_idx = arange_interleave(st_k, start=sizes_to_pointers(s_size[:-1]))
    S_points = S_points[S_k_idx]
    S_points_idx = S_points_idx[S_k_idx]
    S_uid = S_uid[S_k_idx]
    del S_k_idx
    T_k_idx = arange_interleave(st_k, start=sizes_to_pointers(t_size[:-1]))
    T_points = T_points[T_k_idx]
    T_points_idx = T_points_idx[T_k_idx]
    T_uid = T_uid[T_k_idx]
    del T_k_idx

    # Local helper to compute, for each edge, the first eigen vector of
    # the selected subedge points for the source --target,
    # respectively-- segment
    # TODO: scatter_pca is the bottleneck of subedges(), we could
    #  accelerate things by randomly sampling in the clusters
    def first_component(source=True):
        if source:
            X_points, X_uid = S_points, S_uid
        else:
            X_points, X_uid = T_points, T_uid
        return scatter_pca(X_points, X_uid, on_cpu=pca_on_cpu)[1][:, :, -1]

    # Compute the first component of the source and target subedge
    # points, to be used to sort the points and eventually build the
    # subedge pair
    s_v = first_component(source=True)
    t_v = first_component(source=False)

    # Flip the target first component direction when needed. This is to
    # limit subedge crossings. This is motivated by the desire to mimick
    # Delaunay's visibility-based edges
    if target_pc_flip and not source_pc_sort:
        T_proj = (T_points * t_v.repeat_interleave(st_k, dim=0)).sum(dim=1)
        s_mean = scatter_mean(S_points, S_uid, dim=0)
        t_min = T_points[scatter_min(T_proj, T_uid, dim=0)[1]]
        st_u = t_min - s_mean
        st_u /= st_u.norm(dim=1).view(-1, 1)
        to_flip = torch.where((s_v * t_v).sum(dim=1) <= (s_v * st_u).sum(dim=1))[0]
        t_v[to_flip] *= -1
    elif source_pc_sort:
        t_v = s_v

    # Local helper to sort points along their first component
    def sort_by_first_component(source=True):
        if source:
            X_points, X_points_idx, X_uid, x_v = \
                S_points, S_points_idx, S_uid, s_v
        else:
            X_points, X_points_idx, X_uid, x_v = \
                T_points, T_points_idx, T_uid, t_v

        # Sort points along the first component
        X_points, perm = sparse_sort_along_direction(X_points, X_uid, x_v)

        return X_points, X_points_idx[perm], X_uid[perm]

    # Sort the subedge points along their first component
    S_points, S_points_idx, S_uid = sort_by_first_component(source=True)
    T_points, T_points_idx, T_uid = sort_by_first_component(source=False)

    # Bring the subedge points together to make up the final pairs
    ST_pairs = torch.vstack((S_points_idx, T_points_idx))
    ST_uid = S_uid

    return edge_index, ST_pairs, ST_uid


def to_trimmed(edge_index, edge_attr=None, reduce='mean'):
    """Convert to 'trimmed' graph: same as coalescing with the
    additional constraint that (i, j) and (j, i) edges are duplicates.

    If edge attributes are passed, 'reduce' will indicate how to fuse
    duplicate edges' attributes.

    NB: returned edges are expressed with i<j by default.

    :param edge_index: 2xE LongTensor
        Edges in `torch_geometric` format
    :param edge_attr: ExC Tensor
        Edge attributes
    :param reduce: str
        Reduction modes supported by `torch_geometric.utils.coalesce`
    :return:
    """
    # Search for undirected edges, i.e. edges with (i,j) and (j,i)
    # both present in edge_index. Flip (j,i) into (i,j) to make them
    # redundant
    s_larger_t = edge_index[0] > edge_index[1]
    edge_index[:, s_larger_t] = edge_index[:, s_larger_t].flip(0)

    # Sort edges by row and remove duplicates
    if edge_attr is None:
        edge_index = coalesce(edge_index)
    else:
        edge_index, edge_attr = coalesce(
            edge_index, edge_attr=edge_attr, reduce=reduce)

    # Remove self loops
    edge_index, edge_attr = remove_self_loops(
        edge_index, edge_attr=edge_attr)

    if edge_attr is None:
        return edge_index
    return edge_index, edge_attr


def is_trimmed(edge_index, return_trimmed=False):
    """Check if the graph is 'trimmed': same as coalescing with the
    additional constraint that (i, j) and (j, i) edges are duplicates.

    :param edge_index: 2xE LongTensor
        Edges in `torch_geometric` format
    :param return_trimmed: bool
        If True, the trimmed graph will also be returned. Since checking
        if the graph is trimmed requires computing the actual trimmed
        graph, this may save some compute in certain situations
    :return:
    """
    edge_index_trimmed = to_trimmed(edge_index)
    trimmed = edge_index.shape == edge_index_trimmed.shape
    if return_trimmed:
        return trimmed, edge_index_trimmed
    return trimmed