English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from src.utils.sparse import indices_to_pointers
from src.utils.tensor import arange_interleave
__all__ = ['edge_index_to_uid', 'edge_wise_points']
def edge_index_to_uid(edge_index):
"""Compute consecutive unique identifiers for the edges. This may be
needed for scatter operations.
"""
assert edge_index.dim() == 2
assert edge_index.shape[0] == 2
source = edge_index[0]
target = edge_index[1]
edge_uid = source * (max(source.max(), target.max()) + 1) + target
edge_uid = consecutive_cluster(edge_uid)[0]
return edge_uid
def edge_wise_points(points, index, edge_index):
"""Given a graph of point segments, compute the concatenation of
points belonging to either source or target segments for each edge
of the segment graph. This operation arises when dealing with
pairwise relationships between point segments.
Warning: the output tensors might be memory-intensive
:param points: (N, D) tensor
Points
:param index: (N) LongTensor
Segment index, for each point
:param edge_index: (2, E) LongTensor
Edges of the segment graph
"""
assert points.dim() == 2
assert index.dim() == 1
assert points.shape[0] == index.shape[0]
assert edge_index.dim() == 2
assert edge_index.shape[0] == 2
assert edge_index.max() <= index.max()
# We define the segments in the first row of edge_index as 'source'
# segments, while the elements of the second row are 'target'
# segments. The corresponding variables are prepended with 's_' and
# 't_' for clarity
s_idx = edge_index[0]
t_idx = edge_index[1]
# Compute consecutive unique identifiers for the edges
uid = edge_index_to_uid(edge_index)
# Compute the pointers and ordering to express the segments and the
# points they hold in CSR format
pointers, order = indices_to_pointers(index)
# Compute the size of each segment
segment_size = index.bincount()
# Expand the edge variables to point-edge values. That is, the
# concatenation of all the source -or target- points for each edge.
# The corresponding variables are prepended with 'S_' and 'T_' for
# clarity
def expand(source=True):
x_idx = s_idx if source else t_idx
size = segment_size[x_idx]
start = pointers[:-1][x_idx]
X_points_idx = order[arange_interleave(size, start=start)]
X_points = points[X_points_idx]
X_uid = uid.repeat_interleave(size, dim=0)
return X_points, X_points_idx, X_uid
S_points, S_points_idx, S_uid = expand(source=True)
T_points, T_points_idx, T_uid = expand(source=False)
return (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid)