from collections import namedtuple import torch KNNResult = namedtuple("KNNResult", ["dists", "idx", "knn"]) BallQueryResult = namedtuple("BallQueryResult", ["dists", "idx", "knn"]) def knn_points(p1, p2, lengths1=None, lengths2=None, K=1, norm=2, return_nn=False, return_sorted=True): if p1.dim() != 3 or p2.dim() != 3: raise ValueError("p1 and p2 must have shape (N, P, D)") if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]: raise ValueError("p1 and p2 must have matching batch and point dimensions") batch, p1_count, _ = p1.shape p2_count = p2.shape[1] k = min(K, p2_count) if k <= 0: empty_dists = p1.new_empty((batch, p1_count, 0)) empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device) empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None return KNNResult(empty_dists, empty_idx, empty_nn) dists = torch.cdist(p1.float(), p2.float(), p=norm) if norm == 2: dists = dists.square() if lengths2 is not None: arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count) valid = arange < lengths2.to(device=p2.device).view(batch, 1, 1) dists = dists.masked_fill(~valid, torch.inf) dists_k, idx = torch.topk(dists, k=k, dim=-1, largest=False, sorted=return_sorted) if K > k: pad = K - k dists_k = torch.cat([dists_k, dists_k.new_full((batch, p1_count, pad), torch.inf)], dim=-1) idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1) if lengths1 is not None: arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1) valid = arange < lengths1.to(device=p1.device).view(batch, 1, 1) dists_k = dists_k.masked_fill(~valid, torch.inf) idx = idx.masked_fill(~valid, -1) knn = None if return_nn: safe_idx = idx.clamp_min(0) gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2]) points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1) knn = torch.gather(points, 2, gather_idx) knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0) return KNNResult(dists_k.to(dtype=p1.dtype), idx, knn) def ball_query(p1, p2, lengths1=None, lengths2=None, K=1, radius=0.2, return_nn=False): if p1.dim() != 3 or p2.dim() != 3: raise ValueError("p1 and p2 must have shape (N, P, D)") if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]: raise ValueError("p1 and p2 must have matching batch and point dimensions") batch, p1_count, _ = p1.shape p2_count = p2.shape[1] k = max(int(K), 0) if k == 0: empty_dists = p1.new_empty((batch, p1_count, 0)) empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device) empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None return BallQueryResult(empty_dists, empty_idx, empty_nn) dists = torch.cdist(p1.float(), p2.float(), p=2).square() max_dist = float(radius) * float(radius) valid = dists <= max_dist if lengths2 is not None: arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count) valid = valid & (arange < lengths2.to(device=p2.device).view(batch, 1, 1)) masked = dists.masked_fill(~valid, torch.inf) take = min(k, p2_count) dists_k, idx = torch.topk(masked, k=take, dim=-1, largest=False, sorted=True) invalid = torch.isinf(dists_k) idx = idx.masked_fill(invalid, -1) dists_k = dists_k.masked_fill(invalid, 0) if k > take: pad = k - take dists_k = torch.cat([dists_k, dists_k.new_zeros((batch, p1_count, pad))], dim=-1) idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1) if lengths1 is not None: arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1) valid_p1 = arange < lengths1.to(device=p1.device).view(batch, 1, 1) dists_k = dists_k.masked_fill(~valid_p1, 0) idx = idx.masked_fill(~valid_p1, -1) knn = None if return_nn: safe_idx = idx.clamp_min(0) gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2]) points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1) knn = torch.gather(points, 2, gather_idx) knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0) return BallQueryResult(dists_k.to(dtype=p1.dtype), idx, knn)