File size: 4,111 Bytes
434b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Zexin He in 2023-2024.
# The modifications are subject to the same license as the original.


"""
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
"""

import torch


class RaySampler(torch.nn.Module):
    def __init__(self):
        super().__init__()
        (
            self.ray_origins_h,
            self.ray_directions,
            self.depths,
            self.image_coords,
            self.rendering_options,
        ) = (None, None, None, None, None)

    @torch.compile
    def forward(self, cam2world_matrix, intrinsics, resolutions, anchors, region_size):
        """
        Create batches of rays and return origins and directions.

        cam2world_matrix: (N, 4, 4)
        intrinsics: (N, 3, 3)
        resolutions: (N, 1)
        anchors: (N, 2)
        region_size: int

        ray_origins: (N, M, 3)
        ray_dirs: (N, M, 2)
        """

        N, M = cam2world_matrix.shape[0], region_size**2
        cam_locs_world = cam2world_matrix[:, :3, 3]
        fx = intrinsics[:, 0, 0]
        fy = intrinsics[:, 1, 1]
        cx = intrinsics[:, 0, 2]
        cy = intrinsics[:, 1, 2]
        sk = intrinsics[:, 0, 1]

        uv = torch.stack(
            torch.meshgrid(
                torch.arange(
                    region_size, dtype=torch.float32, device=cam2world_matrix.device
                ),
                torch.arange(
                    region_size, dtype=torch.float32, device=cam2world_matrix.device
                ),
                indexing="ij",
            )
        )
        uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
        uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)

        # anchors are indexed as normal (row, col) but uv is indexed as (x, y)
        x_cam = (uv[:, :, 0].view(N, -1) + anchors[:, 1].unsqueeze(-1)) * (
            1.0 / resolutions
        ) + (0.5 / resolutions)
        y_cam = (uv[:, :, 1].view(N, -1) + anchors[:, 0].unsqueeze(-1)) * (
            1.0 / resolutions
        ) + (0.5 / resolutions)
        z_cam = torch.ones((N, M), device=cam2world_matrix.device)

        x_lift = (
            (
                x_cam
                - cx.unsqueeze(-1)
                + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
                - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
            )
            / fx.unsqueeze(-1)
            * z_cam
        )
        y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

        cam_rel_points = torch.stack(
            (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
        )

        _opencv2blender = (
            torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, 0],
                    [0, 0, 0, 1],
                ],
                dtype=torch.float32,
                device=cam2world_matrix.device,
            )
            .unsqueeze(0)
            .repeat(N, 1, 1)
        )

        cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)

        world_rel_points = torch.bmm(
            cam2world_matrix, cam_rel_points.permute(0, 2, 1)
        ).permute(0, 2, 1)[:, :, :3]

        ray_dirs = world_rel_points - cam_locs_world[:, None, :]
        ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)

        ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)

        return ray_origins, ray_dirs