| from collections import namedtuple |
|
|
| import torch |
|
|
|
|
| KNNResult = namedtuple("KNNResult", ["dists", "idx", "knn"]) |
| BallQueryResult = namedtuple("BallQueryResult", ["dists", "idx", "knn"]) |
|
|
|
|
| def knn_points(p1, p2, lengths1=None, lengths2=None, K=1, norm=2, return_nn=False, return_sorted=True): |
| if p1.dim() != 3 or p2.dim() != 3: |
| raise ValueError("p1 and p2 must have shape (N, P, D)") |
| if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]: |
| raise ValueError("p1 and p2 must have matching batch and point dimensions") |
|
|
| batch, p1_count, _ = p1.shape |
| p2_count = p2.shape[1] |
| k = min(K, p2_count) |
| if k <= 0: |
| empty_dists = p1.new_empty((batch, p1_count, 0)) |
| empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device) |
| empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None |
| return KNNResult(empty_dists, empty_idx, empty_nn) |
|
|
| dists = torch.cdist(p1.float(), p2.float(), p=norm) |
| if norm == 2: |
| dists = dists.square() |
|
|
| if lengths2 is not None: |
| arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count) |
| valid = arange < lengths2.to(device=p2.device).view(batch, 1, 1) |
| dists = dists.masked_fill(~valid, torch.inf) |
|
|
| dists_k, idx = torch.topk(dists, k=k, dim=-1, largest=False, sorted=return_sorted) |
|
|
| if K > k: |
| pad = K - k |
| dists_k = torch.cat([dists_k, dists_k.new_full((batch, p1_count, pad), torch.inf)], dim=-1) |
| idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1) |
|
|
| if lengths1 is not None: |
| arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1) |
| valid = arange < lengths1.to(device=p1.device).view(batch, 1, 1) |
| dists_k = dists_k.masked_fill(~valid, torch.inf) |
| idx = idx.masked_fill(~valid, -1) |
|
|
| knn = None |
| if return_nn: |
| safe_idx = idx.clamp_min(0) |
| gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2]) |
| points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1) |
| knn = torch.gather(points, 2, gather_idx) |
| knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0) |
|
|
| return KNNResult(dists_k.to(dtype=p1.dtype), idx, knn) |
|
|
|
|
| def ball_query(p1, p2, lengths1=None, lengths2=None, K=1, radius=0.2, return_nn=False): |
| if p1.dim() != 3 or p2.dim() != 3: |
| raise ValueError("p1 and p2 must have shape (N, P, D)") |
| if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]: |
| raise ValueError("p1 and p2 must have matching batch and point dimensions") |
|
|
| batch, p1_count, _ = p1.shape |
| p2_count = p2.shape[1] |
| k = max(int(K), 0) |
| if k == 0: |
| empty_dists = p1.new_empty((batch, p1_count, 0)) |
| empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device) |
| empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None |
| return BallQueryResult(empty_dists, empty_idx, empty_nn) |
|
|
| dists = torch.cdist(p1.float(), p2.float(), p=2).square() |
| max_dist = float(radius) * float(radius) |
| valid = dists <= max_dist |
|
|
| if lengths2 is not None: |
| arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count) |
| valid = valid & (arange < lengths2.to(device=p2.device).view(batch, 1, 1)) |
|
|
| masked = dists.masked_fill(~valid, torch.inf) |
| take = min(k, p2_count) |
| dists_k, idx = torch.topk(masked, k=take, dim=-1, largest=False, sorted=True) |
| invalid = torch.isinf(dists_k) |
| idx = idx.masked_fill(invalid, -1) |
| dists_k = dists_k.masked_fill(invalid, 0) |
|
|
| if k > take: |
| pad = k - take |
| dists_k = torch.cat([dists_k, dists_k.new_zeros((batch, p1_count, pad))], dim=-1) |
| idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1) |
|
|
| if lengths1 is not None: |
| arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1) |
| valid_p1 = arange < lengths1.to(device=p1.device).view(batch, 1, 1) |
| dists_k = dists_k.masked_fill(~valid_p1, 0) |
| idx = idx.masked_fill(~valid_p1, -1) |
|
|
| knn = None |
| if return_nn: |
| safe_idx = idx.clamp_min(0) |
| gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2]) |
| points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1) |
| knn = torch.gather(points, 2, gather_idx) |
| knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0) |
|
|
| return BallQueryResult(dists_k.to(dtype=p1.dtype), idx, knn) |
|
|