File size: 4,838 Bytes
94dc344 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# TODO: all this potentially goes to PyTorch3D
import math
from typing import Tuple
import pytorch3d as pt3d
import torch
from pytorch3d.renderer.cameras import CamerasBase
def jitter_extrinsics(
R: torch.Tensor,
T: torch.Tensor,
max_angle: float = (math.pi * 2.0),
translation_std: float = 1.0,
scale_std: float = 0.3,
):
"""
Jitter the extrinsic camera parameters `R` and `T` with a random similarity
transformation. The transformation rotates by a random angle between [0, max_angle];
scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is
a random sample from a normal distrubtion with zero mean and variance scale_std;
and translates by a 3D offset sampled from N(0, translation_std).
"""
assert all(x >= 0.0 for x in (max_angle, translation_std, scale_std))
N = R.shape[0]
R_jit = pt3d.transforms.random_rotations(1, device=R.device)
R_jit = pt3d.transforms.so3_exponential_map(
pt3d.transforms.so3_log_map(R_jit) * max_angle
)
T_jit = torch.randn_like(R_jit[:1, :, 0]) * translation_std
rigid_transform = pt3d.ops.eyes(dim=4, N=N, device=R.device)
rigid_transform[:, :3, :3] = R_jit.expand(N, 3, 3)
rigid_transform[:, 3, :3] = T_jit.expand(N, 3)
scale_jit = torch.exp(torch.randn_like(T_jit[:, 0]) * scale_std).expand(N)
return apply_camera_alignment(R, T, rigid_transform, scale_jit)
def apply_camera_alignment(
R: torch.Tensor,
T: torch.Tensor,
rigid_transform: torch.Tensor,
scale: torch.Tensor,
):
"""
Args:
R: Camera rotation matrix of shape (N, 3, 3).
T: Camera translation of shape (N, 3).
rigid_transform: A tensor of shape (N, 4, 4) representing a batch of
N 4x4 tensors that map the scene pointcloud from misaligned coords
to the aligned space.
scale: A list of N scaling factors. A tensor of shape (N,)
Returns:
R_aligned: The aligned rotations R.
T_aligned: The aligned translations T.
"""
R_rigid = rigid_transform[:, :3, :3]
T_rigid = rigid_transform[:, 3:, :3]
R_aligned = R_rigid.permute(0, 2, 1).bmm(R)
T_aligned = scale[:, None] * (T - (T_rigid @ R_aligned)[:, 0])
return R_aligned, T_aligned
def get_min_max_depth_bounds(cameras, scene_center, scene_extent):
"""
Estimate near/far depth plane as:
near = dist(cam_center, self.scene_center) - self.scene_extent
far = dist(cam_center, self.scene_center) + self.scene_extent
"""
cam_center = cameras.get_camera_center()
center_dist = (
((cam_center - scene_center.to(cameras.R)[None]) ** 2)
.sum(dim=-1)
.clamp(0.001)
.sqrt()
)
center_dist = center_dist.clamp(scene_extent + 1e-3)
min_depth = center_dist - scene_extent
max_depth = center_dist + scene_extent
return min_depth, max_depth
def volumetric_camera_overlaps(
cameras: CamerasBase,
scene_extent: float = 8.0,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
resol: int = 16,
weigh_by_ray_angle: bool = True,
):
"""
Compute the overlaps between viewing frustrums of all pairs of cameras
in `cameras`.
"""
device = cameras.device
ba = cameras.R.shape[0]
n_vox = int(resol**3)
grid = pt3d.structures.Volumes(
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
voxel_size=2.0 * scene_extent / resol,
).get_coord_grid(world_coordinates=True)
grid = grid.view(1, n_vox, 3).expand(ba, n_vox, 3)
gridp = cameras.transform_points(grid, eps=1e-2)
proj_in_camera = (
torch.prod((gridp[..., :2].abs() <= 1.0), dim=-1)
* (gridp[..., 2] > 0.0).float()
) # ba x n_vox
if weigh_by_ray_angle:
rays = torch.nn.functional.normalize(
grid - cameras.get_camera_center()[:, None], dim=-1
)
rays_masked = rays * proj_in_camera[..., None]
# - slow and readable:
# inter = torch.zeros(ba, ba)
# for i1 in range(ba):
# for i2 in range(ba):
# inter[i1, i2] = (
# 1 + (rays_masked[i1] * rays_masked[i2]
# ).sum(dim=-1)).sum()
# - fast:
rays_masked = rays_masked.view(ba, n_vox * 3)
inter = n_vox + (rays_masked @ rays_masked.t())
else:
inter = proj_in_camera @ proj_in_camera.t()
mass = torch.diag(inter)
iou = inter / (mass[:, None] + mass[None, :] - inter).clamp(0.1)
return iou
|