Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
raw
history blame
2.12 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, TYPE_CHECKING, Union
import torch
def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
Helper function for torch.gather to collect the points at
the given indices in idx where some of the indices might be -1 to
indicate padding. These indices are first replaced with 0.
Then the points are gathered after which the padded values
are set to 0.0.
Args:
points: (N, P, D) float32 tensor of points
idx: (N, K) or (N, P, K) long tensor of indices into points, where
some indices are -1 to indicate padding
Returns:
selected_points: (N, K, D) float32 tensor of points
at the given indices
"""
if len(idx) != len(points):
raise ValueError("points and idx must have the same batch dimension")
N, P, D = points.shape
if idx.ndim == 3:
# Case: KNN, Ball Query where idx is of shape (N, P', K)
# where P' is not necessarily the same as P as the
# points may be gathered from a different pointcloud.
K = idx.shape[2]
# Match dimensions for points and indices
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
points = points[:, :, None, :].expand(-1, -1, K, -1)
elif idx.ndim == 2:
# Farthest point sampling where idx is of shape (N, K)
idx_expanded = idx[..., None].expand(-1, -1, D)
else:
raise ValueError("idx format is not supported %s" % repr(idx.shape))
idx_expanded_mask = idx_expanded.eq(-1)
idx_expanded = idx_expanded.clone()
# Replace -1 values with 0 for gather
idx_expanded[idx_expanded_mask] = 0
# Gather points
selected_points = points.gather(dim=1, index=idx_expanded)
# Replace padded values
selected_points[idx_expanded_mask] = 0.0
return selected_points