| | from torch_geometric.nn.pool.consecutive import consecutive_cluster |
| | from src.utils.sparse import indices_to_pointers |
| | from src.utils.tensor import arange_interleave |
| |
|
| |
|
| | __all__ = ['edge_index_to_uid', 'edge_wise_points'] |
| |
|
| |
|
| | def edge_index_to_uid(edge_index): |
| | """Compute consecutive unique identifiers for the edges. This may be |
| | needed for scatter operations. |
| | """ |
| | assert edge_index.dim() == 2 |
| | assert edge_index.shape[0] == 2 |
| | source = edge_index[0] |
| | target = edge_index[1] |
| | edge_uid = source * (max(source.max(), target.max()) + 1) + target |
| | edge_uid = consecutive_cluster(edge_uid)[0] |
| | return edge_uid |
| |
|
| |
|
| | def edge_wise_points(points, index, edge_index): |
| | """Given a graph of point segments, compute the concatenation of |
| | points belonging to either source or target segments for each edge |
| | of the segment graph. This operation arises when dealing with |
| | pairwise relationships between point segments. |
| | |
| | Warning: the output tensors might be memory-intensive |
| | |
| | :param points: (N, D) tensor |
| | Points |
| | :param index: (N) LongTensor |
| | Segment index, for each point |
| | :param edge_index: (2, E) LongTensor |
| | Edges of the segment graph |
| | """ |
| | assert points.dim() == 2 |
| | assert index.dim() == 1 |
| | assert points.shape[0] == index.shape[0] |
| | assert edge_index.dim() == 2 |
| | assert edge_index.shape[0] == 2 |
| | assert edge_index.max() <= index.max() |
| |
|
| | |
| | |
| | |
| | |
| | s_idx = edge_index[0] |
| | t_idx = edge_index[1] |
| |
|
| | |
| | uid = edge_index_to_uid(edge_index) |
| |
|
| | |
| | |
| | pointers, order = indices_to_pointers(index) |
| |
|
| | |
| | segment_size = index.bincount() |
| |
|
| | |
| | |
| | |
| | |
| | def expand(source=True): |
| | x_idx = s_idx if source else t_idx |
| | size = segment_size[x_idx] |
| | start = pointers[:-1][x_idx] |
| | X_points_idx = order[arange_interleave(size, start=start)] |
| | X_points = points[X_points_idx] |
| | X_uid = uid.repeat_interleave(size, dim=0) |
| | return X_points, X_points_idx, X_uid |
| |
|
| | S_points, S_points_idx, S_uid = expand(source=True) |
| | T_points, T_points_idx, T_uid = expand(source=False) |
| |
|
| | return (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) |
| |
|