# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sparse tensor utils.""" import torch import MinkowskiEngine as Me import torch.nn.functional as F from typing import Optional, Tuple, Dict def sparse_cat_union(a: Me.SparseTensor, b: Me.SparseTensor): """Sparse cat union two sparse tensors.""" cm = a.coordinate_manager stride = a.tensor_stride assert cm == b.coordinate_manager, "different coords_man" assert a.tensor_stride == b.tensor_stride, "different tensor_stride" # handle empty tensors - if one is empty, return the other if a.F.size(0) == 0 or a.F.numel() == 0: return b if b.F.size(0) == 0 or b.F.numel() == 0: return a # handle the error try: feats_a = F.pad(a.F, (0, b.F.shape[1])) except Exception as e: print("Warning: Got error in feats_a:", e) return a try: feats_b = F.pad(b.F, (a.F.shape[1], 0)) except Exception as e: print("Warning: Got error in feats_b:", e) return b new_a = Me.SparseTensor( features=feats_a, coordinate_map_key=a.coordinate_key, coordinate_manager=cm, tensor_stride=stride, ) new_b = Me.SparseTensor( features=feats_b, coordinate_map_key=b.coordinate_key, coordinate_manager=cm, tensor_stride=stride, ) return new_a + new_b def to_dense( tensor: Me.SparseTensor, shape: Optional[torch.Size] = None, min_coordinate: Optional[torch.IntTensor] = None, contract_stride: bool = True, default_value: float = 0.0 ) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]: """Convert the :attr:`MinkowskiEngine.SparseTensor` to a torch dense tensor. Args: :attr:`shape` (torch.Size, optional): The size of the output tensor. :attr:`min_coordinate` (torch.IntTensor, optional): The min coordinates of the output sparse tensor. Must be divisible by the current :attr:`tensor_stride`. If 0 is given, it will use the origin for the min coordinate. :attr:`contract_stride` (bool, optional): The output coordinates will be divided by the tensor stride to make features spatially contiguous. True by default. Returns: :attr:`tensor` (torch.Tensor): the torch tensor with size `[Batch Dim, Feature Dim, Spatial Dim..., Spatial Dim]`. The coordinate of each feature can be accessed via `min_coordinate + tensor_stride * [the coordinate of the dense tensor]`. :attr:`min_coordinate` (torch.IntTensor): the D-dimensional vector defining the minimum coordinate of the output tensor. :attr:`tensor_stride` (torch.IntTensor): the D-dimensional vector defining the stride between tensor elements. """ if min_coordinate is not None: assert isinstance(min_coordinate, torch.IntTensor) assert min_coordinate.numel() == tensor._D if shape is not None: assert isinstance(shape, torch.Size) assert len(shape) == tensor._D + 2 # batch and channel if shape[1] != tensor._F.size(1): shape = torch.Size([shape[0], tensor._F.size(1), *[s for s in shape[2:]]]) # exception handling for empty tensor if tensor.__len__() == 0: assert shape is not None, "shape is required to densify an empty tensor" return ( torch.zeros(shape, dtype=tensor.dtype, device=tensor.device), torch.zeros(tensor._D, dtype=torch.int32, device=tensor.device), tensor.tensor_stride, ) # use int tensor for all operations tensor_stride = torch.IntTensor(tensor.tensor_stride).to(tensor.device) # new coordinates batch_indices = tensor.C[:, 0] if min_coordinate is None: min_coordinate, _ = tensor.C.min(0, keepdim=True) min_coordinate = min_coordinate[:, 1:] if not torch.all(min_coordinate >= 0): raise ValueError( f"Coordinate has a negative value: {min_coordinate}. Please provide min_coordinate argument" ) coords = tensor.C[:, 1:] elif isinstance(min_coordinate, int) and min_coordinate == 0: coords = tensor.C[:, 1:] else: min_coordinate = min_coordinate.to(tensor.device) if min_coordinate.ndim == 1: min_coordinate = min_coordinate.unsqueeze(0) coords = tensor.C[:, 1:] - min_coordinate assert ( min_coordinate % tensor_stride ).sum() == 0, "The minimum coordinates must be divisible by the tensor stride." if coords.ndim == 1: coords = coords.unsqueeze(1) # return the contracted tensor if contract_stride: coords = torch.div(coords, tensor_stride, rounding_mode="floor") nchannels = tensor.F.size(1) if shape is None: size = coords.max(0)[0] + 1 shape = torch.Size( [batch_indices.max() + 1, nchannels, *size.cpu().numpy()] ) dense_F = torch.full( shape, dtype=tensor.F.dtype, device=tensor.F.device, fill_value=default_value ) tcoords = coords.t().long() batch_indices = batch_indices.long() indices = (batch_indices, slice(None), *tcoords) dense_F[indices] = tensor.F tensor_stride = torch.IntTensor(tensor.tensor_stride) return dense_F, min_coordinate, tensor_stride def _thicken_grid(grid, grid_dims, frustum_mask): """Thicken grid.""" device = frustum_mask.device offsets = torch.nonzero(torch.ones(3, 3, 3, device=device)).long() locs_grid = grid.nonzero(as_tuple=False) locs = locs_grid.unsqueeze(1).repeat(1, 27, 1) locs += offsets locs = locs.view(-1, 3) masks = ((locs >= 0) & (locs < torch.as_tensor(grid_dims, device=device))).all(-1) locs = locs[masks] thicken = torch.zeros(grid_dims, dtype=torch.bool, device=device) thicken[locs[:, 0], locs[:, 1], locs[:, 2]] = True # frustum culling thicken = thicken & frustum_mask return thicken def prepare_instance_masks_thicken( instances: torch.Tensor, semantic_mapping: Dict[int, int], distance_field: torch.Tensor, frustum_mask: torch.Tensor, iso_value: float = 1.0, truncation: float = 3.0, downsample_factor: int = 1 ) -> Dict[int, Tuple[torch.Tensor, int]]: """Prepare instance masks thicken.""" # check if downsample factor is valid assert isinstance(downsample_factor, int) and 256 % downsample_factor == 0 grid_dims = [256, 256, 256] need_rescale = downsample_factor != 1 if need_rescale: grid_dims = (torch.as_tensor(grid_dims) // downsample_factor).tolist() frustum_mask = F.interpolate(frustum_mask[None, None].float(), size=grid_dims, mode="nearest").squeeze(0, 1).bool() instance_information = {} for instance_id, semantic_class in semantic_mapping.items(): instance_mask: torch.Tensor = (instances == instance_id) instance_distance_field = torch.full_like( instance_mask, dtype=torch.float, fill_value=truncation ) instance_distance_field[instance_mask] = distance_field.squeeze()[instance_mask] instance_distance_field_masked = instance_distance_field.abs() < iso_value if need_rescale: instance_distance_field_masked = F.max_pool3d( instance_distance_field_masked[None, None].float(), kernel_size=downsample_factor + 1, stride=downsample_factor, padding=1 ).squeeze(0, 1).bool() # instance_grid = instance_grid & frustum_mask instance_grid = _thicken_grid( instance_distance_field_masked, grid_dims, frustum_mask ) instance_grid: torch.Tensor = instance_grid.to(torch.device("cpu"), non_blocking=True) instance_information[instance_id] = instance_grid, semantic_class return instance_information def mask_invalid_sparse_voxels( grid: Me.SparseTensor, mask=None, frustum_dim=[256, 256, 256] ) -> Me.SparseTensor: """Mask invalid sparse voxels.""" # Mask out voxels which are outside of the grid valid_mask = (grid.C[:, 1] < frustum_dim[0] - 1) & (grid.C[:, 1] >= 0) & \ (grid.C[:, 2] < frustum_dim[1] - 1) & (grid.C[:, 2] >= 0) & \ (grid.C[:, 3] < frustum_dim[2] - 1) & (grid.C[:, 3] >= 0) if mask is not None: valid_mask = valid_mask * mask num_valid_coordinates = valid_mask.sum() if num_valid_coordinates == 0: return {}, {} num_masked_voxels = grid.C.size(0) - num_valid_coordinates grids_needs_to_be_pruned = num_masked_voxels > 0 # Fix: Only prune if there are invalid voxels if grids_needs_to_be_pruned: grid = Me.MinkowskiPruning()(grid, valid_mask) return grid