# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Blocks for Panoptic Recon 3D.""" import torch.nn as nn from torch import Tensor from typing import Optional import torch.nn.functional as F import MinkowskiEngine as Me class ProjectionBlock(nn.Module): """Projection block for depth projection.""" def __init__(self, in_feature, out_feature): """Init""" super().__init__() self.conv_block1 = nn.Sequential( nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_feature), nn.ReLU(True) ) self.conv_block2 = nn.Conv2d( out_feature, out_feature, kernel_size=1, stride=1, padding=0 ) def forward(self, x, target_size): """Forward""" x = self.conv_block1(x) x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) x = self.conv_block2(x) return x class ConvBlock(nn.Module): """Conv block for depth projection.""" def __init__(self, in_feature, out_feature): """Init""" super().__init__() self.conv_block = nn.Sequential( nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_feature), nn.ReLU(True) ) def forward(self, x): """Forward""" return self.conv_block(x) class DepthProjector(nn.Module): """Depth projector module.""" def __init__( self, in_channels: int = 256, out_channels: int = 256, num_proj_convs: int = 4, **kwargs ): """Init""" super(DepthProjector, self).__init__() self.proj_convs1 = nn.ModuleList([ ConvBlock(in_channels, in_channels) for _ in range(num_proj_convs) ]) self.proj_convs2 = nn.ModuleList([ nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) for _ in range(num_proj_convs) ]) def forward(self, depth_features, depth_feature_shape, size_list): """Forward""" output_list = [] size_list.append(depth_feature_shape) for i, (_, feat_shape) in enumerate(zip( self.proj_convs1, size_list[::-1] )): feat = depth_features[i] output = self.proj_convs1[i](feat) output = F.interpolate(output, feat_shape, mode="bilinear", align_corners=False) output = self.proj_convs2[i](output) output_list.append(output) return depth_features[-1], output_list[1:][::-1] class SelfAttentionLayer(nn.Module): """Self Attention Layer.""" def __init__( self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False, export=False ): """Init.""" super().__init__() self.export = export if export: self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) else: self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): """Reset parameters.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): """Add positional embedding.""" return tensor if pos is None else tensor + pos def forward_post( self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward post norm.""" q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre( self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward pre norm.""" tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout(tgt2) return tgt def forward( self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward.""" if self.normalize_before: return self.forward_pre( tgt, tgt_mask, tgt_key_padding_mask, query_pos ) return self.forward_post( tgt, tgt_mask, tgt_key_padding_mask, query_pos ) class CrossAttentionLayer(nn.Module): """Cross attention layer.""" def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False, export=False): """Init.""" super().__init__() self.export = export if export: self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) else: self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): """Reset parameters.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): """Add positional embedding.""" return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward post norm.""" tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre( self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward pre norm.""" tgt2 = self.norm(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask )[0] tgt = tgt + self.dropout(tgt2) return tgt def forward( self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None ): """Forward pass.""" if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): """Feedforward layer.""" def __init__( self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False ): """Init.""" super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): """Reset parameters.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): """Add positional embedding.""" return tensor if pos is None else tensor + pos def forward_post(self, tgt): """Forward post norm.""" tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): """Forward pre norm.""" tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): """Forward.""" if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise NotImplementedError(f"activation should be relu/gelu, not {activation}.") class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): """Init.""" super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): """Forward pass.""" for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x # 3D blocks def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, sparse=False): """3x3 convolution with padding""" if sparse: return Me.MinkowskiConvolution( in_planes, out_planes, kernel_size=3, stride=stride, dilation=dilation, bias=False, dimension=3 ) else: return nn.Conv3d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation ) class BasicBlock3D(nn.Module): """Basic block for 3D.""" def __init__( self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, sparse=False ): """Init.""" super().__init__() if norm_layer is None: norm_layer = nn.InstanceNorm3d if not sparse else Me.MinkowskiInstanceNorm if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.conv1 = conv3x3(inplanes, planes, stride, sparse=sparse) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) if not sparse else Me.MinkowskiReLU(inplace=True) self.conv2 = conv3x3(planes, planes, sparse=sparse) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): """Forward.""" identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class SparseBasicBlock3D(BasicBlock3D): """Sparse basic block for 3D.""" def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): """Init.""" super().__init__(inplanes, planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, norm_layer=norm_layer, sparse=True)