English
File size: 11,296 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import time
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.linear_model import RANSACRegressor
from torch_scatter import scatter_min
from src.utils.partition import xy_partition
from src.utils.point import is_xyz_tensor
from src.utils.neighbors import knn_2



__all__ = [
    'filter_by_z_distance_of_global_min', 'filter_by_local_z_min',
    'filter_by_verticality', 'single_plane_model',
    'neighbor_interpolation_model', 'mlp_model']


def filter_by_z_distance_of_global_min(pos, threshold):
    """Search for points within `threshold` Z-distance of the lowest
    point in the input cloud `xyz`.

    This can be used to filter out points far from the ground, with some
    limitations:
    - if the point cloud contains below-ground points
    - if the ground is not even and involves stairs, slopes, ...

    :param pos: Tensor
        Input 3D point cloud
    :param threshold: float
        Z-distance threshold. Points within a Z-offset of `threshold`
        or lower of the lowest point (i.e. smallest Z) will be selected
    :return:
    """
    assert is_xyz_tensor(pos)
    return pos[:, 2] - pos[:, 2].min() < threshold


def filter_by_local_z_min(pos, grid):
    """Search for the lowest point in each cell of a horizontal XY grid.

    This can be used to filter out points far from the ground, with some
    limitations:
    - if the point cloud contains below-ground points
    - if the ground has slopes, the size of the `grid` may produce
    downstream staircasing effects if the local Z-min points are used as
    Z reference for local ground altitude

    :param pos: Tensor
        Input 3D point cloud
    :param grid: float
        Size of the grid "XY-voxel"
    :return:
    """
    assert is_xyz_tensor(pos)

    # Bin points into an XY grid
    super_index = xy_partition(pos, grid, consecutive=True)

    # Search for the lowest point in each grid cell
    z_min, z_argmin = scatter_min(pos[:, 2], super_index, dim=0)
    is_local_z_min = torch.full((pos.shape[0],), False, device=pos.device)
    is_local_z_min[z_argmin] = True

    return is_local_z_min


def filter_by_verticality(verticality, threshold):
    """Search for the points with low verticality.

    For verticality computation, see the `PointFeatures`.

    This can be used to filter out non-ground points, with some
    limitations:
    - if the point cloud is very noisy, or if the verticality
      was computed on too-small, or too-large neighborhoods, the
      verticality may not be sufficiently discriminative
    - if the ground has slopes, the steepest areas may be filtered out
    - if other non-ground horizontal surfaces are present in the point
      cloud, these will also be preserved (e.g. table, ceiling,
      horizontal building roof, ...)

    :param verticality: Tensor
        1D tensor holding verticality values as computed by
        `PointFeatures`
    :param threshold: float
        Verticality threshod below which points are considered
        "horizontal" enough
    :return:
    """
    return verticality.squeeze() < threshold


def single_plane_model(pos, random_state=0, residual_threshold=1e-3):
    """Model the ground as a single plane using RANSAC.

    Returns a callable taking an XYZ tensor as input and returning the
    pointwise elevation.

    :param pos: Tensor
        Input 3D point cloud
    :param random_state: int
        Seed for RANSAC
    :param residual_threshold: float
        Residual threshold for RANSAC
    :return:
    """
    assert is_xyz_tensor(pos)

    xy = pos[:, :2].cpu().numpy()
    z = pos[:, 2].cpu().numpy()

    # Search the ground plane using RANSAC
    ransac = RANSACRegressor(
        random_state=random_state, residual_threshold=residual_threshold).fit(
        xy, z)

    def predict_elevation(pos_query):
        assert is_xyz_tensor(pos_query)
        device = pos_query.device
        xy = pos_query[:, :2]
        z = pos_query[:, 2]
        return z - torch.from_numpy(ransac.predict(xy.cpu().numpy())).to(device)

    return predict_elevation


def neighbor_interpolation_model(pos, k=3, r_max=1):
    """Model the ground based on a trimmed point cloud carrying ground
    points only. At inference, a point is associated with its nearest
    neighbors in L2 XY distance in the reference ground cloud. The
    ground surface is estimated as a linear interpolation of the
    neighboring reference points. The elevation is then computed as the
    corresponding gap in Z-coordinates.

    Returns a callable taking an XYZ tensor as input and returning the
    pointwise elevation.

    :param pos: Tensor
        Input 3D point cloud
    :param k: int
        Number of neighbors to consider for interpolation
    :param r_max: float
        Maximum radius for the neighbor search
    :return:
    """
    assert is_xyz_tensor(pos)

    def predict_elevation(pos_query):
        # Neighbor search in XY space
        xy0 = F.pad(input=pos[:, :2], pad=(0, 1), mode='constant', value=0)
        xy0_query = F.pad(
            input=pos_query[:, :2], pad=(0, 1), mode='constant', value=0)
        neighbors, distances = knn_2(xy0, xy0_query, k, r_max=r_max)

        # In case some points received 0 neighbors, we search again for
        # those, with a radius so large that no point should be left
        # without a neighbor
        has_no_neighbor = (neighbors == -1).all(dim=1)
        if has_no_neighbor.any():
            high = xy0.max(dim=0).values
            low = xy0.min(dim=0).values
            high_query = xy0_query.max(dim=0).values
            low_query = xy0_query.min(dim=0).values
            r_max_ = max((high_query - low).norm(), (high - low_query).norm())

            neighbors_, distances_ = knn_2(
                xy0, xy0_query[has_no_neighbor], k, r_max=r_max_)

            neighbors[has_no_neighbor] = neighbors_
            distances[has_no_neighbor] = distances_

        # If only 1 neighbor is needed, no need for interpolation
        if k == 1:
            return pos_query[:, 2] - pos[neighbors][:, 2]

        # Note there might still be some missing neighbors here and
        # there, but no completely empty neighborhood. We treat these
        # missing neighbors by attributing a 0-weight
        weights = 1 / (distances + 1e-3)
        weights[distances == -1] = 0
        weights = weights / weights.sum(dim=1).view(-1, 1)

        # Estimate the ground height as the weighted combination of the
        # neighbors' height
        z_query = (pos[:, 2][neighbors] * weights).sum(dim=1)

        return pos_query[:, 2] - z_query

    return predict_elevation


def mlp_model(
        pos,
        layers=[32, 16, 8],
        batch_ratio=1,
        lr=0.01,
        lr_decay=1,
        weight_decay=0.01,
        criterion='l2',
        steps=1000,
        check_every_n_steps=50,
        device='cuda',
        verbose=False):
    """Fit an MLP to a point cloud. Assuming the point cloud mostly
    contains ground points, this function will train an MLP to model the
    ground surface as a piecewise-planar function.

    :param pos: Tensor
        Input 3D point cloud
    :param layers: int or List[int]
        Hidden layers for the MLP. Too many weights may let the model
        overfit to non-ground patterns. Not enough weights will underfit
        the ground and miss some patterns. Having more neurons in the
        first layer allows faster convergence
    :param batch_ratio: float
        Ratio of points to sample from the cloud at each training
        iteration. Allows adding some stochasticity to the training.
        In practice, `batch_ratio=1` gives better results if the entire
        cloud fits in memory
    :param lr: float
        Initial learning rate
    :param lr_decay: float
        Multiplicative factor applied to the learning rate after each
        iteration
    :param weight_decay: float
        Weight decay for regularization
    :param criterion: str
        Loss, either 'l1' or 'l2'
    :param steps: int
        Number of training steps. This largely affects overall compute
        time
    :param check_every_n_steps: int
        If `verbose=True` the loss will be logged every n iteration for
        final visualization
    :param device: str or torch.device
        Device on which to do the training and inference
    :param verbose: bool
        If True, a plot of the training loss and some stats will be
        printed at the end of the training
    :return:
    """
    # Local imports to avoid import loop errors
    from src.nn import MLP
    from src.nn.norm import BatchNorm

    assert is_xyz_tensor(pos)

    torch.cuda.synchronize()
    start = time.time()

    # Normalize the XYZ coordinates to live in a manageable range
    pos = pos.to(device)
    num_points = pos.shape[0]
    means = pos.mean(dim=0)
    stds = pos.std(dim=0)
    pos = (pos - means) / stds

    # Prepare the training
    batch_size = min(num_points, round(num_points * batch_ratio))
    layers = [layers] if isinstance(layers, int) else layers
    model = MLP(
        [2] + layers + [1],
        activation=nn.ReLU(),
        last_activation=False,
        norm=BatchNorm,
        last_norm=False,
        drop=None).to(device).train()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay)
    weights = torch.ones(num_points, device=device)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, lr_decay)

    # For drawing the loss plot for debugging purposes
    if verbose:
        l = []
        t = []

    # Training loop
    for step in range(steps):
        # Optionally, randomly drop some data points for augmentation
        if 0 < batch_ratio < 1:
            idx = torch.multinomial(weights, batch_size, replacement=False)
            pos_ = pos[idx]
        else:
            pos_ = pos

        # Forward pass
        xy = pos_[:, :2]
        z = pos_[:, 2]
        z_hat = model(xy)

        # Loss computation
        if criterion == 'l2':
            loss = ((z - z_hat.squeeze()) ** 2).mean()
        elif criterion == 'l1':
            loss = (z - z_hat.squeeze()).abs().mean()
        else:
            raise NotImplementedError("")

        # Gradient computation and backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if not verbose or step % check_every_n_steps != check_every_n_steps - 1:
            continue

        # Keep track of the loss
        # print(f"Step {step + 1} loss: {loss:0.3f}")
        t.append(step)
        l.append(loss.detach().cpu().item())

    if verbose:
        torch.cuda.synchronize()
        print(f"Training time: {time.time() - start:0.1f} sec")
        print(f"Loss: {l[-1]:0.3f}")
        plt.plot(t, l)
        plt.show()

    # Training is finished, set the model to inference mode
    model = model.eval()

    def predict_elevation(pos):
        input_device = pos.device
        pos = pos.to(device)

        xy = (pos[:, :2] - means[:2]) / stds[:2]
        z = model(xy).squeeze().detach()
        z = z * stds[2] + means[2]

        elevation = pos[:, 2] - z

        return elevation.to(input_device)

    return predict_elevation