|
|
|
|
|
|
|
|
|
|
|
|
| from typing import Optional, Tuple, TYPE_CHECKING, Union
|
| import torch
|
|
|
|
|
| def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
| """
|
| Helper function for torch.gather to collect the points at
|
| the given indices in idx where some of the indices might be -1 to
|
| indicate padding. These indices are first replaced with 0.
|
| Then the points are gathered after which the padded values
|
| are set to 0.0.
|
| Args:
|
| points: (N, P, D) float32 tensor of points
|
| idx: (N, K) or (N, P, K) long tensor of indices into points, where
|
| some indices are -1 to indicate padding
|
| Returns:
|
| selected_points: (N, K, D) float32 tensor of points
|
| at the given indices
|
| """
|
|
|
| if len(idx) != len(points):
|
| raise ValueError("points and idx must have the same batch dimension")
|
|
|
| N, P, D = points.shape
|
|
|
| if idx.ndim == 3:
|
|
|
|
|
|
|
| K = idx.shape[2]
|
|
|
| idx_expanded = idx[..., None].expand(-1, -1, -1, D)
|
| points = points[:, :, None, :].expand(-1, -1, K, -1)
|
| elif idx.ndim == 2:
|
|
|
| idx_expanded = idx[..., None].expand(-1, -1, D)
|
| else:
|
| raise ValueError("idx format is not supported %s" % repr(idx.shape))
|
|
|
| idx_expanded_mask = idx_expanded.eq(-1)
|
| idx_expanded = idx_expanded.clone()
|
|
|
| idx_expanded[idx_expanded_mask] = 0
|
|
|
| selected_points = points.gather(dim=1, index=idx_expanded)
|
|
|
| selected_points[idx_expanded_mask] = 0.0
|
| return selected_points |