nvpanoptix-3d / nvpanoptix_3d /mp_occ /back_projection.py
vpraveen-nv's picture
Update model inference code and environment setup instructions (#4)
f4a0919 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Back projection module."""
import torch
from torch import nn
from ..utils.frustum import (
generate_frustum,
generate_frustum_volume,
compute_camera2frustum_transform
)
class BackProjection(nn.Module):
"""Back projection module."""
def __init__(self, cfg):
"""Initialize the back projection module."""
super().__init__()
self.image_size = cfg.dataset.depth_size
self.depth_min = cfg.dataset.depth_min
self.depth_max = cfg.dataset.depth_max
self.voxel_size = cfg.model.projection.voxel_size
self.frustum_dimensions = torch.tensor([cfg.model.frustum3d.frustum_dims] * 3)
def forward(
self, shp, intrinsics, frustum_masks=None, room_masks=None
):
"""Forward pass."""
device = intrinsics.device
if frustum_masks is None:
frustum_masks = torch.ones(
[len(intrinsics), *self.frustum_dimensions],
dtype=torch.bool, device=device
)
len_shp = len(frustum_masks.shape)
if len_shp == 3:
frustum_masks = frustum_masks[None]
intrinsics = intrinsics[None]
kepts, mappings = [], []
for bi, (intrinsic, frustum_mask) in enumerate(zip(intrinsics, frustum_masks)):
camera2frustum = compute_camera2frustum_transform(
intrinsic.cpu(), self.image_size, self.depth_min,
self.depth_max, self.voxel_size
).to(device)
intrinsic_inverse = torch.inverse(intrinsic)
coordinates = torch.nonzero(frustum_mask)
grid_coordinates = coordinates.clone()
grid_coordinates[:, :2] = 256 - grid_coordinates[:, :2]
padding_offsets = self.compute_frustum_padding(intrinsic_inverse)
grid_coordinates = grid_coordinates - padding_offsets - torch.tensor([1., 1., 1.], device=device)
grid_coordinates = torch.cat([
grid_coordinates, torch.ones(len(grid_coordinates), 1, device=device)], 1
)
pointcloud = torch.mm(torch.inverse(camera2frustum), grid_coordinates.t())
depth_pixels = torch.mm(intrinsic, pointcloud)
depth = depth_pixels[2]
coord = depth_pixels[0:2] / depth
coord = torch.cat([coord, coordinates[:, 2][None]], 0).permute(1, 0)
kept = (depth <= self.depth_max) * \
(depth >= self.depth_min) * \
(coord[:, 0] < shp[1]) * (coord[:, 1] < shp[0])
coordinates = coordinates[kept]
depth = depth[kept, None]
mapping = torch.zeros(256, 256, 256, 5, device=depth.device) - 1.
mapping[coordinates[:, 0], coordinates[:, 1], coordinates[:, 2]] = \
torch.cat([torch.ones_like(depth) * bi, coord[kept], depth], -1)
kept = (mapping >= 0).all(-1)
if room_masks is not None:
mapping_kept = mapping[kept].long()
kept[kept.clone()] = room_masks[bi, 0, mapping_kept[:, 2], mapping_kept[:, 1]]
kepts.append(kept)
mappings.append(mapping)
if len_shp == 3:
kepts = kepts[0]
mappings = mappings[0][..., 1:]
else:
kepts = torch.stack(kepts, 0)
mappings = torch.stack(mappings, 0)
return kepts, mappings
def compute_frustum_padding(self, intrinsic_inverse: torch.Tensor) -> torch.Tensor:
"""Compute frustum padding."""
frustum = generate_frustum(
self.image_size, intrinsic_inverse.cpu(), self.depth_min, self.depth_max
)
dimensions, _ = generate_frustum_volume(frustum, self.voxel_size)
difference = (
self.frustum_dimensions - torch.tensor(dimensions)
).float().to(intrinsic_inverse.device)
padding_offsets = difference / 2
return padding_offsets