File size: 4,470 Bytes
c008121 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from collections import namedtuple
import torch
KNNResult = namedtuple("KNNResult", ["dists", "idx", "knn"])
BallQueryResult = namedtuple("BallQueryResult", ["dists", "idx", "knn"])
def knn_points(p1, p2, lengths1=None, lengths2=None, K=1, norm=2, return_nn=False, return_sorted=True):
if p1.dim() != 3 or p2.dim() != 3:
raise ValueError("p1 and p2 must have shape (N, P, D)")
if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
raise ValueError("p1 and p2 must have matching batch and point dimensions")
batch, p1_count, _ = p1.shape
p2_count = p2.shape[1]
k = min(K, p2_count)
if k <= 0:
empty_dists = p1.new_empty((batch, p1_count, 0))
empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
return KNNResult(empty_dists, empty_idx, empty_nn)
dists = torch.cdist(p1.float(), p2.float(), p=norm)
if norm == 2:
dists = dists.square()
if lengths2 is not None:
arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
valid = arange < lengths2.to(device=p2.device).view(batch, 1, 1)
dists = dists.masked_fill(~valid, torch.inf)
dists_k, idx = torch.topk(dists, k=k, dim=-1, largest=False, sorted=return_sorted)
if K > k:
pad = K - k
dists_k = torch.cat([dists_k, dists_k.new_full((batch, p1_count, pad), torch.inf)], dim=-1)
idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)
if lengths1 is not None:
arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
valid = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
dists_k = dists_k.masked_fill(~valid, torch.inf)
idx = idx.masked_fill(~valid, -1)
knn = None
if return_nn:
safe_idx = idx.clamp_min(0)
gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
knn = torch.gather(points, 2, gather_idx)
knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)
return KNNResult(dists_k.to(dtype=p1.dtype), idx, knn)
def ball_query(p1, p2, lengths1=None, lengths2=None, K=1, radius=0.2, return_nn=False):
if p1.dim() != 3 or p2.dim() != 3:
raise ValueError("p1 and p2 must have shape (N, P, D)")
if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
raise ValueError("p1 and p2 must have matching batch and point dimensions")
batch, p1_count, _ = p1.shape
p2_count = p2.shape[1]
k = max(int(K), 0)
if k == 0:
empty_dists = p1.new_empty((batch, p1_count, 0))
empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
return BallQueryResult(empty_dists, empty_idx, empty_nn)
dists = torch.cdist(p1.float(), p2.float(), p=2).square()
max_dist = float(radius) * float(radius)
valid = dists <= max_dist
if lengths2 is not None:
arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
valid = valid & (arange < lengths2.to(device=p2.device).view(batch, 1, 1))
masked = dists.masked_fill(~valid, torch.inf)
take = min(k, p2_count)
dists_k, idx = torch.topk(masked, k=take, dim=-1, largest=False, sorted=True)
invalid = torch.isinf(dists_k)
idx = idx.masked_fill(invalid, -1)
dists_k = dists_k.masked_fill(invalid, 0)
if k > take:
pad = k - take
dists_k = torch.cat([dists_k, dists_k.new_zeros((batch, p1_count, pad))], dim=-1)
idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)
if lengths1 is not None:
arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
valid_p1 = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
dists_k = dists_k.masked_fill(~valid_p1, 0)
idx = idx.masked_fill(~valid_p1, -1)
knn = None
if return_nn:
safe_idx = idx.clamp_min(0)
gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
knn = torch.gather(points, 2, gather_idx)
knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)
return BallQueryResult(dists_k.to(dtype=p1.dtype), idx, knn)
|