File size: 7,175 Bytes
94dc344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Callable, Optional

import torch

import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase


def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
    """
    Expands the `global_code` of shape (minibatch, dim)
    so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
    and appends to the last dimension of `embeds`.
    """
    bs = embeds.shape[0]
    global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand(
        *embeds.shape[:-1],
        global_code.shape[-1],
    )
    return torch.cat([embeds, global_code_broadcast], dim=-1)


def create_embeddings_for_implicit_function(
    xyz_world: torch.Tensor,
    xyz_in_camera_coords: bool,
    global_code: Optional[torch.Tensor],
    camera: Optional[CamerasBase],
    fun_viewpool: Optional[Callable],
    xyz_embedding_function: Optional[Callable],
    diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    bs, *spatial_size, pts_per_ray, _ = xyz_world.shape

    if xyz_in_camera_coords:
        if camera is None:
            raise ValueError("Camera must be given if xyz_in_camera_coords")

        ray_points_for_embed = (
            camera.get_world_to_view_transform()
            .transform_points(xyz_world.view(bs, -1, 3))
            .view(xyz_world.shape)
        )
    else:
        ray_points_for_embed = xyz_world

    if xyz_embedding_function is None:
        embeds = torch.empty(
            bs,
            1,
            prod(spatial_size),
            pts_per_ray,
            0,
        )
    else:
        embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
        embeds = embeds.reshape(
            bs,
            1,
            prod(spatial_size),
            pts_per_ray,
            -1,
        )  # flatten spatial, add n_src dim

    if fun_viewpool is not None:
        # viewpooling
        embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3))
        embed_shape = (
            bs,
            embeds_viewpooled.shape[1],
            prod(spatial_size),
            pts_per_ray,
            -1,
        )
        embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape)
        if embeds is not None:
            embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1)
        else:
            embeds = embeds_viewpooled

    if global_code is not None:
        # append the broadcasted global code to embeds
        embeds = broadcast_global_code(embeds, global_code)

    return embeds


def interpolate_line(
    points: torch.Tensor,
    source: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    """
    Linearly interpolates values of source grids. The first dimension of points represents
    number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
    dimension of argument source represents feature and ones after that the spatial
    dimension.

    Arguments:
        points: shape (n_grids, n_points, 1),
        source: tensor of shape (n_grids, features, width),
    Returns:
        interpolated tensor of shape (n_grids, n_points, features)
    """
    # To enable sampling of the source using the torch.functional.grid_sample
    # points need to have 2 coordinates.
    expansion = points.new_zeros(points.shape)
    points = torch.cat((points, expansion), dim=-1)

    source = source[:, :, None, :]
    points = points[:, :, None, :]

    out = F.grid_sample(
        grid=points,
        input=source,
        **kwargs,
    )
    return out[:, :, :, 0].permute(0, 2, 1)


def interpolate_plane(
    points: torch.Tensor,
    source: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    """
    Bilinearly interpolates values of source grids. The first dimension of points represents
    number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
    The first dimension of argument source represents feature and ones after that the
    spatial dimension.

    Arguments:
        points: shape (n_grids, n_points, 2),
        source: tensor of shape (n_grids, features, width, height),
    Returns:
        interpolated tensor of shape (n_grids, n_points, features)
    """
    # permuting because torch.nn.functional.grid_sample works with
    # (features, height, width) and not
    # (features, width, height)
    source = source.permute(0, 1, 3, 2)
    points = points[:, :, None, :]

    out = F.grid_sample(
        grid=points,
        input=source,
        **kwargs,
    )
    return out[:, :, :, 0].permute(0, 2, 1)


def interpolate_volume(
    points: torch.Tensor, source: torch.Tensor, **kwargs
) -> torch.Tensor:
    """
    Interpolates values of source grids. The first dimension of points represents
    number of points and the second coordinates, for example
    [[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
    and ones after that the spatial dimension.

    Arguments:
        points: shape (n_grids, n_points, 3),
        source: tensor of shape (n_grids, features, width, height, depth),
    Returns:
        interpolated tensor of shape (n_grids, n_points, features)
    """
    if "mode" in kwargs and kwargs["mode"] == "trilinear":
        kwargs = kwargs.copy()
        kwargs["mode"] = "bilinear"
    # permuting because torch.nn.functional.grid_sample works with
    # (features, depth, height, width) and not (features, width, height, depth)
    source = source.permute(0, 1, 4, 3, 2)
    grid = points[:, :, None, None, :]

    out = F.grid_sample(
        grid=grid,
        input=source,
        **kwargs,
    )
    return out[:, :, :, 0, 0].permute(0, 2, 1)


def get_rays_points_world(
    ray_bundle: Optional[ImplicitronRayBundle] = None,
    rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Converts the ray_bundle to rays_points_world if rays_points_world is not defined
    and raises error if both are defined.

    Args:
        ray_bundle: An ImplicitronRayBundle object or None
        rays_points_world: A torch.Tensor representing ray points converted to
            world coordinates
    Returns:
        A torch.Tensor representing ray points converted to world coordinates
            of shape [minibatch x ... x pts_per_ray x 3].
    """
    if rays_points_world is not None and ray_bundle is not None:
        raise ValueError(
            "Cannot define both rays_points_world and ray_bundle,"
            + " one has to be None."
        )
    if rays_points_world is not None:
        return rays_points_world
    if ray_bundle is not None:
        # pyre-ignore[6]
        return ray_bundle_to_ray_points(ray_bundle)
    raise ValueError("ray_bundle and rays_points_world cannot both be None")