Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import List | |
| import torch | |
| from torch import Tensor | |
| from torch import nn as nn | |
| from .furthest_point_sample import (furthest_point_sample, | |
| furthest_point_sample_with_dist) | |
| def calc_square_dist(point_feat_a: Tensor, | |
| point_feat_b: Tensor, | |
| norm: bool = True) -> Tensor: | |
| """Calculating square distance between a and b. | |
| Args: | |
| point_feat_a (torch.Tensor): (B, N, C) Feature vector of each point. | |
| point_feat_b (torch.Tensor): (B, M, C) Feature vector of each point. | |
| norm (bool, optional): Whether to normalize the distance. | |
| Default: True. | |
| Returns: | |
| torch.Tensor: (B, N, M) Square distance between each point pair. | |
| """ | |
| num_channel = point_feat_a.shape[-1] | |
| dist = torch.cdist(point_feat_a, point_feat_b) | |
| if norm: | |
| dist = dist / num_channel | |
| else: | |
| dist = torch.square(dist) | |
| return dist | |
| def get_sampler_cls(sampler_type: str) -> nn.Module: | |
| """Get the type and mode of points sampler. | |
| Args: | |
| sampler_type (str): The type of points sampler. | |
| The valid value are "D-FPS", "F-FPS", or "FS". | |
| Returns: | |
| class: Points sampler type. | |
| """ | |
| sampler_mappings = { | |
| 'D-FPS': DFPSSampler, | |
| 'F-FPS': FFPSSampler, | |
| 'FS': FSSampler, | |
| } | |
| try: | |
| return sampler_mappings[sampler_type] | |
| except KeyError: | |
| raise KeyError( | |
| f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \ | |
| {sampler_type}') | |
| class PointsSampler(nn.Module): | |
| """Points sampling. | |
| Args: | |
| num_point (list[int]): Number of sample points. | |
| fps_mod_list (list[str], optional): Type of FPS method, valid mod | |
| ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS']. | |
| F-FPS: using feature distances for FPS. | |
| D-FPS: using Euclidean distances of points for FPS. | |
| FS: using F-FPS and D-FPS simultaneously. | |
| fps_sample_range_list (list[int], optional): | |
| Range of points to apply FPS. Default: [-1]. | |
| """ | |
| def __init__(self, | |
| num_point: List[int], | |
| fps_mod_list: List[str] = ['D-FPS'], | |
| fps_sample_range_list: List[int] = [-1]) -> None: | |
| super().__init__() | |
| # FPS would be applied to different fps_mod in the list, | |
| # so the length of the num_point should be equal to | |
| # fps_mod_list and fps_sample_range_list. | |
| assert len(num_point) == len(fps_mod_list) == len( | |
| fps_sample_range_list) | |
| self.num_point = num_point | |
| self.fps_sample_range_list = fps_sample_range_list | |
| self.samplers = nn.ModuleList() | |
| for fps_mod in fps_mod_list: | |
| self.samplers.append(get_sampler_cls(fps_mod)()) | |
| self.fp16_enabled = False | |
| def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of | |
| the points. | |
| features (torch.Tensor): (B, C, N) features of the points. | |
| Returns: | |
| torch.Tensor: (B, npoint, sample_num) Indices of sampled points. | |
| """ | |
| if points_xyz.dtype == torch.half: | |
| points_xyz = points_xyz.to(torch.float32) | |
| if features is not None and features.dtype == torch.half: | |
| features = features.to(torch.float32) | |
| indices = [] | |
| last_fps_end_index = 0 | |
| for fps_sample_range, sampler, npoint in zip( | |
| self.fps_sample_range_list, self.samplers, self.num_point): | |
| assert fps_sample_range < points_xyz.shape[1] | |
| if fps_sample_range == -1: | |
| sample_points_xyz = points_xyz[:, last_fps_end_index:] | |
| if features is not None: | |
| sample_features = features[:, :, last_fps_end_index:] | |
| else: | |
| sample_features = None | |
| else: | |
| sample_points_xyz = points_xyz[:, last_fps_end_index: | |
| fps_sample_range] | |
| if features is not None: | |
| sample_features = features[:, :, last_fps_end_index: | |
| fps_sample_range] | |
| else: | |
| sample_features = None | |
| fps_idx = sampler(sample_points_xyz.contiguous(), sample_features, | |
| npoint) | |
| indices.append(fps_idx + last_fps_end_index) | |
| last_fps_end_index = fps_sample_range | |
| indices = torch.cat(indices, dim=1) | |
| return indices | |
| class DFPSSampler(nn.Module): | |
| """Using Euclidean distances of points for FPS.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: | |
| """Sampling points with D-FPS.""" | |
| fps_idx = furthest_point_sample(points.contiguous(), npoint) | |
| return fps_idx | |
| class FFPSSampler(nn.Module): | |
| """Using feature distances for FPS.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: | |
| """Sampling points with F-FPS.""" | |
| assert features is not None, \ | |
| 'feature input to FFPS_Sampler should not be None' | |
| features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2) | |
| features_dist = calc_square_dist( | |
| features_for_fps, features_for_fps, norm=False) | |
| fps_idx = furthest_point_sample_with_dist(features_dist, npoint) | |
| return fps_idx | |
| class FSSampler(nn.Module): | |
| """Using F-FPS and D-FPS simultaneously.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor: | |
| """Sampling points with FS_Sampling.""" | |
| assert features is not None, \ | |
| 'feature input to FS_Sampler should not be None' | |
| ffps_sampler = FFPSSampler() | |
| dfps_sampler = DFPSSampler() | |
| fps_idx_ffps = ffps_sampler(points, features, npoint) | |
| fps_idx_dfps = dfps_sampler(points, features, npoint) | |
| fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1) | |
| return fps_idx | |