# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Decoder for Panoptic Recon 3D.""" from typing import Optional, List import torch from torch import nn import MinkowskiEngine as Me from ..utils.sparse_tensor import sparse_cat_union from ..blocks import BasicBlock3D, SparseBasicBlock3D class SparseToDense(nn.Module): """Sparse to dense module.""" def __init__(self, input_size): """Initialize the sparse to dense module.""" super().__init__() assert len(input_size) == 3 self.input_size = input_size def forward(self, feature: Me.SparseTensor) -> torch.Tensor: """Forward pass.""" batch_size = len(feature.decomposed_coordinates_and_features[0]) feat_dim = feature.C.shape[-1] out_size = ( torch.div( torch.tensor(self.input_size), torch.tensor(feature.tensor_stride), rounding_mode="floor" ) ).tolist() shape = torch.Size([batch_size, feat_dim, *out_size]) min_coordinate = torch.IntTensor([0, 0, 0]) mask = (feature.C[:, 1] < self.input_size[0]) & \ (feature.C[:, 2] < self.input_size[1]) & \ (feature.C[:, 3] < self.input_size[2]) mask = mask & (feature.C[:, 1] >= 0) & (feature.C[:, 2] >= 0) & (feature.C[:, 3] >= 0) feature = Me.MinkowskiPruning()(feature, mask) dense = feature.dense(shape, min_coordinate=min_coordinate)[0] return dense class FrustumDecoder(nn.Module): """Frustum decoder module.""" def __init__(self, cfg) -> None: """Initialize the frustum decoder module.""" super().__init__() num_output_features = cfg.model.frustum3d.unet_output_channels num_features = cfg.model.frustum3d.unet_features sign_channel = cfg.model.projection.sign_channel mask_dim = cfg.model.sem_seg_head.mask_dim depth_dim = cfg.model.sem_seg_head.depth_dim num_classes = cfg.model.sem_seg_head.num_classes frustum_dims = cfg.model.frustum3d.grid_dimensions frustum_dims = [frustum_dims] * 3 self.use_ms_features = cfg.model.frustum3d.use_multi_scale self.truncation = cfg.model.frustum3d.truncation if cfg.dataset.name == 'matterport': ms_feature_channels = cfg.model.sem_seg_head.convs_dim else: ms_feature_channels = cfg.model.sem_seg_head.convs_dim + \ cfg.model.sem_seg_head.num_classes + cfg.model.frustum3d.signed_channel # input encoding self.input_dims = [2 if sign_channel else 1, mask_dim + depth_dim, num_classes] self.input_encoders = nn.ModuleList() for input_dim in self.input_dims: downsample = nn.Sequential( Me.MinkowskiConvolution( input_dim, num_features, kernel_size=1, stride=1, bias=True, dimension=3 ), Me.MinkowskiInstanceNorm(num_features), ) self.input_encoders.append( SparseBasicBlock3D( input_dim, num_features, downsample=downsample ) ) self.level_encoders = nn.ModuleList([ self.make_encoder(len(self.input_encoders) * num_features, num_features), self.make_encoder(num_features, num_features * 2), self.make_encoder(num_features * 2, num_features * 4, is_sparse=False), self.make_encoder(num_features * 4, num_features * 8, is_sparse=False), self.make_encoder(num_features * 8, num_features * 8, is_sparse=False), ]) sparse_to_dense = SparseToDense(frustum_dims) if self.use_ms_features: self.feature_adapters = nn.ModuleList([ self.make_adapter(ms_feature_channels, num_features), self.make_adapter(ms_feature_channels, num_features * 2), self.make_adapter(ms_feature_channels, num_features * 4, [sparse_to_dense]), ]) else: self.feature_adapters = None self.enc_level_conversion = nn.ModuleList([ nn.Identity(), sparse_to_dense, nn.Identity(), nn.Identity(), ]) self.level_decoders = nn.ModuleList([ self.make_decoder(num_features * 3, num_output_features), self.make_decoder( num_features * 6, num_features * 2, extra_layers=[SparseBasicBlock3D(num_features * 2, num_features * 2)] ), self.make_decoder(num_features * 8, num_features * 2, is_sparse=False), self.make_decoder(num_features * 16, num_features * 4, is_sparse=False), self.make_decoder(num_features * 8, num_features * 8, is_sparse=False), ]) # occupancy heads self.level_occupancy_heads = nn.ModuleList([ nn.Sequential( Me.MinkowskiInstanceNorm(num_output_features), Me.MinkowskiReLU(inplace=True), SparseBasicBlock3D(num_output_features, num_output_features), Me.MinkowskiConvolution(num_output_features, 1, kernel_size=3, bias=True, dimension=3), ), Me.MinkowskiLinear(num_features * 2, 1), nn.Linear(num_features * 4, 1), ]) # panoptic heads self.level_segm_embeddings = nn.ModuleList([ nn.Sequential( Me.MinkowskiInstanceNorm(num_output_features), Me.MinkowskiReLU(inplace=True), SparseBasicBlock3D(num_output_features, num_output_features), ), SparseBasicBlock3D(num_features * 3, num_features * 3), nn.Sequential( BasicBlock3D(num_features * 4, num_features * 4), BasicBlock3D(num_features * 4, num_features * 4), ) ]) self.level_segm_query_projection = nn.ModuleList([ nn.Linear(mask_dim, num_output_features), nn.Linear(mask_dim, num_features * 3), nn.Linear(mask_dim, num_features * 4), ]) # geometry head self.geometry_head = nn.Sequential( Me.MinkowskiInstanceNorm(num_output_features), Me.MinkowskiReLU(inplace=True), SparseBasicBlock3D(num_output_features, num_output_features), Me.MinkowskiConvolution(num_output_features, 1, kernel_size=3, bias=True, dimension=3), ) self.register_buffer("frustum_dimensions", torch.tensor(frustum_dims), persistent=False) @staticmethod def forward_sparse_segm(segm_features, queries): """Forward pass for sparse segmentation.""" features = segm_features.decomposed_features segms = torch.cat( [torch.mm(features[idx], queries[idx].T) for idx in range(len(features))], dim=0 ) return Me.SparseTensor( segms, coordinate_manager=segm_features.coordinate_manager, coordinate_map_key=segm_features.coordinate_map_key, ) @staticmethod def make_encoder(input_dim, output_dim, is_sparse=True): """Make encoder module.""" if is_sparse: downsample = nn.Sequential( Me.MinkowskiConvolution( input_dim, output_dim, kernel_size=4, stride=2, bias=True, dimension=3 ), Me.MinkowskiInstanceNorm(output_dim), ) module = nn.Sequential( SparseBasicBlock3D(input_dim, output_dim, stride=2, downsample=downsample), SparseBasicBlock3D(output_dim, output_dim), ) else: downsample = nn.Conv3d( input_dim, output_dim, kernel_size=4, stride=2, padding=1, bias=False ) module = nn.Sequential( BasicBlock3D(input_dim, output_dim, stride=2, downsample=downsample), BasicBlock3D(output_dim, output_dim), ) return module @staticmethod def make_decoder(input_dim, output_dim, is_sparse=True, extra_layers: Optional[List] = None): """Make decoder module.""" if extra_layers is None: extra_layers = [] if is_sparse: return nn.Sequential( Me.MinkowskiConvolutionTranspose( input_dim, output_dim, kernel_size=4, stride=2, bias=False, dimension=3, expand_coordinates=True ), Me.MinkowskiInstanceNorm(output_dim), Me.MinkowskiReLU(inplace=True), *extra_layers, ) else: return nn.Sequential( nn.ConvTranspose3d(input_dim, output_dim, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm3d(output_dim), nn.ReLU(inplace=True), *extra_layers, ) @staticmethod def make_adapter(input_dim, output_dim, extra_layers: Optional[List] = None): """Make adapter module.""" if extra_layers is None: extra_layers = [] downsample = nn.Sequential( Me.MinkowskiConvolution(input_dim, output_dim, kernel_size=1, stride=1, bias=True, dimension=3), Me.MinkowskiInstanceNorm(output_dim), ) return nn.Sequential( SparseBasicBlock3D(input_dim, output_dim, downsample=downsample), *extra_layers, ) def forward( self, ms_features: List[Me.SparseTensor], features: Me.SparseTensor, segm_queries, frustum_mask ): """Forward pass.""" start_dim = 0 encoded_inputs = [] cm = features.coordinate_manager key = features.coordinate_map_key for dim, encoder in zip(self.input_dims, self.input_encoders): encoded_inputs.append( encoder(Me.SparseTensor( features.F[:, start_dim:start_dim + dim], coordinate_manager=cm, coordinate_map_key=key )) ) start_dim += dim encoded_inputs = Me.cat(*encoded_inputs) lvls = len(self.level_encoders) # high to low resolution encoder_outputs = [] encoder_inputs = [encoded_inputs] for idx in range(len(self.level_encoders)): encoded = self.level_encoders[idx](encoder_inputs[idx]) if self.use_ms_features and idx < len(self.feature_adapters): feat = self.feature_adapters[idx](ms_features[idx]) if isinstance(encoded, torch.Tensor): encoded = encoded + feat else: feat = Me.SparseTensor( feat.F, coordinates=feat.C, tensor_stride=feat.tensor_stride, coordinate_manager=encoded.coordinate_manager ) encoded = encoded + feat encoder_outputs.append(encoded) if idx < lvls - 1: encoder_inputs.append(self.enc_level_conversion[idx](encoded)) # low to high resolution decoder_outputs = [] decoder_inputs = [encoder_outputs[-1]] pred_occupancies = [] pred_segms = [] pred_geometry = None # U-Net for idx in reversed(range(lvls)): decoded = self.level_decoders[idx](decoder_inputs[lvls - 1 - idx]) decoder_outputs.append(decoded) if idx <= 1: # level 128, 256 occupancy = self.level_occupancy_heads[idx](decoded) # mask invalid voxels outside of frustum valid_mask = ( (occupancy.C[:, 1:] >= 0) & (occupancy.C[:, 1:] < self.frustum_dimensions) ).all(-1) pred_occupancies.append(Me.MinkowskiPruning()(occupancy, valid_mask)) pruning_mask = (Me.MinkowskiSigmoid()(occupancy).F.squeeze(-1) > 0.5) & valid_mask sparse_out = Me.MinkowskiPruning()(decoded, pruning_mask) if idx > 0: # level 128 sparse_out = sparse_cat_union(encoder_outputs[idx - 1], sparse_out) valid_mask = ( (sparse_out.C[:, 1:] >= 0) & (sparse_out.C[:, 1:] < self.frustum_dimensions) ).all(-1) decoder_inputs.append(Me.MinkowskiPruning()(sparse_out, valid_mask)) else: # level 256 pred_geometry = self.geometry_head(sparse_out) predicted_values = pred_geometry.F predicted_values = torch.clamp(predicted_values, 0.0, self.truncation) pred_geometry = Me.SparseTensor( predicted_values, coordinate_manager=pred_geometry.coordinate_manager, coordinate_map_key=pred_geometry.coordinate_map_key, ) valid_mask = ( (pred_geometry.C[:, 1:] >= 0) & (pred_geometry.C[:, 1:] < self.frustum_dimensions) ).all(-1) pred_geometry = Me.MinkowskiPruning()(pred_geometry, valid_mask) queries = self.level_segm_query_projection[idx](segm_queries) segm_features = self.level_segm_embeddings[idx](sparse_out) pred_segm = self.forward_sparse_segm(segm_features, queries) valid_mask = ( (pred_segm.C[:, 1:] >= 0) & (pred_segm.C[:, 1:] < self.frustum_dimensions) ).all(-1) pred_segms.append(Me.MinkowskiPruning()(pred_segm, valid_mask)) elif idx == 2: # level 64 decoded = torch.cat([encoder_inputs[idx], decoded], dim=1) occupancy = self.level_occupancy_heads[idx](decoded.permute(0, 2, 3, 4, 1)).squeeze(-1) pred_occupancies.append(occupancy.masked_fill(~frustum_mask.squeeze(1), -torch.inf)) queries = self.level_segm_query_projection[idx](segm_queries) segm_features = self.level_segm_embeddings[idx](decoded) pred_segm = torch.einsum("bqc,bchwd->bqhwd", queries, segm_features) pred_segms.append(pred_segm.masked_fill(~frustum_mask, -torch.inf)) pruning_mask = (occupancy.sigmoid() > 0.5) & frustum_mask.squeeze(1) coords = pruning_mask.nonzero() sparse_out = decoded[coords[:, 0], :, coords[:, 1], coords[:, 2], coords[:, 3]] encoded = encoder_outputs[idx - 1] stride = encoded.tensor_stride coords = coords.clone() coords[:, 1:] *= torch.tensor(stride, device=coords.device) sparse_out = Me.SparseTensor( sparse_out, coordinates=coords.int().contiguous(), tensor_stride=stride, coordinate_manager=cm ) decoder_inputs.append(sparse_cat_union(encoded, sparse_out)) else: decoder_inputs.append(torch.cat([encoder_inputs[idx], decoded], dim=1)) return { "pred_geometry": pred_geometry, "pred_occupancies": pred_occupancies, "pred_segms": pred_segms, }