Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This layer is a compilable, ball-query operation. | |
| By default, it will project a grid of points to a 1D set of points. | |
| It does not support batch size > 1. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from physicsnemo.utils.neighbors import radius_search | |
| class BQWarp(nn.Module): | |
| """ | |
| Warp-based ball-query layer for finding neighboring points within a specified radius. | |
| This layer uses an accelerated ball query implementation to efficiently find points | |
| within a specified radius of query points. | |
| Only supports batch size 1. | |
| """ | |
| def __init__( | |
| self, | |
| radius: float = 0.25, | |
| neighbors_in_radius: int | None = 10, | |
| ): | |
| """ | |
| Initialize the BQWarp layer. | |
| Args: | |
| radius: Radius for ball query operation | |
| neighbors_in_radius: Maximum number of neighbors to return within radius. If None, all neighbors will be returned. | |
| """ | |
| super().__init__() | |
| self.radius = radius | |
| self.neighbors_in_radius = neighbors_in_radius | |
| def forward( | |
| self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Performs ball query operation to find neighboring points and their features. | |
| This method uses the Warp-accelerated ball query implementation to find points | |
| within a specified radius. It can operate in two modes: | |
| - Forward mapping: Find points from x that are near p_grid points (reverse_mapping=False) | |
| - Reverse mapping: Find points from p_grid that are near x points (reverse_mapping=True) | |
| Args: | |
| x: Tensor of shape (batch_size, num_points, 3+features) containing point coordinates | |
| and their features | |
| p_grid: Tensor of shape (batch_size, grid_x, grid_y, grid_z, 3) containing grid point | |
| coordinates | |
| reverse_mapping: Boolean flag to control the direction of the mapping: | |
| - True: Find p_grid points near x points | |
| - False: Find x points near p_grid points | |
| Returns: | |
| tuple containing: | |
| - mapping: Tensor containing indices of neighboring points | |
| - outputs: Tensor containing coordinates of the neighboring points | |
| """ | |
| if x.shape[0] != 1 or p_grid.shape[0] != 1: | |
| raise ValueError("BQWarp only supports batch size 1") | |
| if p_grid.shape[-1] != x.shape[-1] or x.shape[-1] != 3: | |
| raise ValueError("The last dimension of p_grid and x must be 3") | |
| if p_grid.ndim != 3: | |
| if p_grid.ndim == 4: | |
| p_grid = rearrange(p_grid, "b nx ny c -> b (nx ny) c") | |
| elif p_grid.ndim == 5: | |
| p_grid = rearrange(p_grid, "b nx ny nz c -> b (nx ny nz) c") | |
| else: | |
| raise ValueError("p_grid must be 3D, 4D, 5D only") | |
| if reverse_mapping: | |
| mapping, outputs = radius_search( | |
| x[0], | |
| p_grid[0], | |
| self.radius, | |
| self.neighbors_in_radius, | |
| return_points=True, | |
| ) | |
| mapping = mapping.unsqueeze(0) | |
| outputs = outputs.unsqueeze(0) | |
| else: | |
| mapping, outputs = radius_search( | |
| p_grid[0], | |
| x[0], | |
| self.radius, | |
| self.neighbors_in_radius, | |
| return_points=True, | |
| ) | |
| mapping = mapping.unsqueeze(0) | |
| outputs = outputs.unsqueeze(0) | |
| return mapping, outputs | |