|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, TYPE_CHECKING, Union
|
|
|
import torch
|
|
|
|
|
|
|
|
|
def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Helper function for torch.gather to collect the points at
|
|
|
the given indices in idx where some of the indices might be -1 to
|
|
|
indicate padding. These indices are first replaced with 0.
|
|
|
Then the points are gathered after which the padded values
|
|
|
are set to 0.0.
|
|
|
Args:
|
|
|
points: (N, P, D) float32 tensor of points
|
|
|
idx: (N, K) or (N, P, K) long tensor of indices into points, where
|
|
|
some indices are -1 to indicate padding
|
|
|
Returns:
|
|
|
selected_points: (N, K, D) float32 tensor of points
|
|
|
at the given indices
|
|
|
"""
|
|
|
|
|
|
if len(idx) != len(points):
|
|
|
raise ValueError("points and idx must have the same batch dimension")
|
|
|
|
|
|
N, P, D = points.shape
|
|
|
|
|
|
if idx.ndim == 3:
|
|
|
|
|
|
|
|
|
|
|
|
K = idx.shape[2]
|
|
|
|
|
|
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
|
|
|
points = points[:, :, None, :].expand(-1, -1, K, -1)
|
|
|
elif idx.ndim == 2:
|
|
|
|
|
|
idx_expanded = idx[..., None].expand(-1, -1, D)
|
|
|
else:
|
|
|
raise ValueError("idx format is not supported %s" % repr(idx.shape))
|
|
|
|
|
|
idx_expanded_mask = idx_expanded.eq(-1)
|
|
|
idx_expanded = idx_expanded.clone()
|
|
|
|
|
|
idx_expanded[idx_expanded_mask] = 0
|
|
|
|
|
|
selected_points = points.gather(dim=1, index=idx_expanded)
|
|
|
|
|
|
selected_points[idx_expanded_mask] = 0.0
|
|
|
return selected_points |