File size: 4,534 Bytes
f4a0919 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Back projection module."""
import torch
from torch import nn
from ..utils.frustum import (
generate_frustum,
generate_frustum_volume,
compute_camera2frustum_transform
)
class BackProjection(nn.Module):
"""Back projection module."""
def __init__(self, cfg):
"""Initialize the back projection module."""
super().__init__()
self.image_size = cfg.dataset.depth_size
self.depth_min = cfg.dataset.depth_min
self.depth_max = cfg.dataset.depth_max
self.voxel_size = cfg.model.projection.voxel_size
self.frustum_dimensions = torch.tensor([cfg.model.frustum3d.frustum_dims] * 3)
def forward(
self, shp, intrinsics, frustum_masks=None, room_masks=None
):
"""Forward pass."""
device = intrinsics.device
if frustum_masks is None:
frustum_masks = torch.ones(
[len(intrinsics), *self.frustum_dimensions],
dtype=torch.bool, device=device
)
len_shp = len(frustum_masks.shape)
if len_shp == 3:
frustum_masks = frustum_masks[None]
intrinsics = intrinsics[None]
kepts, mappings = [], []
for bi, (intrinsic, frustum_mask) in enumerate(zip(intrinsics, frustum_masks)):
camera2frustum = compute_camera2frustum_transform(
intrinsic.cpu(), self.image_size, self.depth_min,
self.depth_max, self.voxel_size
).to(device)
intrinsic_inverse = torch.inverse(intrinsic)
coordinates = torch.nonzero(frustum_mask)
grid_coordinates = coordinates.clone()
grid_coordinates[:, :2] = 256 - grid_coordinates[:, :2]
padding_offsets = self.compute_frustum_padding(intrinsic_inverse)
grid_coordinates = grid_coordinates - padding_offsets - torch.tensor([1., 1., 1.], device=device)
grid_coordinates = torch.cat([
grid_coordinates, torch.ones(len(grid_coordinates), 1, device=device)], 1
)
pointcloud = torch.mm(torch.inverse(camera2frustum), grid_coordinates.t())
depth_pixels = torch.mm(intrinsic, pointcloud)
depth = depth_pixels[2]
coord = depth_pixels[0:2] / depth
coord = torch.cat([coord, coordinates[:, 2][None]], 0).permute(1, 0)
kept = (depth <= self.depth_max) * \
(depth >= self.depth_min) * \
(coord[:, 0] < shp[1]) * (coord[:, 1] < shp[0])
coordinates = coordinates[kept]
depth = depth[kept, None]
mapping = torch.zeros(256, 256, 256, 5, device=depth.device) - 1.
mapping[coordinates[:, 0], coordinates[:, 1], coordinates[:, 2]] = \
torch.cat([torch.ones_like(depth) * bi, coord[kept], depth], -1)
kept = (mapping >= 0).all(-1)
if room_masks is not None:
mapping_kept = mapping[kept].long()
kept[kept.clone()] = room_masks[bi, 0, mapping_kept[:, 2], mapping_kept[:, 1]]
kepts.append(kept)
mappings.append(mapping)
if len_shp == 3:
kepts = kepts[0]
mappings = mappings[0][..., 1:]
else:
kepts = torch.stack(kepts, 0)
mappings = torch.stack(mappings, 0)
return kepts, mappings
def compute_frustum_padding(self, intrinsic_inverse: torch.Tensor) -> torch.Tensor:
"""Compute frustum padding."""
frustum = generate_frustum(
self.image_size, intrinsic_inverse.cpu(), self.depth_min, self.depth_max
)
dimensions, _ = generate_frustum_volume(frustum, self.voxel_size)
difference = (
self.frustum_dimensions - torch.tensor(dimensions)
).float().to(intrinsic_inverse.device)
padding_offsets = difference / 2
return padding_offsets
|