Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import warp as wp | |
| wp.config.quiet = True | |
| def _bvh_query_distance( | |
| mesh_id: wp.uint64, | |
| points: wp.array(dtype=wp.vec3f), | |
| max_dist: wp.float32, | |
| sdf: wp.array(dtype=wp.float32), | |
| sdf_hit_point: wp.array(dtype=wp.vec3f), | |
| use_sign_winding_number: bool = False, | |
| ): | |
| """ | |
| Computes the signed distance from each point in the given array `points` | |
| to the mesh represented by `mesh`,within the maximum distance `max_dist`, | |
| and stores the result in the array `sdf`. | |
| Parameters: | |
| mesh (wp.uint64): The identifier of the mesh. | |
| points (wp.array): An array of 3D points for which to compute the | |
| signed distance. | |
| max_dist (wp.float32): The maximum distance within which to search | |
| for the closest point on the mesh. | |
| sdf (wp.array): An array to store the computed signed distances. | |
| sdf_hit_point (wp.array): An array to store the computed hit points. | |
| sdf_hit_point_id (wp.array): An array to store the computed hit point ids. | |
| use_sign_winding_number (bool): Flag to use sign_winding_number method for SDF. | |
| Returns: | |
| None | |
| """ | |
| tid = wp.tid() | |
| if use_sign_winding_number: | |
| res = wp.mesh_query_point_sign_winding_number(mesh_id, points[tid], max_dist) | |
| else: | |
| res = wp.mesh_query_point_sign_normal(mesh_id, points[tid], max_dist) | |
| mesh = wp.mesh_get(mesh_id) | |
| p0 = mesh.points[mesh.indices[3 * res.face + 0]] | |
| p1 = mesh.points[mesh.indices[3 * res.face + 1]] | |
| p2 = mesh.points[mesh.indices[3 * res.face + 2]] | |
| p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 | |
| sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) | |
| sdf_hit_point[tid] = p_closest | |
| def signed_distance_field( | |
| mesh_vertices: torch.Tensor, | |
| mesh_indices: torch.Tensor, | |
| input_points: torch.Tensor, | |
| max_dist: float = 1e8, | |
| use_sign_winding_number: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Computes the signed distance field (SDF) for a given mesh and input points. | |
| The mesh must be a surface mesh consisting of all triangles. Uses NVIDIA | |
| Warp for GPU acceleration. | |
| Parameters: | |
| ---------- | |
| mesh_vertices (np.ndarray): Coordinates of the vertices of the mesh; | |
| shape: (n_vertices, 3) | |
| mesh_indices (np.ndarray): Indices corresponding to the faces of the | |
| mesh; shape: (n_faces, 3) | |
| input_points (np.ndarray): Coordinates of the points for which to | |
| compute the SDF; shape: (n_points, 3) | |
| max_dist (float, optional): Maximum distance within which | |
| to search for the closest point on the mesh. Default is 1e8. | |
| include_hit_points (bool, optional): Whether to include hit points in | |
| the output. Here, | |
| use_sign_winding_number (bool, optional): Whether to use sign winding | |
| number method for SDF. Default is False. If False, your mesh should | |
| be watertight to obtain correct results. | |
| return_cupy (bool, optional): Whether to return a CuPy array. Default is | |
| None, which means the function will automatically determine the | |
| appropriate return type based on the input types. | |
| Returns: | |
| ------- | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor] of: | |
| - signed distance to the mesh, per input point | |
| - hit point, per input point. "hit points" are the points on the | |
| mesh that are closest to the input points, and hence, are | |
| defining the SDF. | |
| Example: | |
| ------- | |
| >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] | |
| >>> mesh_indices = torch.tensor((0, 1, 2)) | |
| >>> input_points = torch.tensor((0.5, 0.5, 0.5)) | |
| >>> signed_distance_field(mesh_vertices, mesh_indices, input_points) | |
| (tensor([0.5]), tensor([0.5, 0.5, 0.5])) | |
| """ | |
| if input_points.shape[-1] != 3: | |
| raise ValueError("Input points must be a tensor with last dimension of size 3") | |
| input_shape = input_points.shape | |
| # Flatten the input points: | |
| input_points = input_points.reshape(-1, 3) | |
| N = len(input_points) | |
| # Allocate output tensors with torch: | |
| sdf = torch.zeros(N, dtype=torch.float32, device=input_points.device) | |
| sdf_hit_point = torch.zeros(N, 3, dtype=torch.float32, device=input_points.device) | |
| if input_points.device.type == "cuda": | |
| wp_launch_stream = wp.stream_from_torch( | |
| torch.cuda.current_stream(input_points.device) | |
| ) | |
| wp_launch_device = None # We explicitly pass None if using the stream. | |
| else: | |
| wp_launch_stream = None | |
| wp_launch_device = "cpu" # CPUs have no streams | |
| with wp.ScopedStream(wp_launch_stream): | |
| wp.init() | |
| # zero copy the vertices, indices, and input points to warp: | |
| wp_vertices = wp.from_torch(mesh_vertices.to(torch.float32), dtype=wp.vec3) | |
| wp_indices = wp.from_torch(mesh_indices.to(torch.int32), dtype=wp.int32) | |
| wp_input_points = wp.from_torch(input_points.to(torch.float32), dtype=wp.vec3) | |
| # Convert output points: | |
| wp_sdf = wp.from_torch(sdf, dtype=wp.float32) | |
| wp_sdf_hit_point = wp.from_torch(sdf_hit_point, dtype=wp.vec3f) | |
| mesh = wp.Mesh( | |
| points=wp_vertices, | |
| indices=wp_indices, | |
| support_winding_number=use_sign_winding_number, | |
| ) | |
| wp.launch( | |
| kernel=_bvh_query_distance, | |
| dim=N, | |
| inputs=[ | |
| mesh.id, | |
| wp_input_points, | |
| max_dist, | |
| wp_sdf, | |
| wp_sdf_hit_point, | |
| use_sign_winding_number, | |
| ], | |
| device=wp_launch_device, | |
| stream=wp_launch_stream, | |
| ) | |
| # Unflatten the output to be like the input: | |
| sdf = sdf.reshape(input_shape[:-1]) | |
| sdf_hit_point = sdf_hit_point.reshape(input_shape) | |
| return sdf.to(input_points.dtype), sdf_hit_point.to(input_points.dtype) | |
| def _( | |
| mesh_vertices: torch.Tensor, | |
| mesh_indices: torch.Tensor, | |
| input_points: torch.Tensor, | |
| max_dist: float = 1e8, | |
| use_sign_winding_number: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if mesh_vertices.device != input_points.device: | |
| raise RuntimeError("mesh_vertices and input_points must be on the same device") | |
| if mesh_vertices.device != mesh_indices.device: | |
| raise RuntimeError("mesh_vertices and mesh_indices must be on the same device") | |
| N = input_points.shape[0] | |
| sdf_output = torch.empty(N, 1, device=input_points.device, dtype=input_points.dtype) | |
| sdf_hit_point_output = torch.empty( | |
| N, 3, device=input_points.device, dtype=input_points.dtype | |
| ) | |
| return sdf_output, sdf_hit_point_output | |