File size: 7,175 Bytes
94dc344 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
"""
Expands the `global_code` of shape (minibatch, dim)
so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
and appends to the last dimension of `embeds`.
"""
bs = embeds.shape[0]
global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand(
*embeds.shape[:-1],
global_code.shape[-1],
)
return torch.cat([embeds, global_code_broadcast], dim=-1)
def create_embeddings_for_implicit_function(
xyz_world: torch.Tensor,
xyz_in_camera_coords: bool,
global_code: Optional[torch.Tensor],
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_in_camera_coords")
ray_points_for_embed = (
camera.get_world_to_view_transform()
.transform_points(xyz_world.view(bs, -1, 3))
.view(xyz_world.shape)
)
else:
ray_points_for_embed = xyz_world
if xyz_embedding_function is None:
embeds = torch.empty(
bs,
1,
prod(spatial_size),
pts_per_ray,
0,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape(
bs,
1,
prod(spatial_size),
pts_per_ray,
-1,
) # flatten spatial, add n_src dim
if fun_viewpool is not None:
# viewpooling
embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3))
embed_shape = (
bs,
embeds_viewpooled.shape[1],
prod(spatial_size),
pts_per_ray,
-1,
)
embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape)
if embeds is not None:
embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1)
else:
embeds = embeds_viewpooled
if global_code is not None:
# append the broadcasted global code to embeds
embeds = broadcast_global_code(embeds, global_code)
return embeds
def interpolate_line(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Linearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
dimension of argument source represents feature and ones after that the spatial
dimension.
Arguments:
points: shape (n_grids, n_points, 1),
source: tensor of shape (n_grids, features, width),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# To enable sampling of the source using the torch.functional.grid_sample
# points need to have 2 coordinates.
expansion = points.new_zeros(points.shape)
points = torch.cat((points, expansion), dim=-1)
source = source[:, :, None, :]
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
def interpolate_plane(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Bilinearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
The first dimension of argument source represents feature and ones after that the
spatial dimension.
Arguments:
points: shape (n_grids, n_points, 2),
source: tensor of shape (n_grids, features, width, height),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# permuting because torch.nn.functional.grid_sample works with
# (features, height, width) and not
# (features, width, height)
source = source.permute(0, 1, 3, 2)
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
def interpolate_volume(
points: torch.Tensor, source: torch.Tensor, **kwargs
) -> torch.Tensor:
"""
Interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
and ones after that the spatial dimension.
Arguments:
points: shape (n_grids, n_points, 3),
source: tensor of shape (n_grids, features, width, height, depth),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
if "mode" in kwargs and kwargs["mode"] == "trilinear":
kwargs = kwargs.copy()
kwargs["mode"] = "bilinear"
# permuting because torch.nn.functional.grid_sample works with
# (features, depth, height, width) and not (features, width, height, depth)
source = source.permute(0, 1, 4, 3, 2)
grid = points[:, :, None, None, :]
out = F.grid_sample(
grid=grid,
input=source,
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)
def get_rays_points_world(
ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: An ImplicitronRayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
A torch.Tensor representing ray points converted to world coordinates
of shape [minibatch x ... x pts_per_ray x 3].
"""
if rays_points_world is not None and ray_bundle is not None:
raise ValueError(
"Cannot define both rays_points_world and ray_bundle,"
+ " one has to be None."
)
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
# pyre-ignore[6]
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")
|