ml-sharp / src /sharp /utils /gsplat.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains utility code for gsplat renderer.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple
import gsplat
import torch
from torch import nn
from sharp.utils import color_space as cs_utils
from sharp.utils import io, vis
from sharp.utils.gaussians import BackgroundColor, Gaussians3D
class RenderingOutputs(NamedTuple):
"""Outputs of 3D Gaussians renderer."""
color: torch.Tensor
depth: torch.Tensor
alpha: torch.Tensor
def write_renderings(rendering: RenderingOutputs, output_folder: Path, filename: str):
"""Write rendered color/depth/alpha to files."""
batch_size = len(rendering.color)
if batch_size != 1:
raise RuntimeError("We only support saving rendering of batch size = 1")
def _save_image_tensor(tensor: torch.Tensor, suffix: str):
np_array = tensor.permute(1, 2, 0).numpy()
io.save_image(np_array, (output_folder / filename).with_suffix(suffix))
color = (rendering.color[0].cpu() * 255.0).to(dtype=torch.uint8)
colorized_depth = vis.colorize_depth(rendering.depth[0], val_max=100.0)
colorized_alpha = vis.colorize_alpha(rendering.alpha[0])
_save_image_tensor(color, ".color.png")
_save_image_tensor(colorized_depth, ".depth.png")
_save_image_tensor(colorized_alpha, ".alpha.png")
class GSplatRenderer(nn.Module):
"""Module to render 3D Gaussians to images using gsplat."""
color_space: cs_utils.ColorSpace
background_color: BackgroundColor
def __init__(
self,
color_space: cs_utils.ColorSpace = "sRGB",
background_color: BackgroundColor = "black",
low_pass_filter_eps: float = 0.0,
) -> None:
"""Initialize gsplat renderer.
Args:
color_space: The color space to use for rendering.
background_color: The background color to use for rendering.
low_pass_filter_eps: The epsilon value for the low pass filter.
"""
super().__init__()
self.color_space = color_space
self.background_color = background_color
self.low_pass_filter_eps = low_pass_filter_eps
def forward(
self,
gaussians: Gaussians3D,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
image_width: int,
image_height: int,
) -> RenderingOutputs:
"""Predict images from gaussians.
Args:
gaussians: The Gaussians to render.
extrinsics: The extrinsics of the camera to render to in OpenCV format.
intrinsics: The intriniscs of the camera to render to in OpenCV format.
image_width: The desired output image width.
image_height: The desired output image height.
"""
batch_size = len(gaussians.mean_vectors)
outputs_list: list[RenderingOutputs] = []
for ib in range(batch_size):
colors, alphas, meta = gsplat.rendering.rasterization(
means=gaussians.mean_vectors[ib],
quats=gaussians.quaternions[ib],
scales=gaussians.singular_values[ib],
opacities=gaussians.opacities[ib],
colors=gaussians.colors[ib],
viewmats=extrinsics[ib : ib + 1],
Ks=intrinsics[ib : ib + 1, :3, :3],
width=image_width,
height=image_height,
render_mode="RGB+D",
rasterize_mode="classic",
absgrad=False,
packed=False,
eps2d=self.low_pass_filter_eps,
)
rendered_color = colors[..., 0:3].permute([0, 3, 1, 2])
rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2])
rendered_alpha = alphas.permute([0, 3, 1, 2])
# Compose with background color.
rendered_color = self.compose_with_background(
rendered_color, rendered_alpha, self.background_color
)
# Colorspace conversion.
if self.color_space == "sRGB":
pass
elif self.color_space == "linearRGB":
rendered_color = cs_utils.linearRGB2sRGB(rendered_color)
else:
ValueError("Unsupported ColorSpace type.")
# splats: (B, N, 10)
cov2d = self._conics_to_covars2d(meta["conics"])
# Set the cov2d of invisible splats to 1 to avoid nan in condition number calculation..
splats_visible_mask = meta["depths"] > 1e-2
cov2d[~splats_visible_mask][..., 0, 0] = 1
cov2d[~splats_visible_mask][..., 1, 1] = 1
cov2d[~splats_visible_mask][..., 0, 1] = 0
# Normalize the depth by alpha.
rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8)
outputs = RenderingOutputs(
color=rendered_color,
depth=rendered_depth,
alpha=rendered_alpha,
)
outputs_list.append(outputs)
return RenderingOutputs(
color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(),
depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(),
alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(),
)
@staticmethod
def compose_with_background(
rendered_rgb: torch.Tensor,
rendered_alpha: torch.Tensor,
background_color: BackgroundColor,
) -> torch.Tensor:
"""Compose rendered RGB with background color."""
if background_color == "black":
return rendered_rgb
elif background_color == "white":
return rendered_rgb + (1.0 - rendered_alpha)
elif background_color == "random_color":
return (
rendered_rgb
+ (1.0 - rendered_alpha)
* torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[
None, :, None, None
]
)
elif background_color == "random_pixel":
return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb)
else:
raise ValueError("Unsupported BackgroundColor type.")
@staticmethod
def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor:
"""Convert conics to covariance matrices."""
a = conics[..., 0]
b = conics[..., 1]
c = conics[..., 2]
# Reconstruct determinant.
det = 1 / (a * c - b**2 + eps)
det = det.clamp(min=eps)
# Reconstruct covars2d.
covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device)
covars2d[..., 1, 1] = a * det
covars2d[..., 0, 0] = c * det
covars2d[..., 0, 1] = -b * det
covars2d[..., 1, 0] = -b * det
covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0)
return covars2d