vpraveen-nv's picture
Update model inference code and environment setup instructions (#4)
f4a0919 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blocks for Panoptic Recon 3D."""
import torch.nn as nn
from torch import Tensor
from typing import Optional
import torch.nn.functional as F
import MinkowskiEngine as Me
class ProjectionBlock(nn.Module):
"""Projection block for depth projection."""
def __init__(self, in_feature, out_feature):
"""Init"""
super().__init__()
self.conv_block1 = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
self.conv_block2 = nn.Conv2d(
out_feature, out_feature,
kernel_size=1, stride=1,
padding=0
)
def forward(self, x, target_size):
"""Forward"""
x = self.conv_block1(x)
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
x = self.conv_block2(x)
return x
class ConvBlock(nn.Module):
"""Conv block for depth projection."""
def __init__(self, in_feature, out_feature):
"""Init"""
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
def forward(self, x):
"""Forward"""
return self.conv_block(x)
class DepthProjector(nn.Module):
"""Depth projector module."""
def __init__(
self,
in_channels: int = 256,
out_channels: int = 256,
num_proj_convs: int = 4,
**kwargs
):
"""Init"""
super(DepthProjector, self).__init__()
self.proj_convs1 = nn.ModuleList([
ConvBlock(in_channels, in_channels) for _ in range(num_proj_convs)
])
self.proj_convs2 = nn.ModuleList([
nn.Conv2d(
in_channels, out_channels,
kernel_size=1, stride=1,
padding=0
) for _ in range(num_proj_convs)
])
def forward(self, depth_features, depth_feature_shape, size_list):
"""Forward"""
output_list = []
size_list.append(depth_feature_shape)
for i, (_, feat_shape) in enumerate(zip(
self.proj_convs1,
size_list[::-1]
)):
feat = depth_features[i]
output = self.proj_convs1[i](feat)
output = F.interpolate(output, feat_shape, mode="bilinear", align_corners=False)
output = self.proj_convs2[i](output)
output_list.append(output)
return depth_features[-1], output_list[1:][::-1]
class SelfAttentionLayer(nn.Module):
"""Self Attention Layer."""
def __init__(
self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False, export=False
):
"""Init."""
super().__init__()
self.export = export
if export:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
else:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
"""Reset parameters."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
"""Add positional embedding."""
return tensor if pos is None else tensor + pos
def forward_post(
self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward post norm."""
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(
self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward pre norm."""
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(
self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward."""
if self.normalize_before:
return self.forward_pre(
tgt, tgt_mask, tgt_key_padding_mask, query_pos
)
return self.forward_post(
tgt, tgt_mask, tgt_key_padding_mask, query_pos
)
class CrossAttentionLayer(nn.Module):
"""Cross attention layer."""
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False, export=False):
"""Init."""
super().__init__()
self.export = export
if export:
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
else:
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
"""Reset parameters."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
"""Add positional embedding."""
return tensor if pos is None else tensor + pos
def forward_post(
self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward post norm."""
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(
self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward pre norm."""
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(
self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None
):
"""Forward pass."""
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
"""Feedforward layer."""
def __init__(
self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False
):
"""Init."""
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
"""Reset parameters."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
"""Add positional embedding."""
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
"""Forward post norm."""
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
"""Forward pre norm."""
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
"""Forward."""
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise NotImplementedError(f"activation should be relu/gelu, not {activation}.")
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
"""Init."""
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
"""Forward pass."""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
# 3D blocks
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, sparse=False):
"""3x3 convolution with padding"""
if sparse:
return Me.MinkowskiConvolution(
in_planes, out_planes, kernel_size=3,
stride=stride, dilation=dilation,
bias=False, dimension=3
)
else:
return nn.Conv3d(
in_planes, out_planes, kernel_size=3,
stride=stride, padding=dilation,
groups=groups, bias=False,
dilation=dilation
)
class BasicBlock3D(nn.Module):
"""Basic block for 3D."""
def __init__(
self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, sparse=False
):
"""Init."""
super().__init__()
if norm_layer is None:
norm_layer = nn.InstanceNorm3d if not sparse else Me.MinkowskiInstanceNorm
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride, sparse=sparse)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True) if not sparse else Me.MinkowskiReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, sparse=sparse)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""Forward."""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class SparseBasicBlock3D(BasicBlock3D):
"""Sparse basic block for 3D."""
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
"""Init."""
super().__init__(inplanes, planes,
stride=stride, downsample=downsample, groups=groups,
base_width=base_width, dilation=dilation,
norm_layer=norm_layer, sparse=True)