Spaces:
Sleeping
Sleeping
File size: 4,588 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | from math import isqrt
from typing import Literal
import torch
from diff_gs import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from einops import einsum, rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from ...geometry.projection import get_fov, homogenize_points
def get_projection_matrix(
near: Float[Tensor, " batch"],
far: Float[Tensor, " batch"],
fov_x: Float[Tensor, " batch"],
fov_y: Float[Tensor, " batch"],
) -> Float[Tensor, "batch 4 4"]:
"""Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
transformation and that Z is flipped.
"""
tan_fov_x = (0.5 * fov_x).tan()
tan_fov_y = (0.5 * fov_y).tan()
top = tan_fov_y * near
bottom = -top
right = tan_fov_x * near
left = -right
(b,) = near.shape
result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device)
result[:, 0, 0] = 2 * near / (right - left)
result[:, 1, 1] = 2 * near / (top - bottom)
result[:, 0, 2] = (right + left) / (right - left)
result[:, 1, 2] = (top + bottom) / (top - bottom)
result[:, 3, 2] = 1
result[:, 2, 2] = far / (far - near)
result[:, 2, 3] = -(far * near) / (far - near)
return result
def render_cuda(
extrinsics: Float[Tensor, "batch 4 4"],
intrinsics: Float[Tensor, "batch 3 3"],
near: Float[Tensor, " batch"],
far: Float[Tensor, " batch"],
image_shape: tuple[int, int],
background_color: Float[Tensor, "batch 3"],
gaussian_means: Float[Tensor, "batch gaussian 3"],
gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
gaussian_opacities: Float[Tensor, "batch gaussian"],
scale_invariant: bool = False,
use_sh: bool = True,
):
assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
assert scale_invariant is False
# Make sure everything is in a range where numerical issues don't appear.
if scale_invariant:
scale = 1 / near
extrinsics = extrinsics.clone()
extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None]
gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2)
gaussian_means = gaussian_means * scale[:, None, None]
near = near * scale
far = far * scale
_, _, _, n = gaussian_sh_coefficients.shape
degree = isqrt(n) - 1
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
b, _, _ = extrinsics.shape
h, w = image_shape
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
tan_fov_x = (0.5 * fov_x).tan()
tan_fov_y = (0.5 * fov_y).tan()
projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
full_projection = view_matrix @ projection_matrix
all_images = []
all_radii = []
all_depths = []
for i in range(b):
# Set up a tensor for the gradients of the screen-space means.
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
try:
mean_gradients.retain_grad()
except Exception:
pass
settings = GaussianRasterizationSettings(
image_height=h,
image_width=w,
tanfovx=tan_fov_x[i].item(),
tanfovy=tan_fov_y[i].item(),
bg=background_color[i],
scale_modifier=1.0,
viewmatrix=view_matrix[i],
projmatrix=full_projection[i],
sh_degree=degree,
campos=extrinsics[i, :3, 3],
prefiltered=False,
debug=False,
antialiasing=False,
)
rasterizer = GaussianRasterizer(settings)
row, col = torch.triu_indices(3, 3)
image, radii, depth = rasterizer(
means3D=gaussian_means[i],
means2D=mean_gradients,
shs=shs[i] if use_sh else None,
colors_precomp=None if use_sh else shs[i, :, 0, :],
opacities=gaussian_opacities[i, ..., None],
cov3D_precomp=gaussian_covariances[i, :, row, col],
)
all_images.append(image)
all_radii.append(radii)
all_depths.append(depth)
return {
'image': torch.stack(all_images),
'depth': torch.stack(all_depths),
'radii': torch.stack(all_radii),
}
|