ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This layer is a compilable, ball-query operation.
By default, it will project a grid of points to a 1D set of points.
It does not support batch size > 1.
"""
import torch
import torch.nn as nn
from einops import rearrange
from physicsnemo.utils.neighbors import radius_search
class BQWarp(nn.Module):
"""
Warp-based ball-query layer for finding neighboring points within a specified radius.
This layer uses an accelerated ball query implementation to efficiently find points
within a specified radius of query points.
Only supports batch size 1.
"""
def __init__(
self,
radius: float = 0.25,
neighbors_in_radius: int | None = 10,
):
"""
Initialize the BQWarp layer.
Args:
radius: Radius for ball query operation
neighbors_in_radius: Maximum number of neighbors to return within radius. If None, all neighbors will be returned.
"""
super().__init__()
self.radius = radius
self.neighbors_in_radius = neighbors_in_radius
def forward(
self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Performs ball query operation to find neighboring points and their features.
This method uses the Warp-accelerated ball query implementation to find points
within a specified radius. It can operate in two modes:
- Forward mapping: Find points from x that are near p_grid points (reverse_mapping=False)
- Reverse mapping: Find points from p_grid that are near x points (reverse_mapping=True)
Args:
x: Tensor of shape (batch_size, num_points, 3+features) containing point coordinates
and their features
p_grid: Tensor of shape (batch_size, grid_x, grid_y, grid_z, 3) containing grid point
coordinates
reverse_mapping: Boolean flag to control the direction of the mapping:
- True: Find p_grid points near x points
- False: Find x points near p_grid points
Returns:
tuple containing:
- mapping: Tensor containing indices of neighboring points
- outputs: Tensor containing coordinates of the neighboring points
"""
if x.shape[0] != 1 or p_grid.shape[0] != 1:
raise ValueError("BQWarp only supports batch size 1")
if p_grid.shape[-1] != x.shape[-1] or x.shape[-1] != 3:
raise ValueError("The last dimension of p_grid and x must be 3")
if p_grid.ndim != 3:
if p_grid.ndim == 4:
p_grid = rearrange(p_grid, "b nx ny c -> b (nx ny) c")
elif p_grid.ndim == 5:
p_grid = rearrange(p_grid, "b nx ny nz c -> b (nx ny nz) c")
else:
raise ValueError("p_grid must be 3D, 4D, 5D only")
if reverse_mapping:
mapping, outputs = radius_search(
x[0],
p_grid[0],
self.radius,
self.neighbors_in_radius,
return_points=True,
)
mapping = mapping.unsqueeze(0)
outputs = outputs.unsqueeze(0)
else:
mapping, outputs = radius_search(
p_grid[0],
x[0],
self.radius,
self.neighbors_in_radius,
return_points=True,
)
mapping = mapping.unsqueeze(0)
outputs = outputs.unsqueeze(0)
return mapping, outputs