File size: 9,433 Bytes
f4a0919 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sparse tensor utils."""
import torch
import MinkowskiEngine as Me
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
def sparse_cat_union(a: Me.SparseTensor, b: Me.SparseTensor):
"""Sparse cat union two sparse tensors."""
cm = a.coordinate_manager
stride = a.tensor_stride
assert cm == b.coordinate_manager, "different coords_man"
assert a.tensor_stride == b.tensor_stride, "different tensor_stride"
# handle empty tensors - if one is empty, return the other
if a.F.size(0) == 0 or a.F.numel() == 0:
return b
if b.F.size(0) == 0 or b.F.numel() == 0:
return a
# handle the error
try:
feats_a = F.pad(a.F, (0, b.F.shape[1]))
except Exception as e:
print("Warning: Got error in feats_a:", e)
return a
try:
feats_b = F.pad(b.F, (a.F.shape[1], 0))
except Exception as e:
print("Warning: Got error in feats_b:", e)
return b
new_a = Me.SparseTensor(
features=feats_a,
coordinate_map_key=a.coordinate_key,
coordinate_manager=cm,
tensor_stride=stride,
)
new_b = Me.SparseTensor(
features=feats_b,
coordinate_map_key=b.coordinate_key,
coordinate_manager=cm,
tensor_stride=stride,
)
return new_a + new_b
def to_dense(
tensor: Me.SparseTensor,
shape: Optional[torch.Size] = None,
min_coordinate: Optional[torch.IntTensor] = None,
contract_stride: bool = True,
default_value: float = 0.0
) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]:
"""Convert the :attr:`MinkowskiEngine.SparseTensor` to a torch dense
tensor.
Args:
:attr:`shape` (torch.Size, optional): The size of the output tensor.
:attr:`min_coordinate` (torch.IntTensor, optional): The min
coordinates of the output sparse tensor. Must be divisible by the
current :attr:`tensor_stride`. If 0 is given, it will use the origin for the min coordinate.
:attr:`contract_stride` (bool, optional): The output coordinates
will be divided by the tensor stride to make features spatially
contiguous. True by default.
Returns:
:attr:`tensor` (torch.Tensor): the torch tensor with size `[Batch
Dim, Feature Dim, Spatial Dim..., Spatial Dim]`. The coordinate of
each feature can be accessed via `min_coordinate + tensor_stride *
[the coordinate of the dense tensor]`.
:attr:`min_coordinate` (torch.IntTensor): the D-dimensional vector
defining the minimum coordinate of the output tensor.
:attr:`tensor_stride` (torch.IntTensor): the D-dimensional vector
defining the stride between tensor elements.
"""
if min_coordinate is not None:
assert isinstance(min_coordinate, torch.IntTensor)
assert min_coordinate.numel() == tensor._D
if shape is not None:
assert isinstance(shape, torch.Size)
assert len(shape) == tensor._D + 2 # batch and channel
if shape[1] != tensor._F.size(1):
shape = torch.Size([shape[0], tensor._F.size(1), *[s for s in shape[2:]]])
# exception handling for empty tensor
if tensor.__len__() == 0:
assert shape is not None, "shape is required to densify an empty tensor"
return (
torch.zeros(shape, dtype=tensor.dtype, device=tensor.device),
torch.zeros(tensor._D, dtype=torch.int32, device=tensor.device),
tensor.tensor_stride,
)
# use int tensor for all operations
tensor_stride = torch.IntTensor(tensor.tensor_stride).to(tensor.device)
# new coordinates
batch_indices = tensor.C[:, 0]
if min_coordinate is None:
min_coordinate, _ = tensor.C.min(0, keepdim=True)
min_coordinate = min_coordinate[:, 1:]
if not torch.all(min_coordinate >= 0):
raise ValueError(
f"Coordinate has a negative value: {min_coordinate}. Please provide min_coordinate argument"
)
coords = tensor.C[:, 1:]
elif isinstance(min_coordinate, int) and min_coordinate == 0:
coords = tensor.C[:, 1:]
else:
min_coordinate = min_coordinate.to(tensor.device)
if min_coordinate.ndim == 1:
min_coordinate = min_coordinate.unsqueeze(0)
coords = tensor.C[:, 1:] - min_coordinate
assert (
min_coordinate % tensor_stride
).sum() == 0, "The minimum coordinates must be divisible by the tensor stride."
if coords.ndim == 1:
coords = coords.unsqueeze(1)
# return the contracted tensor
if contract_stride:
coords = torch.div(coords, tensor_stride, rounding_mode="floor")
nchannels = tensor.F.size(1)
if shape is None:
size = coords.max(0)[0] + 1
shape = torch.Size(
[batch_indices.max() + 1, nchannels, *size.cpu().numpy()]
)
dense_F = torch.full(
shape, dtype=tensor.F.dtype,
device=tensor.F.device, fill_value=default_value
)
tcoords = coords.t().long()
batch_indices = batch_indices.long()
indices = (batch_indices, slice(None), *tcoords)
dense_F[indices] = tensor.F
tensor_stride = torch.IntTensor(tensor.tensor_stride)
return dense_F, min_coordinate, tensor_stride
def _thicken_grid(grid, grid_dims, frustum_mask):
"""Thicken grid."""
device = frustum_mask.device
offsets = torch.nonzero(torch.ones(3, 3, 3, device=device)).long()
locs_grid = grid.nonzero(as_tuple=False)
locs = locs_grid.unsqueeze(1).repeat(1, 27, 1)
locs += offsets
locs = locs.view(-1, 3)
masks = ((locs >= 0) & (locs < torch.as_tensor(grid_dims, device=device))).all(-1)
locs = locs[masks]
thicken = torch.zeros(grid_dims, dtype=torch.bool, device=device)
thicken[locs[:, 0], locs[:, 1], locs[:, 2]] = True
# frustum culling
thicken = thicken & frustum_mask
return thicken
def prepare_instance_masks_thicken(
instances: torch.Tensor,
semantic_mapping: Dict[int, int],
distance_field: torch.Tensor,
frustum_mask: torch.Tensor,
iso_value: float = 1.0,
truncation: float = 3.0,
downsample_factor: int = 1
) -> Dict[int, Tuple[torch.Tensor, int]]:
"""Prepare instance masks thicken."""
# check if downsample factor is valid
assert isinstance(downsample_factor, int) and 256 % downsample_factor == 0
grid_dims = [256, 256, 256]
need_rescale = downsample_factor != 1
if need_rescale:
grid_dims = (torch.as_tensor(grid_dims) // downsample_factor).tolist()
frustum_mask = F.interpolate(frustum_mask[None, None].float(),
size=grid_dims, mode="nearest").squeeze(0, 1).bool()
instance_information = {}
for instance_id, semantic_class in semantic_mapping.items():
instance_mask: torch.Tensor = (instances == instance_id)
instance_distance_field = torch.full_like(
instance_mask,
dtype=torch.float,
fill_value=truncation
)
instance_distance_field[instance_mask] = distance_field.squeeze()[instance_mask]
instance_distance_field_masked = instance_distance_field.abs() < iso_value
if need_rescale:
instance_distance_field_masked = F.max_pool3d(
instance_distance_field_masked[None, None].float(),
kernel_size=downsample_factor + 1,
stride=downsample_factor,
padding=1
).squeeze(0, 1).bool()
# instance_grid = instance_grid & frustum_mask
instance_grid = _thicken_grid(
instance_distance_field_masked,
grid_dims,
frustum_mask
)
instance_grid: torch.Tensor = instance_grid.to(torch.device("cpu"), non_blocking=True)
instance_information[instance_id] = instance_grid, semantic_class
return instance_information
def mask_invalid_sparse_voxels(
grid: Me.SparseTensor,
mask=None, frustum_dim=[256, 256, 256]
) -> Me.SparseTensor:
"""Mask invalid sparse voxels."""
# Mask out voxels which are outside of the grid
valid_mask = (grid.C[:, 1] < frustum_dim[0] - 1) & (grid.C[:, 1] >= 0) & \
(grid.C[:, 2] < frustum_dim[1] - 1) & (grid.C[:, 2] >= 0) & \
(grid.C[:, 3] < frustum_dim[2] - 1) & (grid.C[:, 3] >= 0)
if mask is not None:
valid_mask = valid_mask * mask
num_valid_coordinates = valid_mask.sum()
if num_valid_coordinates == 0:
return {}, {}
num_masked_voxels = grid.C.size(0) - num_valid_coordinates
grids_needs_to_be_pruned = num_masked_voxels > 0
# Fix: Only prune if there are invalid voxels
if grids_needs_to_be_pruned:
grid = Me.MinkowskiPruning()(grid, valid_mask)
return grid
|