|
|
import math |
|
|
import torch |
|
|
from torch_scatter import scatter_add, scatter_mean, scatter_min |
|
|
from itertools import combinations_with_replacement |
|
|
from src.utils.edge import edge_wise_points |
|
|
from torch_geometric.utils import coalesce |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'scatter_mean_weighted', 'scatter_pca', 'scatter_nearest_neighbor', |
|
|
'idx_preserving_mask', 'scatter_mean_orientation'] |
|
|
|
|
|
|
|
|
def scatter_mean_weighted(x, idx, w, dim_size=None): |
|
|
"""Helper for scatter_mean with weights""" |
|
|
assert w.ge(0).all(), "Only positive weights are accepted" |
|
|
assert w.dim() == idx.dim() == 1, "w and idx should be 1D Tensors" |
|
|
assert x.shape[0] == w.shape[0] == idx.shape[0], \ |
|
|
"Only supports weighted mean along the first dimension" |
|
|
|
|
|
|
|
|
x = x.view(-1, 1) if x.dim() == 1 else x |
|
|
w = w.view(-1, 1).float() |
|
|
wx = torch.cat((w, x * w), dim=1) |
|
|
|
|
|
|
|
|
wx_segment = scatter_add(wx, idx, dim=0, dim_size=dim_size) |
|
|
|
|
|
|
|
|
w_segment = wx_segment[:, 0] |
|
|
x_segment = wx_segment[:, 1:] |
|
|
w_segment[w_segment == 0] = 1 |
|
|
mean_segment = x_segment / w_segment.view(-1, 1) |
|
|
|
|
|
return mean_segment |
|
|
|
|
|
|
|
|
def scatter_pca(x, idx, on_cpu=True): |
|
|
"""Scatter implementation for PCA. |
|
|
|
|
|
Returns eigenvalues and eigenvectors for each group in idx. |
|
|
If x has shape N1xD and idx covers indices in [0, N2], the |
|
|
eigenvalues will have shape N2xD and the eigenvectors will |
|
|
have shape N2xDxD. The eigenvalues and eigenvectors are |
|
|
sorted by increasing eigenvalue. |
|
|
""" |
|
|
assert idx.dim() == 1 |
|
|
assert x.dim() == 2 |
|
|
assert idx.shape[0] == x.shape[0] |
|
|
assert x.shape[1] > 1 |
|
|
|
|
|
d = x.shape[1] |
|
|
device = x.device |
|
|
|
|
|
|
|
|
mean = scatter_mean(x, idx, dim=0) |
|
|
x = x - mean[idx] |
|
|
|
|
|
|
|
|
ij = torch.tensor(list(combinations_with_replacement(range(d), 2)), device=device) |
|
|
upper_triangle = x[:, ij[:, 0]] * x[:, ij[:, 1]] |
|
|
|
|
|
|
|
|
|
|
|
upper_triangle = scatter_add(upper_triangle, idx, dim=0) / d |
|
|
cov = torch.empty((upper_triangle.shape[0], d, d), device=device) |
|
|
cov[:, ij[:, 0], ij[:, 1]] = upper_triangle |
|
|
|
|
|
|
|
|
if on_cpu: |
|
|
device = cov.device |
|
|
cov = cov.cpu() |
|
|
eval, evec = torch.linalg.eigh(cov, UPLO='U') |
|
|
eval = eval.to(device) |
|
|
evec = evec.to(device) |
|
|
else: |
|
|
eval, evec = torch.linalg.eigh(cov, UPLO='U') |
|
|
|
|
|
|
|
|
|
|
|
idx_nan = torch.where(torch.logical_and( |
|
|
eval.isnan().any(1), evec.flatten(1).isnan().any(1))) |
|
|
eval[idx_nan] = torch.ones(3, dtype=eval.dtype, device=device) |
|
|
evec[idx_nan] = torch.eye(3, dtype=evec.dtype, device=device) |
|
|
|
|
|
|
|
|
|
|
|
eval[torch.where(eval < 0)] = 0 |
|
|
|
|
|
return eval, evec |
|
|
|
|
|
|
|
|
def scatter_nearest_neighbor( |
|
|
points, index, edge_index, cycles=3, chunk_size=None): |
|
|
"""For each pair of segments indicated in edge_index, find the 2 |
|
|
closest points between the two segments. |
|
|
|
|
|
NB: this is an approximate, iterative process. |
|
|
|
|
|
:param points: (N, D) tensor |
|
|
Points |
|
|
:param index: (N) LongTensor |
|
|
Segment index, for each point |
|
|
:param edge_index: (2, E) LongTensor |
|
|
Segment pairs for which to compute the nearest neighbors |
|
|
:param cycles int |
|
|
Number of iterations. Starting from a point X in set A, one |
|
|
cycle accounts for searching the nearest neighbor, in A, of the |
|
|
nearest neighbor of X in set B |
|
|
:param chunk_size: int, float |
|
|
Allows mitigating memory use when computing the neighbors. If |
|
|
`chunk_size > 1`, `edge_index` will be processed into chunks of |
|
|
`chunk_size`. If `0 < chunk_size < 1`, then `edge_index` will be |
|
|
divided into parts of `edge_index.shape[1] * chunk_size` or less |
|
|
""" |
|
|
assert edge_index.shape == coalesce(edge_index).shape, \ |
|
|
"Does not support duplicate edges, please coalesce the edges" \ |
|
|
" before calling this function" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if chunk_size is not None and chunk_size > 0: |
|
|
|
|
|
|
|
|
chunk_size = int(chunk_size) if chunk_size > 1 \ |
|
|
else math.ceil(edge_index.shape[1] * chunk_size) |
|
|
num_chunks = math.ceil(edge_index.shape[1] / chunk_size) |
|
|
out_list = [] |
|
|
for i_chunk in range(num_chunks): |
|
|
start = i_chunk * chunk_size |
|
|
end = (i_chunk + 1) * chunk_size |
|
|
out_list.append(scatter_nearest_neighbor( |
|
|
points, index, edge_index[:, start:end], cycles=cycles, |
|
|
chunk_size=None)) |
|
|
|
|
|
|
|
|
candidate = torch.cat([elt[0] for elt in out_list], dim=0) |
|
|
candidate_idx = torch.cat([elt[1] for elt in out_list], dim=1) |
|
|
|
|
|
return candidate, candidate_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s_idx = edge_index[0] |
|
|
t_idx = edge_index[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) = \ |
|
|
edge_wise_points(points, index, edge_index) |
|
|
|
|
|
|
|
|
segment_centroid = scatter_mean(points, index, dim=0) |
|
|
segment_size = index.bincount() |
|
|
s_candidate = segment_centroid[s_idx] |
|
|
t_candidate = segment_centroid[t_idx] |
|
|
s_candidate_idx = -torch.ones_like(s_idx) |
|
|
t_candidate_idx = -torch.ones_like(s_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def step(source=True): |
|
|
if source: |
|
|
x_idx, y_candidate, X_points, X_points_idx, X_uid = \ |
|
|
s_idx, t_candidate, S_points, S_points_idx, S_uid |
|
|
else: |
|
|
x_idx, y_candidate, X_points, X_points_idx, X_uid = \ |
|
|
t_idx, s_candidate, T_points, T_points_idx, T_uid |
|
|
|
|
|
|
|
|
size = segment_size[x_idx] |
|
|
Y_candidate = y_candidate.repeat_interleave(size, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
X_dist = (X_points - Y_candidate).norm(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, X_argmin = scatter_min(X_dist, X_uid) |
|
|
x_candidate_idx = X_points_idx[X_argmin] |
|
|
x_candidate = points[x_candidate_idx] |
|
|
|
|
|
return x_candidate, x_candidate_idx |
|
|
|
|
|
|
|
|
for _ in range(cycles): |
|
|
t_candidate, t_candidate_idx = step(source=False) |
|
|
s_candidate, s_candidate_idx = step(source=True) |
|
|
|
|
|
|
|
|
candidate = torch.vstack((s_candidate, t_candidate)) |
|
|
candidate_idx = torch.vstack((s_candidate_idx, t_candidate_idx)) |
|
|
|
|
|
return candidate, candidate_idx |
|
|
|
|
|
|
|
|
def idx_preserving_mask(mask, idx, dim=0): |
|
|
"""Helper to pass a boolean mask and an index, to make sure indexing |
|
|
using the mask will not entirely discard all elements of index. |
|
|
""" |
|
|
is_empty = scatter_add(mask.float(), idx, dim=dim) == 0 |
|
|
return mask | is_empty[idx] |
|
|
|
|
|
|
|
|
def scatter_mean_orientation(orientation, idx): |
|
|
"""Scatter implementation for mean normal orientation computation. |
|
|
When dealing with normals, we care more about the orientation than |
|
|
the sense. So normals are defined up to a sign. When computing the |
|
|
average normal across a set of points, we may run into issues. This |
|
|
method aims at computing the mean orientation, expressed in the Z+ |
|
|
halfspace by default. |
|
|
|
|
|
:param orientation: (N, D) tensor |
|
|
Orientations vectors. Do not need to be normalized but are |
|
|
assumed to be expressed with 0 as their origin |
|
|
:param idx: (N) LongTensor |
|
|
Group index, for each vector |
|
|
""" |
|
|
epsilon = 1e-4 |
|
|
|
|
|
|
|
|
x = orientation.detach().clone() |
|
|
|
|
|
|
|
|
x /= x.norm(dim=1).view(-1, 1).add_(epsilon) |
|
|
x = x.clamp(min=-1, max=1) |
|
|
|
|
|
|
|
|
phi = x[:, 2].arcsin() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
phi_mean = scatter_mean(phi, idx, dim=0) |
|
|
is_horizontal = (phi_mean < torch.pi / 4)[idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, argmin = scatter_min(phi, idx, dim=0) |
|
|
is_opposing = (x * x[argmin[idx]]).sum(dim=1) < 0 |
|
|
|
|
|
|
|
|
x[is_horizontal & is_opposing] *= -1 |
|
|
|
|
|
|
|
|
x_mean = scatter_mean(x, idx, dim=0) |
|
|
|
|
|
|
|
|
x_mean /= x_mean.norm(dim=1).view(-1, 1).add_(epsilon) |
|
|
x_mean = x_mean.clamp(min=-1, max=1) |
|
|
|
|
|
|
|
|
x_mean[x_mean[:, -1] < 0] *= -1 |
|
|
|
|
|
return x_mean |
|
|
|