File size: 4,534 Bytes
f4a0919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Back projection module."""

import torch
from torch import nn
from ..utils.frustum import (
    generate_frustum,
    generate_frustum_volume,
    compute_camera2frustum_transform
)


class BackProjection(nn.Module):
    """Back projection module."""

    def __init__(self, cfg):
        """Initialize the back projection module."""
        super().__init__()
        self.image_size = cfg.dataset.depth_size
        self.depth_min = cfg.dataset.depth_min
        self.depth_max = cfg.dataset.depth_max
        self.voxel_size = cfg.model.projection.voxel_size
        self.frustum_dimensions = torch.tensor([cfg.model.frustum3d.frustum_dims] * 3)

    def forward(
        self, shp, intrinsics, frustum_masks=None, room_masks=None
    ):
        """Forward pass."""
        device = intrinsics.device
        if frustum_masks is None:
            frustum_masks = torch.ones(
                [len(intrinsics), *self.frustum_dimensions],
                dtype=torch.bool, device=device
            )
        len_shp = len(frustum_masks.shape)
        if len_shp == 3:
            frustum_masks = frustum_masks[None]
            intrinsics = intrinsics[None]

        kepts, mappings = [], []
        for bi, (intrinsic, frustum_mask) in enumerate(zip(intrinsics, frustum_masks)):
            camera2frustum = compute_camera2frustum_transform(
                intrinsic.cpu(), self.image_size, self.depth_min,
                self.depth_max, self.voxel_size
            ).to(device)
            intrinsic_inverse = torch.inverse(intrinsic)
            coordinates = torch.nonzero(frustum_mask)
            grid_coordinates = coordinates.clone()
            grid_coordinates[:, :2] = 256 - grid_coordinates[:, :2]

            padding_offsets = self.compute_frustum_padding(intrinsic_inverse)
            grid_coordinates = grid_coordinates - padding_offsets - torch.tensor([1., 1., 1.], device=device)
            grid_coordinates = torch.cat([
                grid_coordinates, torch.ones(len(grid_coordinates), 1, device=device)], 1
            )
            pointcloud = torch.mm(torch.inverse(camera2frustum), grid_coordinates.t())
            depth_pixels = torch.mm(intrinsic, pointcloud)

            depth = depth_pixels[2]
            coord = depth_pixels[0:2] / depth
            coord = torch.cat([coord, coordinates[:, 2][None]], 0).permute(1, 0)

            kept = (depth <= self.depth_max) * \
                   (depth >= self.depth_min) * \
                   (coord[:, 0] < shp[1]) * (coord[:, 1] < shp[0])
            coordinates = coordinates[kept]
            depth = depth[kept, None]

            mapping = torch.zeros(256, 256, 256, 5, device=depth.device) - 1.
            mapping[coordinates[:, 0], coordinates[:, 1], coordinates[:, 2]] = \
                torch.cat([torch.ones_like(depth) * bi, coord[kept], depth], -1)

            kept = (mapping >= 0).all(-1)

            if room_masks is not None:
                mapping_kept = mapping[kept].long()
                kept[kept.clone()] = room_masks[bi, 0, mapping_kept[:, 2], mapping_kept[:, 1]]

            kepts.append(kept)
            mappings.append(mapping)

        if len_shp == 3:
            kepts = kepts[0]
            mappings = mappings[0][..., 1:]
        else:
            kepts = torch.stack(kepts, 0)
            mappings = torch.stack(mappings, 0)

        return kepts, mappings

    def compute_frustum_padding(self, intrinsic_inverse: torch.Tensor) -> torch.Tensor:
        """Compute frustum padding."""
        frustum = generate_frustum(
            self.image_size, intrinsic_inverse.cpu(), self.depth_min, self.depth_max
        )
        dimensions, _ = generate_frustum_volume(frustum, self.voxel_size)
        difference = (
            self.frustum_dimensions - torch.tensor(dimensions)
        ).float().to(intrinsic_inverse.device)
        padding_offsets = difference / 2
        return padding_offsets