File size: 4,037 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from jaxtyping import Float, Shaped
from torch import Tensor

try:
    from ..model.decoder.cuda_splatting import render_cuda_orthographic
except:
    pass
from ..model.types import Gaussians
from ..visualization.annotation import add_label
from ..visualization.drawing.cameras import draw_cameras
from .drawing.cameras import compute_equal_aabb_with_margin


def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]:
    shapes = torch.stack([torch.tensor(x.shape) for x in images])
    padded_shape = shapes.max(dim=0)[0]
    results = [
        torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device)
        for x in images
    ]
    for image, result in zip(images, results):
        slices = [slice(0, x) for x in image.shape]
        result[slices] = image[slices]
    return results


def render_projections(
    gaussians: Gaussians,
    resolution: int,
    margin: float = 0.1,
    draw_label: bool = True,
    extra_label: str = "",
) -> Float[Tensor, "batch 3 3 height width"]:
    device = gaussians.means.device
    b, _, _ = gaussians.means.shape

    # Compute the minima and maxima of the scene.
    minima = gaussians.means.min(dim=1).values
    maxima = gaussians.means.max(dim=1).values
    scene_minima, scene_maxima = compute_equal_aabb_with_margin(
        minima, maxima, margin=margin
    )

    projections = []
    for look_axis in range(3):
        right_axis = (look_axis + 1) % 3
        down_axis = (look_axis + 2) % 3

        # Define the extrinsics for rendering.
        extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device)
        extrinsics[:, right_axis, 0] = 1
        extrinsics[:, down_axis, 1] = 1
        extrinsics[:, look_axis, 2] = 1
        extrinsics[:, right_axis, 3] = 0.5 * (
            scene_minima[:, right_axis] + scene_maxima[:, right_axis]
        )
        extrinsics[:, down_axis, 3] = 0.5 * (
            scene_minima[:, down_axis] + scene_maxima[:, down_axis]
        )
        extrinsics[:, look_axis, 3] = scene_minima[:, look_axis]
        extrinsics[:, 3, 3] = 1

        # Define the intrinsics for rendering.
        extents = scene_maxima - scene_minima
        far = extents[:, look_axis]
        near = torch.zeros_like(far)
        width = extents[:, right_axis]
        height = extents[:, down_axis]

        projection = render_cuda_orthographic(
            extrinsics,
            width,
            height,
            near,
            far,
            (resolution, resolution),
            torch.zeros((b, 3), dtype=torch.float32, device=device),
            gaussians.means,
            gaussians.covariances,
            gaussians.harmonics,
            gaussians.opacities,
            fov_degrees=10.0,
        )
        if draw_label:
            right_axis_name = "XYZ"[right_axis]
            down_axis_name = "XYZ"[down_axis]
            label = f"{right_axis_name}{down_axis_name} Projection {extra_label}"
            projection = torch.stack([add_label(x, label) for x in projection])

        projections.append(projection)

    return torch.stack(pad(projections), dim=1)


def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]:
    # Define colors for context and target views.
    num_context_views = batch["context"]["extrinsics"].shape[1]
    num_target_views = batch["target"]["extrinsics"].shape[1]
    color = torch.ones(
        (num_target_views + num_context_views, 3),
        dtype=torch.float32,
        device=batch["target"]["extrinsics"].device,
    )
    color[num_context_views:, 1:] = 0

    return draw_cameras(
        resolution,
        torch.cat(
            (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0])
        ),
        torch.cat(
            (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0])
        ),
        color,
        torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])),
        torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])),
    )