| | import torch |
| | import math |
| | from torch_scatter import scatter_min, scatter_max, scatter_mean |
| | from torch_geometric.utils import coalesce, remove_self_loops |
| | from torch_geometric.nn.pool.consecutive import consecutive_cluster |
| | from src.utils.tensor import arange_interleave |
| | from src.utils.geometry import base_vectors_3d |
| | from src.utils.sparse import sizes_to_pointers, sparse_sort, \ |
| | sparse_sort_along_direction |
| | from src.utils.scatter import scatter_pca, scatter_nearest_neighbor, \ |
| | idx_preserving_mask |
| | from src.utils.edge import edge_wise_points |
| |
|
| | __all__ = [ |
| | 'is_pyg_edge_format', 'isolated_nodes', 'edge_to_superedge', 'subedges', |
| | 'to_trimmed', 'is_trimmed'] |
| |
|
| |
|
| | def is_pyg_edge_format(edge_index): |
| | """Check whether edge_index follows pytorch geometric graph edge |
| | format: a [2, N] torch.LongTensor. |
| | """ |
| | return \ |
| | isinstance(edge_index, torch.Tensor) and edge_index.dim() == 2 \ |
| | and edge_index.dtype == torch.long and edge_index.shape[0] == 2 |
| |
|
| |
|
| | def isolated_nodes(edge_index, num_nodes=None): |
| | """Return a boolean mask of size num_nodes indicating which node has |
| | no edge in edge_index. |
| | """ |
| | assert is_pyg_edge_format(edge_index) |
| | num_nodes = edge_index.max() + 1 if num_nodes is None else num_nodes |
| | device = edge_index.device |
| | mask = torch.ones(num_nodes, dtype=torch.bool, device=device) |
| | mask[edge_index.unique()] = False |
| | return mask |
| |
|
| |
|
| | def edge_to_superedge(edges, super_index, edge_attr=None): |
| | """Convert point-level edges into superedges between clusters, based |
| | on point-to-cluster indexing 'super_index'. Optionally 'edge_attr' |
| | can be passed to describe edge attributes that will be returned |
| | filtered and ordered to describe the superedges. |
| | |
| | NB: this function treats (i, j) and (j, i) superedges as identical. |
| | By default, the final edges are expressed with i <= j |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | se = super_index[edges] |
| | inter_cluster = torch.where(se[0] != se[1])[0] |
| |
|
| | |
| | edges_inter = edges[:, inter_cluster] |
| | edge_attr = edge_attr[inter_cluster] if edge_attr is not None else None |
| | se = se[:, inter_cluster] |
| |
|
| | |
| | |
| | |
| | s_larger_t = se[0] > se[1] |
| | se[:, s_larger_t] = se[:, s_larger_t].flip(0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | se_id = \ |
| | se[0] * (max(se[0].max(), se[1].max()) + 1) + se[1] |
| | se_id, perm = consecutive_cluster(se_id) |
| | se = se[:, perm] |
| |
|
| | return se, se_id, edges_inter, edge_attr |
| |
|
| |
|
| | def subedges( |
| | points, |
| | index, |
| | edge_index, |
| | ratio=0.2, |
| | k_min=20, |
| | cycles=3, |
| | pca_on_cpu=True, |
| | margin=0.2, |
| | halfspace_filter=True, |
| | bbox_filter=True, |
| | target_pc_flip=True, |
| | source_pc_sort=False, |
| | chunk_size=None): |
| | """Compute the subedges making up each edge between segments. These |
| | are needed for superedge features computation. This approach relies |
| | on heuristics to avoid the Delaunay triangulation or any other O(N²) |
| | operation. |
| | |
| | NB: the input edges will be trimmed (see `to_trimmed`) in the first |
| | place and the returned edge_index will reflect this change. This is |
| | because subedge computation relies on costly operations. To save |
| | compute and memory, we only build subedges for the trimmed graph. |
| | |
| | :param points: |
| | Level-0 points |
| | :param index: |
| | Index of the segment each point belongs to |
| | :param edge_index: |
| | Edges of the graph between segments |
| | :param ratio: |
| | Maximum ratio of a segment's points than can be used in a |
| | superedge's subedges |
| | :param k_min: |
| | Minimum of subedges per superedge |
| | :param cycles: |
| | Number of iterations for nearest neighbor search between |
| | segments |
| | :param pca_on_cpu: |
| | Whether PCA should be computed on CPU if need be. Should be kept |
| | as True |
| | :param margin: |
| | Tolerance margin used for selecting subedges points and |
| | excluding segment points from potential subedge candidates |
| | :param halfspace_filter: |
| | Whether the halfspace filtering should be applied |
| | :param bbox_filter: |
| | Whether the bounding box filtering should be applied |
| | :param target_pc_flip: |
| | Whether the subedge point pairs should be carefully ordered |
| | :param source_pc_sort: |
| | Whether the source and target subedge point pairs should be |
| | ordered along the same vector |
| | :param chunk_size: int, float |
| | Allows mitigating memory use when computing the subedges. If |
| | `chunk_size > 1`, `edge_index` will be processed into chunks of |
| | `chunk_size`. If `0 < chunk_size < 1`, then `edge_index` will be |
| | divided into parts of `edge_index.shape[1] * chunk_size` or less |
| | :return: |
| | """ |
| | |
| | edge_index = to_trimmed(edge_index) |
| |
|
| | |
| | num_segments = index.max() + 1 |
| |
|
| | |
| | |
| | |
| | if chunk_size is not None and chunk_size > 0: |
| |
|
| | |
| | chunk_size = int(chunk_size) if chunk_size > 1 \ |
| | else math.ceil(edge_index.shape[1] * chunk_size) |
| | num_chunks = math.ceil(edge_index.shape[1] / chunk_size) |
| | out_list = [] |
| | for i_chunk in range(num_chunks): |
| | start = i_chunk * chunk_size |
| | end = (i_chunk + 1) * chunk_size |
| | out_list.append(subedges( |
| | points, |
| | index, |
| | edge_index[:, start:end], |
| | ratio=ratio, |
| | k_min=k_min, |
| | cycles=cycles, |
| | pca_on_cpu=pca_on_cpu, |
| | margin=margin, |
| | halfspace_filter=halfspace_filter, |
| | bbox_filter=bbox_filter, |
| | target_pc_flip=target_pc_flip, |
| | source_pc_sort=source_pc_sort, |
| | chunk_size=None)) |
| |
|
| | |
| | device = points.device |
| | edge_index = torch.cat([elt[0] for elt in out_list], dim=1) |
| | ST_pairs = torch.cat([elt[1] for elt in out_list], dim=1) |
| | size = torch.tensor([o[0].shape[1] for o in out_list], device=device) |
| | offset = sizes_to_pointers(size[:-1]) |
| | ST_uid = torch.cat([elt[2] + o for elt, o in zip(out_list, offset)]) |
| |
|
| | return edge_index, ST_pairs, ST_uid |
| |
|
| | |
| | |
| | |
| | _, edge_anchor_idx = scatter_nearest_neighbor( |
| | points, index, edge_index, cycles=cycles) |
| |
|
| | |
| | |
| | s_anchor = points[edge_anchor_idx[0]] |
| | t_anchor = points[edge_anchor_idx[1]] |
| | anchor_base = base_vectors_3d(t_anchor - s_anchor) |
| |
|
| | |
| | |
| | s_size, t_size = index.bincount(minlength=num_segments)[edge_index] |
| |
|
| | |
| | |
| | |
| | |
| | (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) = \ |
| | edge_wise_points(points, index, edge_index) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def to_anchor_base(source=True): |
| | if source: |
| | x_size, x_anchor, X_points = s_size, s_anchor, S_points |
| | else: |
| | x_size, x_anchor, X_points = t_size, t_anchor, T_points |
| |
|
| | |
| | X_points = X_points - x_anchor.repeat_interleave(x_size, dim=0) |
| |
|
| | |
| | X_proj = [] |
| | for i in range(3): |
| | v = anchor_base[:, i].repeat_interleave(x_size, dim=0) |
| | X_proj.append(torch.einsum('nd, nd -> n', X_points, v)) |
| |
|
| | return torch.vstack(X_proj).T |
| |
|
| | |
| | S_points = to_anchor_base(source=True) |
| | T_points = to_anchor_base(source=False) |
| | del s_anchor, t_anchor, anchor_base |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if halfspace_filter: |
| | in_S_halfspace = S_points[:, 0] <= margin |
| | in_S_halfspace = idx_preserving_mask(in_S_halfspace, S_uid) |
| | in_S_halfspace = torch.where(in_S_halfspace)[0] |
| | S_points = S_points[in_S_halfspace] |
| | S_points_idx = S_points_idx[in_S_halfspace] |
| | S_uid = S_uid[in_S_halfspace] |
| | del in_S_halfspace |
| | in_T_halfspace = T_points[:, 0] >= -margin |
| | in_T_halfspace = idx_preserving_mask(in_T_halfspace, T_uid) |
| | in_T_halfspace = torch.where(in_T_halfspace)[0] |
| | T_points = T_points[in_T_halfspace] |
| | T_points_idx = T_points_idx[in_T_halfspace] |
| | T_uid = T_uid[in_T_halfspace] |
| | del in_T_halfspace |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if bbox_filter: |
| | s_min, _ = scatter_min(S_points[:, 1:], S_uid, dim=0) |
| | s_max, _ = scatter_max(S_points[:, 1:], S_uid, dim=0) |
| | t_min, _ = scatter_min(T_points[:, 1:], T_uid, dim=0) |
| | t_max, _ = scatter_max(T_points[:, 1:], T_uid, dim=0) |
| | st_min = torch.max(s_min, t_min).clamp(max=-margin) |
| | st_max = torch.min(s_max, t_max).clamp(min=margin) |
| | del s_min, s_max, t_min, t_max |
| |
|
| | |
| | def select_in_bbox(source=True): |
| | if source: |
| | X_points, X_points_idx, X_uid = S_points, S_points_idx, S_uid |
| | else: |
| | X_points, X_points_idx, X_uid = T_points, T_points_idx, T_uid |
| |
|
| | in_bbox = (X_points[:, 1:] >= st_min[X_uid]).all(dim=1) & \ |
| | (X_points[:, 1:] <= st_max[X_uid]).all(dim=1) |
| | in_bbox = idx_preserving_mask(in_bbox, X_uid) |
| | in_bbox = torch.where(in_bbox)[0] |
| |
|
| | return X_points[in_bbox], X_points_idx[in_bbox], X_uid[in_bbox] |
| |
|
| | |
| | S_points, S_points_idx, S_uid = select_in_bbox(source=True) |
| | T_points, T_points_idx, T_uid = select_in_bbox(source=False) |
| |
|
| | |
| | |
| | |
| | _, perm = sparse_sort(S_points[:, 0], S_uid, descending=True) |
| | S_points = S_points[perm] |
| | S_points_idx = S_points_idx[perm] |
| | S_uid = S_uid[perm] |
| | del perm |
| | _, perm = sparse_sort(T_points[:, 0], T_uid, descending=False) |
| | T_points = T_points[perm] |
| | T_points_idx = T_points_idx[perm] |
| | T_uid = T_uid[perm] |
| | del perm |
| |
|
| | |
| | |
| | |
| | |
| | s_size = S_uid.bincount() |
| | t_size = T_uid.bincount() |
| | s_k = (s_size * ratio).long().clamp(min=k_min).clamp(max=s_size) |
| | t_k = (t_size * ratio).long().clamp(min=k_min).clamp(max=t_size) |
| | st_k = torch.min(s_k, t_k) |
| | del s_k, t_k |
| |
|
| | |
| | S_k_idx = arange_interleave(st_k, start=sizes_to_pointers(s_size[:-1])) |
| | S_points = S_points[S_k_idx] |
| | S_points_idx = S_points_idx[S_k_idx] |
| | S_uid = S_uid[S_k_idx] |
| | del S_k_idx |
| | T_k_idx = arange_interleave(st_k, start=sizes_to_pointers(t_size[:-1])) |
| | T_points = T_points[T_k_idx] |
| | T_points_idx = T_points_idx[T_k_idx] |
| | T_uid = T_uid[T_k_idx] |
| | del T_k_idx |
| |
|
| | |
| | |
| | |
| | |
| | |
| | def first_component(source=True): |
| | if source: |
| | X_points, X_uid = S_points, S_uid |
| | else: |
| | X_points, X_uid = T_points, T_uid |
| | return scatter_pca(X_points, X_uid, on_cpu=pca_on_cpu)[1][:, :, -1] |
| |
|
| | |
| | |
| | |
| | s_v = first_component(source=True) |
| | t_v = first_component(source=False) |
| |
|
| | |
| | |
| | |
| | if target_pc_flip and not source_pc_sort: |
| | T_proj = (T_points * t_v.repeat_interleave(st_k, dim=0)).sum(dim=1) |
| | s_mean = scatter_mean(S_points, S_uid, dim=0) |
| | t_min = T_points[scatter_min(T_proj, T_uid, dim=0)[1]] |
| | st_u = t_min - s_mean |
| | st_u /= st_u.norm(dim=1).view(-1, 1) |
| | to_flip = torch.where((s_v * t_v).sum(dim=1) <= (s_v * st_u).sum(dim=1))[0] |
| | t_v[to_flip] *= -1 |
| | elif source_pc_sort: |
| | t_v = s_v |
| |
|
| | |
| | def sort_by_first_component(source=True): |
| | if source: |
| | X_points, X_points_idx, X_uid, x_v = \ |
| | S_points, S_points_idx, S_uid, s_v |
| | else: |
| | X_points, X_points_idx, X_uid, x_v = \ |
| | T_points, T_points_idx, T_uid, t_v |
| |
|
| | |
| | X_points, perm = sparse_sort_along_direction(X_points, X_uid, x_v) |
| |
|
| | return X_points, X_points_idx[perm], X_uid[perm] |
| |
|
| | |
| | S_points, S_points_idx, S_uid = sort_by_first_component(source=True) |
| | T_points, T_points_idx, T_uid = sort_by_first_component(source=False) |
| |
|
| | |
| | ST_pairs = torch.vstack((S_points_idx, T_points_idx)) |
| | ST_uid = S_uid |
| |
|
| | return edge_index, ST_pairs, ST_uid |
| |
|
| |
|
| | def to_trimmed(edge_index, edge_attr=None, reduce='mean'): |
| | """Convert to 'trimmed' graph: same as coalescing with the |
| | additional constraint that (i, j) and (j, i) edges are duplicates. |
| | |
| | If edge attributes are passed, 'reduce' will indicate how to fuse |
| | duplicate edges' attributes. |
| | |
| | NB: returned edges are expressed with i<j by default. |
| | |
| | :param edge_index: 2xE LongTensor |
| | Edges in `torch_geometric` format |
| | :param edge_attr: ExC Tensor |
| | Edge attributes |
| | :param reduce: str |
| | Reduction modes supported by `torch_geometric.utils.coalesce` |
| | :return: |
| | """ |
| | |
| | |
| | |
| | s_larger_t = edge_index[0] > edge_index[1] |
| | edge_index[:, s_larger_t] = edge_index[:, s_larger_t].flip(0) |
| |
|
| | |
| | if edge_attr is None: |
| | edge_index = coalesce(edge_index) |
| | else: |
| | edge_index, edge_attr = coalesce( |
| | edge_index, edge_attr=edge_attr, reduce=reduce) |
| |
|
| | |
| | edge_index, edge_attr = remove_self_loops( |
| | edge_index, edge_attr=edge_attr) |
| |
|
| | if edge_attr is None: |
| | return edge_index |
| | return edge_index, edge_attr |
| |
|
| |
|
| | def is_trimmed(edge_index, return_trimmed=False): |
| | """Check if the graph is 'trimmed': same as coalescing with the |
| | additional constraint that (i, j) and (j, i) edges are duplicates. |
| | |
| | :param edge_index: 2xE LongTensor |
| | Edges in `torch_geometric` format |
| | :param return_trimmed: bool |
| | If True, the trimmed graph will also be returned. Since checking |
| | if the graph is trimmed requires computing the actual trimmed |
| | graph, this may save some compute in certain situations |
| | :return: |
| | """ |
| | edge_index_trimmed = to_trimmed(edge_index) |
| | trimmed = edge_index.shape == edge_index_trimmed.shape |
| | if return_trimmed: |
| | return trimmed, edge_index_trimmed |
| | return trimmed |
| |
|