ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import warp as wp
wp.config.quiet = True
@wp.kernel
def _bvh_query_distance(
mesh_id: wp.uint64,
points: wp.array(dtype=wp.vec3f),
max_dist: wp.float32,
sdf: wp.array(dtype=wp.float32),
sdf_hit_point: wp.array(dtype=wp.vec3f),
use_sign_winding_number: bool = False,
):
"""
Computes the signed distance from each point in the given array `points`
to the mesh represented by `mesh`,within the maximum distance `max_dist`,
and stores the result in the array `sdf`.
Parameters:
mesh (wp.uint64): The identifier of the mesh.
points (wp.array): An array of 3D points for which to compute the
signed distance.
max_dist (wp.float32): The maximum distance within which to search
for the closest point on the mesh.
sdf (wp.array): An array to store the computed signed distances.
sdf_hit_point (wp.array): An array to store the computed hit points.
sdf_hit_point_id (wp.array): An array to store the computed hit point ids.
use_sign_winding_number (bool): Flag to use sign_winding_number method for SDF.
Returns:
None
"""
tid = wp.tid()
if use_sign_winding_number:
res = wp.mesh_query_point_sign_winding_number(mesh_id, points[tid], max_dist)
else:
res = wp.mesh_query_point_sign_normal(mesh_id, points[tid], max_dist)
mesh = wp.mesh_get(mesh_id)
p0 = mesh.points[mesh.indices[3 * res.face + 0]]
p1 = mesh.points[mesh.indices[3 * res.face + 1]]
p2 = mesh.points[mesh.indices[3 * res.face + 2]]
p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2
sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest))
sdf_hit_point[tid] = p_closest
@torch.library.custom_op("physicsnemo::signed_distance_field", mutates_args=())
def signed_distance_field(
mesh_vertices: torch.Tensor,
mesh_indices: torch.Tensor,
input_points: torch.Tensor,
max_dist: float = 1e8,
use_sign_winding_number: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes the signed distance field (SDF) for a given mesh and input points.
The mesh must be a surface mesh consisting of all triangles. Uses NVIDIA
Warp for GPU acceleration.
Parameters:
----------
mesh_vertices (np.ndarray): Coordinates of the vertices of the mesh;
shape: (n_vertices, 3)
mesh_indices (np.ndarray): Indices corresponding to the faces of the
mesh; shape: (n_faces, 3)
input_points (np.ndarray): Coordinates of the points for which to
compute the SDF; shape: (n_points, 3)
max_dist (float, optional): Maximum distance within which
to search for the closest point on the mesh. Default is 1e8.
include_hit_points (bool, optional): Whether to include hit points in
the output. Here,
use_sign_winding_number (bool, optional): Whether to use sign winding
number method for SDF. Default is False. If False, your mesh should
be watertight to obtain correct results.
return_cupy (bool, optional): Whether to return a CuPy array. Default is
None, which means the function will automatically determine the
appropriate return type based on the input types.
Returns:
-------
Returns:
tuple[torch.Tensor, torch.Tensor] of:
- signed distance to the mesh, per input point
- hit point, per input point. "hit points" are the points on the
mesh that are closest to the input points, and hence, are
defining the SDF.
Example:
-------
>>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)]
>>> mesh_indices = torch.tensor((0, 1, 2))
>>> input_points = torch.tensor((0.5, 0.5, 0.5))
>>> signed_distance_field(mesh_vertices, mesh_indices, input_points)
(tensor([0.5]), tensor([0.5, 0.5, 0.5]))
"""
if input_points.shape[-1] != 3:
raise ValueError("Input points must be a tensor with last dimension of size 3")
input_shape = input_points.shape
# Flatten the input points:
input_points = input_points.reshape(-1, 3)
N = len(input_points)
# Allocate output tensors with torch:
sdf = torch.zeros(N, dtype=torch.float32, device=input_points.device)
sdf_hit_point = torch.zeros(N, 3, dtype=torch.float32, device=input_points.device)
if input_points.device.type == "cuda":
wp_launch_stream = wp.stream_from_torch(
torch.cuda.current_stream(input_points.device)
)
wp_launch_device = None # We explicitly pass None if using the stream.
else:
wp_launch_stream = None
wp_launch_device = "cpu" # CPUs have no streams
with wp.ScopedStream(wp_launch_stream):
wp.init()
# zero copy the vertices, indices, and input points to warp:
wp_vertices = wp.from_torch(mesh_vertices.to(torch.float32), dtype=wp.vec3)
wp_indices = wp.from_torch(mesh_indices.to(torch.int32), dtype=wp.int32)
wp_input_points = wp.from_torch(input_points.to(torch.float32), dtype=wp.vec3)
# Convert output points:
wp_sdf = wp.from_torch(sdf, dtype=wp.float32)
wp_sdf_hit_point = wp.from_torch(sdf_hit_point, dtype=wp.vec3f)
mesh = wp.Mesh(
points=wp_vertices,
indices=wp_indices,
support_winding_number=use_sign_winding_number,
)
wp.launch(
kernel=_bvh_query_distance,
dim=N,
inputs=[
mesh.id,
wp_input_points,
max_dist,
wp_sdf,
wp_sdf_hit_point,
use_sign_winding_number,
],
device=wp_launch_device,
stream=wp_launch_stream,
)
# Unflatten the output to be like the input:
sdf = sdf.reshape(input_shape[:-1])
sdf_hit_point = sdf_hit_point.reshape(input_shape)
return sdf.to(input_points.dtype), sdf_hit_point.to(input_points.dtype)
@signed_distance_field.register_fake
def _(
mesh_vertices: torch.Tensor,
mesh_indices: torch.Tensor,
input_points: torch.Tensor,
max_dist: float = 1e8,
use_sign_winding_number: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if mesh_vertices.device != input_points.device:
raise RuntimeError("mesh_vertices and input_points must be on the same device")
if mesh_vertices.device != mesh_indices.device:
raise RuntimeError("mesh_vertices and mesh_indices must be on the same device")
N = input_points.shape[0]
sdf_output = torch.empty(N, 1, device=input_points.device, dtype=input_points.dtype)
sdf_hit_point_output = torch.empty(
N, 3, device=input_points.device, dtype=input_points.dtype
)
return sdf_output, sdf_hit_point_output