File size: 4,111 Bytes
434b0b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Zexin He in 2023-2024.
# The modifications are subject to the same license as the original.
"""
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
"""
import torch
class RaySampler(torch.nn.Module):
def __init__(self):
super().__init__()
(
self.ray_origins_h,
self.ray_directions,
self.depths,
self.image_coords,
self.rendering_options,
) = (None, None, None, None, None)
@torch.compile
def forward(self, cam2world_matrix, intrinsics, resolutions, anchors, region_size):
"""
Create batches of rays and return origins and directions.
cam2world_matrix: (N, 4, 4)
intrinsics: (N, 3, 3)
resolutions: (N, 1)
anchors: (N, 2)
region_size: int
ray_origins: (N, M, 3)
ray_dirs: (N, M, 2)
"""
N, M = cam2world_matrix.shape[0], region_size**2
cam_locs_world = cam2world_matrix[:, :3, 3]
fx = intrinsics[:, 0, 0]
fy = intrinsics[:, 1, 1]
cx = intrinsics[:, 0, 2]
cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1]
uv = torch.stack(
torch.meshgrid(
torch.arange(
region_size, dtype=torch.float32, device=cam2world_matrix.device
),
torch.arange(
region_size, dtype=torch.float32, device=cam2world_matrix.device
),
indexing="ij",
)
)
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
# anchors are indexed as normal (row, col) but uv is indexed as (x, y)
x_cam = (uv[:, :, 0].view(N, -1) + anchors[:, 1].unsqueeze(-1)) * (
1.0 / resolutions
) + (0.5 / resolutions)
y_cam = (uv[:, :, 1].view(N, -1) + anchors[:, 0].unsqueeze(-1)) * (
1.0 / resolutions
) + (0.5 / resolutions)
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
x_lift = (
(
x_cam
- cx.unsqueeze(-1)
+ cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
- sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
)
/ fx.unsqueeze(-1)
* z_cam
)
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
cam_rel_points = torch.stack(
(x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
)
_opencv2blender = (
torch.tensor(
[
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
device=cam2world_matrix.device,
)
.unsqueeze(0)
.repeat(N, 1, 1)
)
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
world_rel_points = torch.bmm(
cam2world_matrix, cam_rel_points.permute(0, 2, 1)
).permute(0, 2, 1)[:, :, :3]
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
return ray_origins, ray_dirs
|