ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
import numpy as np
import torch
import torch.distributed as dist
from physicsnemo.utils.neighbors.knn._cuml_impl import knn_impl
from physicsnemo.utils.version_check import check_module_requirements
check_module_requirements("physicsnemo.distributed.shard_tensor")
from physicsnemo.distributed import ShardTensor # noqa: E402
from physicsnemo.distributed.shard_utils.patch_core import ( # noqa: E402
MissingShardPatch,
)
from physicsnemo.distributed.shard_utils.ring import ( # noqa: E402
RingPassingConfig,
perform_ring_iteration,
)
def ring_knn(
points: ShardTensor, queries: ShardTensor, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Ring based kNN implementation where points travel around a ring and queries stay local.
This function performs k-nearest neighbor search using a distributed ring-based
algorithm. The points are passed around different devices in a ring topology while
the queries remain local to each device.
Parameters
----------
points : ShardTensor
The point cloud data tensor that will be distributed around the ring.
Must be sharded on the same mesh as queries.
queries : ShardTensor
The query points tensor that stays local on each device.
Must be sharded on the same mesh as points.
k : int
Number of nearest neighbors to find for each query point.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- shard_idx : torch.Tensor
Indices of the k nearest neighbors for each query point
- shard_distances : torch.Tensor
Distances to the k nearest neighbors for each query point
Raises
------
NotImplementedError
If points and queries tensors are not sharded on the same mesh.
"""
# Each tensor has a _spec attribute, which contains information about the tensor's placement
# and the devices it lives on:
points_spec = points._spec
queries_spec = queries._spec
# ** In general ** you want to do some checking on the placements, since each
# point cloud might be sharded differently. By construction, I know they're both
# sharded along the points axis here (and not, say, replicated).
if not points_spec.mesh == queries_spec.mesh:
raise NotImplementedError("Tensors must be sharded on the same mesh")
mesh = points_spec.mesh
local_group = mesh.get_group(0)
local_size = dist.get_world_size(group=local_group)
mesh_rank = mesh.get_local_rank()
# points and queries are both sharded - and since we're returning the nearest
# neighbors to points, let's make sure the output keeps that sharding too.
# One memory-efficient way to do this is with with a ring computation.
# We'll compute the knn on the local tensors, get the distances and outputs,
# then shuffle the queries shards along the mesh.
# we'll need to sort the results and make sure we have just the top-k,
# which is a little extra computation.
# Physics nemo has a ring passing utility we can use.
ring_config = RingPassingConfig(
mesh_dim=0,
mesh_size=local_size,
ring_direction="forward",
communication_method="p2p",
)
local_points, local_queries = points.to_local(), queries.to_local()
current_dists = None
current_topk_idx = None
points_spec = points._spec
points_sharding_shapes = points_spec.sharding_shapes()[0]
sharding_dim = points_spec.placements[0].dim
# This is to help specify the offset from local to global tensor.
points_strides_along_ring = [s[sharding_dim] for s in points_sharding_shapes]
points_strides_along_ring = np.cumsum(points_strides_along_ring)
points_strides_along_ring = [
0,
] + list(points_strides_along_ring[0:-1])
for i in range(local_size):
source_rank = (mesh_rank - i) % local_size
# For point clouds, we need to pass the size of the incoming shard.
next_source_rank = (source_rank - 1) % local_size
recv_shape = points_sharding_shapes[next_source_rank]
if i != local_size - 1:
# Don't do a ring on the last iteration.
next_local_points = perform_ring_iteration(
local_points,
mesh,
ring_config,
recv_shape=recv_shape,
)
# Compute the knn on the local tensors:
local_idx, local_distances = knn_impl(local_points, local_queries, k)
# The local_idx indexes into the _local_ tensor, but for
# Correctness we need it to index into the _global_ tensor.
# Make sure to index using the rank the points came from!
offset = points_strides_along_ring[source_rank]
local_idx = local_idx + offset
if current_dists is None:
current_dists = local_distances
current_topk_idx = local_idx
else:
# Combine with the topk so far:
current_dists = torch.cat([current_dists, local_distances], dim=1)
current_topk_idx = torch.cat([current_topk_idx, local_idx], dim=1)
# And take the topk again:
current_dists, running_indexes = torch.topk(
current_dists, k=k, dim=1, sorted=True, largest=False
)
# This creates proper indexing to select specific elements along dim 1
current_topk_idx = torch.gather(current_topk_idx, 1, running_indexes)
if i != local_size - 1:
# Don't do a ring on the last iteration.
local_points = next_local_points
return current_topk_idx, current_dists
def extract_knn_args(
points: torch.Tensor, queries: torch.Tensor, k: int, *args, **kwargs
):
"""
Minimal function to use python's argument unpacking to extract the points, queries, and k values.
"""
return points, queries, k
def knn_sharded_wrapper(
func: Callable, types: Any, args: tuple, kwargs: dict
) -> tuple[ShardTensor, ShardTensor]:
"""
Dispatch the proper kNN tools based on the input sharding.
`args` and `kwargs` are passed to `extract_knn_args` to extract
the points, queries, and k values needed for the kNN operation.
Parameters
----------
func : Callable
The function to dispatch.
types : Any
The types of the inputs.
args : tuple
The positional arguments.
kwargs : dict
The keyword arguments.
Returns
-------
tuple[ShardTensor, ShardTensor]
A tuple containing the shard_idx and shard_distances.
Raises
------
MissingShardPatch
If the points and queries tensors are not sharded on the same mesh.
"""
points, queries, k = extract_knn_args(*args, **kwargs)
# kNN will only work with 1D sharding
if points._spec.mesh != queries._spec.mesh:
raise MissingShardPatch(
"sharded knn: All point inputs must be on the same mesh"
)
# make sure all meshes are 1D
if points._spec.mesh.ndim != 1:
raise MissingShardPatch(
"point_cloud_ops.radius_search_wrapper: All point inputs must be on 1D meshes"
)
# Do we need a ring?
points_placement = points._spec.placements[0]
if points_placement.is_shard():
# We need a ring
idx, distances = ring_knn(points, queries, k)
else:
# No ring is needed. Get the local tensors and compute directly:
local_points = points.to_local() # This is replicated, getting all of it
local_queries = queries.to_local() # This sharding doesn't matter!
idx, distances = knn_impl(local_points, local_queries, k)
# The outputs only depend on the local queries shape
input_queries_spec = queries._spec
# The global output tensor will be (N_q, k)
output_queries_shard_shapes = {}
for mesh_dim in input_queries_spec.sharding_shapes().keys():
shard_shapes = tuple(
torch.Size((s[0], k))
for s in input_queries_spec.sharding_shapes()[mesh_dim]
)
output_queries_shard_shapes[mesh_dim] = shard_shapes
# Convert the selected points and indexes to shards:
shard_idx = ShardTensor.from_local(
idx,
queries._spec.mesh,
queries._spec.placements,
sharding_shapes=output_queries_shard_shapes,
)
shard_distances = ShardTensor.from_local(
distances,
queries._spec.mesh,
queries._spec.placements,
sharding_shapes=output_queries_shard_shapes,
)
return shard_idx, shard_distances
ShardTensor.register_named_function_handler(
"physicsnemo.knn_cuml.default", knn_sharded_wrapper
)