File size: 9,433 Bytes
f4a0919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sparse tensor utils."""

import torch
import MinkowskiEngine as Me
import torch.nn.functional as F
from typing import Optional, Tuple, Dict


def sparse_cat_union(a: Me.SparseTensor, b: Me.SparseTensor):
    """Sparse cat union two sparse tensors."""
    cm = a.coordinate_manager
    stride = a.tensor_stride
    assert cm == b.coordinate_manager, "different coords_man"
    assert a.tensor_stride == b.tensor_stride, "different tensor_stride"

    # handle empty tensors - if one is empty, return the other
    if a.F.size(0) == 0 or a.F.numel() == 0:
        return b
    if b.F.size(0) == 0 or b.F.numel() == 0:
        return a
    # handle the error
    try:
        feats_a = F.pad(a.F, (0, b.F.shape[1]))
    except Exception as e:
        print("Warning: Got error in feats_a:", e)
        return a
    try:
        feats_b = F.pad(b.F, (a.F.shape[1], 0))
    except Exception as e:
        print("Warning: Got error in feats_b:", e)
        return b

    new_a = Me.SparseTensor(
        features=feats_a,
        coordinate_map_key=a.coordinate_key,
        coordinate_manager=cm,
        tensor_stride=stride,
    )

    new_b = Me.SparseTensor(
        features=feats_b,
        coordinate_map_key=b.coordinate_key,
        coordinate_manager=cm,
        tensor_stride=stride,
    )

    return new_a + new_b


def to_dense(
    tensor: Me.SparseTensor,
    shape: Optional[torch.Size] = None,
    min_coordinate: Optional[torch.IntTensor] = None,
    contract_stride: bool = True,
    default_value: float = 0.0
) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]:
    """Convert the :attr:`MinkowskiEngine.SparseTensor` to a torch dense
    tensor.
    Args:
        :attr:`shape` (torch.Size, optional): The size of the output tensor.
        :attr:`min_coordinate` (torch.IntTensor, optional): The min
        coordinates of the output sparse tensor. Must be divisible by the
        current :attr:`tensor_stride`. If 0 is given, it will use the origin for the min coordinate.
        :attr:`contract_stride` (bool, optional): The output coordinates
        will be divided by the tensor stride to make features spatially
        contiguous. True by default.
    Returns:
        :attr:`tensor` (torch.Tensor): the torch tensor with size `[Batch
        Dim, Feature Dim, Spatial Dim..., Spatial Dim]`. The coordinate of
        each feature can be accessed via `min_coordinate + tensor_stride *
        [the coordinate of the dense tensor]`.
        :attr:`min_coordinate` (torch.IntTensor): the D-dimensional vector
        defining the minimum coordinate of the output tensor.
        :attr:`tensor_stride` (torch.IntTensor): the D-dimensional vector
        defining the stride between tensor elements.
    """
    if min_coordinate is not None:
        assert isinstance(min_coordinate, torch.IntTensor)
        assert min_coordinate.numel() == tensor._D
    if shape is not None:
        assert isinstance(shape, torch.Size)
        assert len(shape) == tensor._D + 2  # batch and channel
        if shape[1] != tensor._F.size(1):
            shape = torch.Size([shape[0], tensor._F.size(1), *[s for s in shape[2:]]])

    # exception handling for empty tensor
    if tensor.__len__() == 0:
        assert shape is not None, "shape is required to densify an empty tensor"
        return (
            torch.zeros(shape, dtype=tensor.dtype, device=tensor.device),
            torch.zeros(tensor._D, dtype=torch.int32, device=tensor.device),
            tensor.tensor_stride,
        )

    # use int tensor for all operations
    tensor_stride = torch.IntTensor(tensor.tensor_stride).to(tensor.device)

    # new coordinates
    batch_indices = tensor.C[:, 0]

    if min_coordinate is None:
        min_coordinate, _ = tensor.C.min(0, keepdim=True)
        min_coordinate = min_coordinate[:, 1:]
        if not torch.all(min_coordinate >= 0):
            raise ValueError(
                f"Coordinate has a negative value: {min_coordinate}. Please provide min_coordinate argument"
            )
        coords = tensor.C[:, 1:]
    elif isinstance(min_coordinate, int) and min_coordinate == 0:
        coords = tensor.C[:, 1:]
    else:
        min_coordinate = min_coordinate.to(tensor.device)
        if min_coordinate.ndim == 1:
            min_coordinate = min_coordinate.unsqueeze(0)
        coords = tensor.C[:, 1:] - min_coordinate

    assert (
        min_coordinate % tensor_stride
    ).sum() == 0, "The minimum coordinates must be divisible by the tensor stride."

    if coords.ndim == 1:
        coords = coords.unsqueeze(1)

    # return the contracted tensor
    if contract_stride:
        coords = torch.div(coords, tensor_stride, rounding_mode="floor")

    nchannels = tensor.F.size(1)
    if shape is None:
        size = coords.max(0)[0] + 1
        shape = torch.Size(
            [batch_indices.max() + 1, nchannels, *size.cpu().numpy()]
        )

    dense_F = torch.full(
        shape, dtype=tensor.F.dtype,
        device=tensor.F.device, fill_value=default_value
    )

    tcoords = coords.t().long()
    batch_indices = batch_indices.long()

    indices = (batch_indices, slice(None), *tcoords)
    dense_F[indices] = tensor.F

    tensor_stride = torch.IntTensor(tensor.tensor_stride)
    return dense_F, min_coordinate, tensor_stride


def _thicken_grid(grid, grid_dims, frustum_mask):
    """Thicken grid."""
    device = frustum_mask.device
    offsets = torch.nonzero(torch.ones(3, 3, 3, device=device)).long()
    locs_grid = grid.nonzero(as_tuple=False)
    locs = locs_grid.unsqueeze(1).repeat(1, 27, 1)
    locs += offsets
    locs = locs.view(-1, 3)
    masks = ((locs >= 0) & (locs < torch.as_tensor(grid_dims, device=device))).all(-1)
    locs = locs[masks]

    thicken = torch.zeros(grid_dims, dtype=torch.bool, device=device)
    thicken[locs[:, 0], locs[:, 1], locs[:, 2]] = True
    # frustum culling
    thicken = thicken & frustum_mask

    return thicken


def prepare_instance_masks_thicken(
    instances: torch.Tensor,
    semantic_mapping: Dict[int, int],
    distance_field: torch.Tensor,
    frustum_mask: torch.Tensor,
    iso_value: float = 1.0,
    truncation: float = 3.0,
    downsample_factor: int = 1
) -> Dict[int, Tuple[torch.Tensor, int]]:
    """Prepare instance masks thicken."""
    # check if downsample factor is valid
    assert isinstance(downsample_factor, int) and 256 % downsample_factor == 0
    grid_dims = [256, 256, 256]
    need_rescale = downsample_factor != 1
    if need_rescale:
        grid_dims = (torch.as_tensor(grid_dims) // downsample_factor).tolist()
        frustum_mask = F.interpolate(frustum_mask[None, None].float(),
                                     size=grid_dims, mode="nearest").squeeze(0, 1).bool()

    instance_information = {}

    for instance_id, semantic_class in semantic_mapping.items():
        instance_mask: torch.Tensor = (instances == instance_id)
        instance_distance_field = torch.full_like(
            instance_mask,
            dtype=torch.float,
            fill_value=truncation
        )
        instance_distance_field[instance_mask] = distance_field.squeeze()[instance_mask]
        instance_distance_field_masked = instance_distance_field.abs() < iso_value

        if need_rescale:
            instance_distance_field_masked = F.max_pool3d(
                instance_distance_field_masked[None, None].float(),
                kernel_size=downsample_factor + 1,
                stride=downsample_factor,
                padding=1
            ).squeeze(0, 1).bool()

        # instance_grid = instance_grid & frustum_mask
        instance_grid = _thicken_grid(
            instance_distance_field_masked,
            grid_dims,
            frustum_mask
        )
        instance_grid: torch.Tensor = instance_grid.to(torch.device("cpu"), non_blocking=True)
        instance_information[instance_id] = instance_grid, semantic_class

    return instance_information


def mask_invalid_sparse_voxels(
    grid: Me.SparseTensor,
    mask=None, frustum_dim=[256, 256, 256]
) -> Me.SparseTensor:
    """Mask invalid sparse voxels."""
    # Mask out voxels which are outside of the grid
    valid_mask = (grid.C[:, 1] < frustum_dim[0] - 1) & (grid.C[:, 1] >= 0) & \
                 (grid.C[:, 2] < frustum_dim[1] - 1) & (grid.C[:, 2] >= 0) & \
                 (grid.C[:, 3] < frustum_dim[2] - 1) & (grid.C[:, 3] >= 0)
    if mask is not None:
        valid_mask = valid_mask * mask
    num_valid_coordinates = valid_mask.sum()

    if num_valid_coordinates == 0:
        return {}, {}

    num_masked_voxels = grid.C.size(0) - num_valid_coordinates
    grids_needs_to_be_pruned = num_masked_voxels > 0

    # Fix: Only prune if there are invalid voxels
    if grids_needs_to_be_pruned:
        grid = Me.MinkowskiPruning()(grid, valid_mask)

    return grid