File size: 4,838 Bytes
94dc344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


# TODO: all this potentially goes to PyTorch3D

import math
from typing import Tuple

import pytorch3d as pt3d
import torch
from pytorch3d.renderer.cameras import CamerasBase


def jitter_extrinsics(
    R: torch.Tensor,
    T: torch.Tensor,
    max_angle: float = (math.pi * 2.0),
    translation_std: float = 1.0,
    scale_std: float = 0.3,
):
    """
    Jitter the extrinsic camera parameters `R` and `T` with a random similarity
    transformation. The transformation rotates by a random angle between [0, max_angle];
    scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is
    a random sample from a normal distrubtion with zero mean and variance scale_std;
    and translates by a 3D offset sampled from N(0, translation_std).
    """
    assert all(x >= 0.0 for x in (max_angle, translation_std, scale_std))
    N = R.shape[0]
    R_jit = pt3d.transforms.random_rotations(1, device=R.device)
    R_jit = pt3d.transforms.so3_exponential_map(
        pt3d.transforms.so3_log_map(R_jit) * max_angle
    )
    T_jit = torch.randn_like(R_jit[:1, :, 0]) * translation_std
    rigid_transform = pt3d.ops.eyes(dim=4, N=N, device=R.device)
    rigid_transform[:, :3, :3] = R_jit.expand(N, 3, 3)
    rigid_transform[:, 3, :3] = T_jit.expand(N, 3)
    scale_jit = torch.exp(torch.randn_like(T_jit[:, 0]) * scale_std).expand(N)
    return apply_camera_alignment(R, T, rigid_transform, scale_jit)


def apply_camera_alignment(
    R: torch.Tensor,
    T: torch.Tensor,
    rigid_transform: torch.Tensor,
    scale: torch.Tensor,
):
    """
    Args:
        R: Camera rotation matrix of shape (N, 3, 3).
        T: Camera translation  of shape (N, 3).
        rigid_transform: A tensor of shape (N, 4, 4) representing a batch of
            N 4x4 tensors that map the scene pointcloud from misaligned coords
            to the aligned space.
        scale: A list of N scaling factors. A tensor of shape (N,)

    Returns:
        R_aligned: The aligned rotations R.
        T_aligned: The aligned translations T.
    """
    R_rigid = rigid_transform[:, :3, :3]
    T_rigid = rigid_transform[:, 3:, :3]
    R_aligned = R_rigid.permute(0, 2, 1).bmm(R)
    T_aligned = scale[:, None] * (T - (T_rigid @ R_aligned)[:, 0])
    return R_aligned, T_aligned


def get_min_max_depth_bounds(cameras, scene_center, scene_extent):
    """
    Estimate near/far depth plane as:
    near = dist(cam_center, self.scene_center) - self.scene_extent
    far  = dist(cam_center, self.scene_center) + self.scene_extent
    """
    cam_center = cameras.get_camera_center()
    center_dist = (
        ((cam_center - scene_center.to(cameras.R)[None]) ** 2)
        .sum(dim=-1)
        .clamp(0.001)
        .sqrt()
    )
    center_dist = center_dist.clamp(scene_extent + 1e-3)
    min_depth = center_dist - scene_extent
    max_depth = center_dist + scene_extent
    return min_depth, max_depth


def volumetric_camera_overlaps(
    cameras: CamerasBase,
    scene_extent: float = 8.0,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    resol: int = 16,
    weigh_by_ray_angle: bool = True,
):
    """
    Compute the overlaps between viewing frustrums of all pairs of cameras
    in `cameras`.
    """
    device = cameras.device
    ba = cameras.R.shape[0]
    n_vox = int(resol**3)
    grid = pt3d.structures.Volumes(
        densities=torch.zeros([1, 1, resol, resol, resol], device=device),
        volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
        voxel_size=2.0 * scene_extent / resol,
    ).get_coord_grid(world_coordinates=True)

    grid = grid.view(1, n_vox, 3).expand(ba, n_vox, 3)
    gridp = cameras.transform_points(grid, eps=1e-2)
    proj_in_camera = (
        torch.prod((gridp[..., :2].abs() <= 1.0), dim=-1)
        * (gridp[..., 2] > 0.0).float()
    )  # ba x n_vox

    if weigh_by_ray_angle:
        rays = torch.nn.functional.normalize(
            grid - cameras.get_camera_center()[:, None], dim=-1
        )
        rays_masked = rays * proj_in_camera[..., None]

        # - slow and readable:
        # inter = torch.zeros(ba, ba)
        # for i1 in range(ba):
        #     for i2 in range(ba):
        #         inter[i1, i2] = (
        #             1 + (rays_masked[i1] * rays_masked[i2]
        #         ).sum(dim=-1)).sum()

        # - fast:
        rays_masked = rays_masked.view(ba, n_vox * 3)
        inter = n_vox + (rays_masked @ rays_masked.t())

    else:
        inter = proj_in_camera @ proj_in_camera.t()

    mass = torch.diag(inter)
    iou = inter / (mass[:, None] + mass[None, :] - inter).clamp(0.1)

    return iou