nvpanoptix-3d / nvpanoptix_3d /mp_occ /multiplane_occupancy.py
vpraveen-nv's picture
Update model inference code and environment setup instructions (#4)
f4a0919 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-plane occupancy head."""
import torch
from torch import nn
from torch.nn import functional as F
class _UpProjection(nn.Module):
"""Up-projection block."""
def __init__(self, num_input_features, num_output_features):
"""Initialize the up-projection block."""
super().__init__()
self.conv1 = nn.Conv2d(
num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False
)
self.bn1 = nn.BatchNorm2d(num_output_features)
self.relu = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(
num_output_features, num_output_features, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1_2 = nn.BatchNorm2d(num_output_features)
self.conv2 = nn.Conv2d(
num_input_features, num_output_features, kernel_size=5, stride=1, padding=2, bias=False
)
self.bn2 = nn.BatchNorm2d(num_output_features)
def forward(self, x, size):
"""Forward pass."""
x = F.interpolate(x, size=size, mode="bilinear", align_corners=True)
x_conv1 = self.relu(self.bn1(self.conv1(x)))
bran1 = self.bn1_2(self.conv1_2(x_conv1))
bran2 = self.bn2(self.conv2(x))
out = self.relu(bran1 + bran2)
return out
class D(nn.Module):
"""Decoder module."""
def __init__(self, block_channel):
"""Initialize the decoder module."""
super().__init__()
self.conv = nn.Conv2d(
block_channel[0], block_channel[1], kernel_size=1, stride=1, bias=False
)
self.bn = nn.BatchNorm2d(block_channel[1])
self.up1 = _UpProjection(num_input_features=block_channel[1],
num_output_features=block_channel[2])
self.up2 = _UpProjection(num_input_features=block_channel[2],
num_output_features=block_channel[3])
add_feat_channel = block_channel[3]
self.up3 = _UpProjection(num_input_features=add_feat_channel,
num_output_features=add_feat_channel // 2)
add_feat_channel = add_feat_channel // 2
self.up4 = _UpProjection(num_input_features=add_feat_channel,
num_output_features=add_feat_channel // 2)
def forward(self, x_block1, x_block2, x_block3, x_block4):
"""Forward pass."""
x_d0 = F.relu(self.bn(self.conv(x_block4)))
x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)])
x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)])
x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)])
x_d4 = self.up4(x_d3, [x_block1.size(2) * 2, x_block1.size(3) * 2])
return x_d4
class MFF(nn.Module):
"""Multi-feature fusion module."""
def __init__(self, block_channel, num_features=64):
"""Initialize the multi-feature fusion module."""
super().__init__()
self.up1 = _UpProjection(num_input_features=block_channel[3], num_output_features=16)
self.up2 = _UpProjection(num_input_features=block_channel[2], num_output_features=16)
self.up3 = _UpProjection(num_input_features=block_channel[1], num_output_features=16)
self.up4 = _UpProjection(num_input_features=block_channel[0], num_output_features=16)
self.conv = nn.Conv2d(
num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False
)
self.bn = nn.BatchNorm2d(num_features)
def forward(self, x_block1, x_block2, x_block3, x_block4, size):
"""Forward pass."""
x_m1 = self.up1(x_block1, size)
x_m2 = self.up2(x_block2, size)
x_m3 = self.up3(x_block3, size)
x_m4 = self.up4(x_block4, size)
x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1)))
x = F.relu(x)
return x
class R(nn.Module):
"""Occupancy head module."""
def __init__(self, channel, num_class=1):
"""Initialize the occupancy head module."""
super().__init__()
self.target_size = (120, 160)
self.resize = _UpProjection(num_input_features=channel, num_output_features=channel)
self.conv0 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False)
self.bn0 = nn.BatchNorm2d(channel)
self.conv1 = nn.Conv2d(channel, channel, kernel_size=5, stride=1, padding=2, bias=False)
self.bn1 = nn.BatchNorm2d(channel)
self.conv2 = nn.Conv2d(channel, num_class, kernel_size=5, stride=1, padding=2, bias=True)
def forward(self, x):
"""Forward pass."""
x0 = self.resize(x, self.target_size) # resize to 120*160
x0 = self.conv0(x0)
x0 = self.bn0(x0)
x0 = F.relu(x0)
x1 = self.conv1(x0)
x1 = self.bn1(x1)
x1 = F.relu(x1)
x2 = self.conv2(x1)
return x2
class MultiPlaneOccupancyHead(nn.Module):
"""Multi-plane occupancy head."""
def __init__(self):
"""Initialize the multi-plane occupancy head."""
super().__init__()
block_channel = [2048, 1024, 512, 256]
self.feature_key = ['res2', 'res3', 'res4', 'res5']
feature_channels = 64
self.D = D(block_channel)
self.MFF = MFF(block_channel, feature_channels)
head_channels = block_channel[-1] // 4 + feature_channels
self.num_classes = 100
self.prediction = R(head_channels, self.num_classes)
def forward(self, x):
"""Forward pass."""
x_block1, x_block2, x_block3, x_block4 = x[self.feature_key[0]], x[self.feature_key[1]], \
x[self.feature_key[2]], x[self.feature_key[3]]
x_decoder = self.D(x_block1, x_block2, x_block3, x_block4)
x_mff = self.MFF(
x_block1, x_block2, x_block3, x_block4, [x_decoder.size(2), x_decoder.size(3)]
)
x_feat = torch.cat((x_decoder, x_mff), 1)
occ_pred = self.prediction(x_feat)
return occ_pred