"""Contains basic data structures and functionality for 3D Gaussians. For licensing see accompanying LICENSE file. Copyright (C) 2025 Apple Inc. All Rights Reserved. """ from __future__ import annotations import logging from pathlib import Path from typing import Any, Literal, NamedTuple import numpy as np import torch from plyfile import PlyData, PlyElement from sharp.utils import color_space as cs_utils from sharp.utils import linalg LOGGER = logging.getLogger(__name__) BackgroundColor = Literal["black", "white", "random_color", "random_pixel"] class Gaussians3D(NamedTuple): """Represents a collection of 3D Gaussians.""" mean_vectors: torch.Tensor singular_values: torch.Tensor quaternions: torch.Tensor colors: torch.Tensor opacities: torch.Tensor def to(self, device: torch.device) -> Gaussians3D: """Move Gaussians to device.""" return Gaussians3D( mean_vectors=self.mean_vectors.to(device), singular_values=self.singular_values.to(device), quaternions=self.quaternions.to(device), colors=self.colors.to(device), opacities=self.opacities.to(device), ) class SceneMetaData(NamedTuple): """Meta data about Gaussian scene.""" focal_length_px: float resolution_px: tuple[int, int] color_space: cs_utils.ColorSpace def get_unprojection_matrix( extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_shape: tuple[int, int], ) -> torch.Tensor: """Compute unprojection matrix to transform Gaussians to Euclidean space. Args: extrinsics: The 4x4 extrinsics matrix of the camera view. intrinsics: The 4x4 intrinsics matrix of the camera view. image_shape: The (width, height) of the input image. Returns: A 4x4 matrix to transform Gaussians from NDC space to Euclidean space. """ device = intrinsics.device image_width, image_height = image_shape # This matrix converts OpenCV pixel coordinates to NDC coordinates where # (-1, 1) denotes the top left and (1, 1) the bottom right of the image. # # Note that premultiplying the intrinsics with ndc_matrix typically yields a matrix # that simply scales the x-axis by 2 * focal_length / image_width and the y-axis by # 2 * focal_length / image_height. ndc_matrix = torch.tensor( [ [2.0 / image_width, 0.0, -1.0, 0.0], [0.0, 2.0 / image_height, -1.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], ], device=device, ) return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics) def unproject_gaussians( gaussians_ndc: Gaussians3D, extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_shape: tuple[int, int], ) -> Gaussians3D: """Unproject Gaussians from NDC space to world coordinates.""" unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape) gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3]) return gaussians def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D: """Apply an affine transformation to 3D Gaussians. Args: gaussians: The Gaussians to transform. transform: An affine transform with shape 3x4. Returns: The transformed Gaussians. Note: This operation is not differentiable. """ transform_linear = transform[..., :3, :3] transform_offset = transform[..., :3, 3] mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset covariance_matrices = compose_covariance_matrices( gaussians.quaternions, gaussians.singular_values ) covariance_matrices = ( transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2) ) quaternions, singular_values = decompose_covariance_matrices(covariance_matrices) return Gaussians3D( mean_vectors=mean_vectors, singular_values=singular_values, quaternions=quaternions, colors=gaussians.colors, opacities=gaussians.opacities, ) def decompose_covariance_matrices( covariance_matrices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Decompose 3D covariance matrices into quaternions and singular values. Args: covariance_matrices: The covariance matrices to decompose. Returns: Quaternion and singular values corresponding to the orientation and scales of the diagonalized matrix. Note: This operation is not differentiable. """ device = covariance_matrices.device dtype = covariance_matrices.dtype # We convert to fp64 to avoid numerical errors. covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64) rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices) # NOTE: in SVD, it is possible that U and VT are both reflections. # We need to correct them. batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0) num_reflections = len(gaussian_idx) if num_reflections > 0: LOGGER.warning( "Received %d reflection matrices from SVD. Flipping them to rotations.", num_reflections, ) # Flip the last column of reflection and make it a rotation. rotations[batch_idx, gaussian_idx, :, -1] *= -1 quaternions = linalg.quaternions_from_rotation_matrices(rotations) quaternions = quaternions.to(dtype=dtype, device=device) singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device) return quaternions, singular_values def compose_covariance_matrices( quaternions: torch.Tensor, singular_values: torch.Tensor ) -> torch.Tensor: """Compose 3D covariance matrices into quaternions and singular values. Args: quaternions: The quaternions describing the principal basis. singular_values: The scales of the diagonalized matrix. Returns: The 3x3 covariances matrices. """ device = quaternions.device rotations = linalg.rotation_matrices_from_quaternions(quaternions) diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None] return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2) def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor: """Convert degree-0 spherical harmonics to RGB. Reference: https://en.wikipedia.org/wiki/Table_of_spherical_harmonics """ coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) return sh0 * coeff_degree0 + 0.5 def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor: """Convert RGB to degree-0 spherical harmonics. Reference: https://en.wikipedia.org/wiki/Table_of_spherical_harmonics """ coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) return (rgb - 0.5) / coeff_degree0 def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: """Loads a ply from a file.""" plydata = PlyData.read(path) vertices = next(filter(lambda x: x.name == "vertex", plydata.elements)) properties = ["x", "y", "z"] properties.extend([f"f_dc_{i}" for i in range(3)]) properties.extend([f"scale_{i}" for i in range(3)]) properties.extend([f"rot_{i}" for i in range(3)]) for prop in properties: if prop not in vertices: raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.") mean_vectors = np.stack( ( np.asarray(vertices["x"]), np.asarray(vertices["y"]), np.asarray(vertices["z"]), ), axis=1, ) scale_logits = np.stack( ( np.asarray(vertices["scale_0"]), np.asarray(vertices["scale_1"]), np.asarray(vertices["scale_2"]), ), axis=1, ) quaternions = np.stack( ( np.asarray(vertices["rot_0"]), np.asarray(vertices["rot_1"]), np.asarray(vertices["rot_2"]), np.asarray(vertices["rot_3"]), ), axis=1, ) spherical_harmonics_deg0 = np.stack( ( np.asarray(vertices["f_dc_0"]), np.asarray(vertices["f_dc_1"]), np.asarray(vertices["f_dc_2"]), ), axis=1, ) colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0) opacity_logits = np.asarray(vertices["opacity"])[..., None] supplement_elements = [element for element in plydata.elements if element.name != "vertex"] supplement_data: dict[str, Any] = {} supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"] for element in supplement_elements: for key in supplement_keys: if key not in supplement_data and key in element: supplement_data[key] = np.asarray(element[key]) # Parse intrinsics and image_size. if "intrinsic" in supplement_data: intrinsics_data = supplement_data["intrinsic"] # Legacy: image_size is contained in intrinsic element. if "image_size" not in supplement_data: if len(intrinsics_data) != 4: raise ValueError( "Expect legacy intrinsics with len=4 containing image size, " f"but received len={len(intrinsics_data)}" ) focal_length_px = (intrinsics_data[0], intrinsics_data[1]) width = int(intrinsics_data[2]) height = int(intrinsics_data[3]) else: if len(intrinsics_data) != 9: raise ValueError( "Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}." ) intrinsics_matrix = intrinsics_data.reshape((3, 3)) focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1]) image_size_data = supplement_data["image_size"] width = image_size_data[0] height = image_size_data[1] # Default to VGA resolution: focal length = 512, image size = (640, 480). else: focal_length_px = (512, 512) width = 640 height = 480 # Parse extrinsics. extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten()) extrinsics_matrix = np.eye(4) # Legacy: extrinsics store 12 elements. if len(extrinsics_data) == 12: extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4)) extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T elif len(extrinsics_data) == 16: extrinsics_matrix[:] = extrinsics_data.reshape((4, 4)) else: raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}") # Parse color space. color_space_index = supplement_data.get("color_space", 1) color_space = cs_utils.decode_color_space(color_space_index) if color_space == "sRGB": colors = cs_utils.sRGB2linearRGB(colors) mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float() quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float() singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float() opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float() colors = torch.from_numpy(colors).view(1, -1, 3).float() gaussians = Gaussians3D( mean_vectors=mean_vectors, quaternions=quaternions, singular_values=singular_values, opacities=opacities, colors=colors, ) metadata = SceneMetaData(focal_length_px[0], (width, height), color_space) return gaussians, metadata @torch.no_grad() def save_ply( gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path ) -> PlyData: """Save a predicted Gaussian3D to a ply file.""" def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: return torch.log(tensor / (1.0 - tensor)) xyz = gaussians.mean_vectors.flatten(0, 1) scale_logits = torch.log(gaussians.singular_values).flatten(0, 1) quaternions = gaussians.quaternions.flatten(0, 1) # SHARP takes an image, convert it to sRGB color space as input, # and predicts linearRGB Gaussians as output. # The SHARP renderer would blend linearRGB Gaussians and convert rendered images and videos # back to sRGB for the best display quality. # # However, public renderers do not have such linear2sRGB conversions after rendering. # If they render linearRGB Gaussians as-is, the output would be dark without Gamma correction. # # To make it compatible to public renderers, we force convert linearRGB to sRGB during export. # - The SHARP renderer will still handle conversions properly. # - Public renderers will be mostly working fine when regarding sRGB images as linearRGB images, # although for the best performance, it is recommended to apply the conversions. colors = convert_rgb_to_spherical_harmonics( cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)) ) color_space_index = cs_utils.encode_color_space("sRGB") # Store opacity logits. opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1) attributes = torch.cat( ( xyz, colors, opacity_logits, scale_logits, quaternions, ), dim=1, ) dtype_full = [ (attribute, "f4") for attribute in ["x", "y", "z"] + [f"f_dc_{i}" for i in range(3)] + ["opacity"] + [f"scale_{i}" for i in range(3)] + [f"rot_{i}" for i in range(4)] ] num_gaussians = len(xyz) elements = np.empty(num_gaussians, dtype=dtype_full) elements[:] = list(map(tuple, attributes.detach().cpu().numpy())) vertex_elements = PlyElement.describe(elements, "vertex") # Load image-wise metadata. image_height, image_width = image_shape # Export image size. dtype_image_size = [("image_size", "u4")] image_size_array = np.empty(2, dtype=dtype_image_size) image_size_array[:] = np.array([image_width, image_height]) image_size_element = PlyElement.describe(image_size_array, "image_size") # Export intrinsics. dtype_intrinsic = [("intrinsic", "f4")] intrinsic_array = np.empty(9, dtype=dtype_intrinsic) intrinsic = np.array( [ f_px, 0, image_width * 0.5, 0, f_px, image_height * 0.5, 0, 0, 1, ] ) intrinsic_array[:] = intrinsic.flatten() intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic") # Export dummy extrinsics. dtype_extrinsic = [("extrinsic", "f4")] extrinsic_array = np.empty(16, dtype=dtype_extrinsic) extrinsic_array[:] = np.eye(4).flatten() extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic") # Export number of frames and particles per frame. dtype_frames = [("frame", "i4")] frame_array = np.empty(2, dtype=dtype_frames) frame_array[:] = np.array([1, num_gaussians], dtype=np.int32) frame_element = PlyElement.describe(frame_array, "frame") # Export disparity ranges for transform. dtype_disparity = [("disparity", "f4")] disparity_array = np.empty(2, dtype=dtype_disparity) disparity = 1.0 / gaussians.mean_vectors[0, ..., -1] quantiles = ( torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device)) .float() .cpu() .numpy() ) disparity_array[:] = quantiles disparity_element = PlyElement.describe(disparity_array, "disparity") # Export colorspace. dtype_color_space = [("color_space", "u1")] color_space_array = np.empty(1, dtype=dtype_color_space) color_space_array[:] = np.array([color_space_index]).flatten() color_space_element = PlyElement.describe(color_space_array, "color_space") dtype_version = [("version", "u1")] version_array = np.empty(3, dtype=dtype_version) version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten() version_element = PlyElement.describe(version_array, "version") plydata = PlyData( [ vertex_elements, extrinsic_element, intrinsic_element, image_size_element, frame_element, disparity_element, color_space_element, version_element, ] ) plydata.write(path) return plydata