File size: 4,470 Bytes
c008121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from collections import namedtuple

import torch


KNNResult = namedtuple("KNNResult", ["dists", "idx", "knn"])
BallQueryResult = namedtuple("BallQueryResult", ["dists", "idx", "knn"])


def knn_points(p1, p2, lengths1=None, lengths2=None, K=1, norm=2, return_nn=False, return_sorted=True):
    if p1.dim() != 3 or p2.dim() != 3:
        raise ValueError("p1 and p2 must have shape (N, P, D)")
    if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
        raise ValueError("p1 and p2 must have matching batch and point dimensions")

    batch, p1_count, _ = p1.shape
    p2_count = p2.shape[1]
    k = min(K, p2_count)
    if k <= 0:
        empty_dists = p1.new_empty((batch, p1_count, 0))
        empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
        empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
        return KNNResult(empty_dists, empty_idx, empty_nn)

    dists = torch.cdist(p1.float(), p2.float(), p=norm)
    if norm == 2:
        dists = dists.square()

    if lengths2 is not None:
        arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
        valid = arange < lengths2.to(device=p2.device).view(batch, 1, 1)
        dists = dists.masked_fill(~valid, torch.inf)

    dists_k, idx = torch.topk(dists, k=k, dim=-1, largest=False, sorted=return_sorted)

    if K > k:
        pad = K - k
        dists_k = torch.cat([dists_k, dists_k.new_full((batch, p1_count, pad), torch.inf)], dim=-1)
        idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)

    if lengths1 is not None:
        arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
        valid = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
        dists_k = dists_k.masked_fill(~valid, torch.inf)
        idx = idx.masked_fill(~valid, -1)

    knn = None
    if return_nn:
        safe_idx = idx.clamp_min(0)
        gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
        points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
        knn = torch.gather(points, 2, gather_idx)
        knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)

    return KNNResult(dists_k.to(dtype=p1.dtype), idx, knn)


def ball_query(p1, p2, lengths1=None, lengths2=None, K=1, radius=0.2, return_nn=False):
    if p1.dim() != 3 or p2.dim() != 3:
        raise ValueError("p1 and p2 must have shape (N, P, D)")
    if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
        raise ValueError("p1 and p2 must have matching batch and point dimensions")

    batch, p1_count, _ = p1.shape
    p2_count = p2.shape[1]
    k = max(int(K), 0)
    if k == 0:
        empty_dists = p1.new_empty((batch, p1_count, 0))
        empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
        empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
        return BallQueryResult(empty_dists, empty_idx, empty_nn)

    dists = torch.cdist(p1.float(), p2.float(), p=2).square()
    max_dist = float(radius) * float(radius)
    valid = dists <= max_dist

    if lengths2 is not None:
        arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
        valid = valid & (arange < lengths2.to(device=p2.device).view(batch, 1, 1))

    masked = dists.masked_fill(~valid, torch.inf)
    take = min(k, p2_count)
    dists_k, idx = torch.topk(masked, k=take, dim=-1, largest=False, sorted=True)
    invalid = torch.isinf(dists_k)
    idx = idx.masked_fill(invalid, -1)
    dists_k = dists_k.masked_fill(invalid, 0)

    if k > take:
        pad = k - take
        dists_k = torch.cat([dists_k, dists_k.new_zeros((batch, p1_count, pad))], dim=-1)
        idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)

    if lengths1 is not None:
        arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
        valid_p1 = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
        dists_k = dists_k.masked_fill(~valid_p1, 0)
        idx = idx.masked_fill(~valid_p1, -1)

    knn = None
    if return_nn:
        safe_idx = idx.clamp_min(0)
        gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
        points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
        knn = torch.gather(points, 2, gather_idx)
        knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)

    return BallQueryResult(dists_k.to(dtype=p1.dtype), idx, knn)