"""Defines module to compose final Gaussians from base values and delta values. For licensing see accompanying LICENSE file. Copyright (C) 2025 Apple Inc. All Rights Reserved. """ from __future__ import annotations import torch from torch import nn from torch.nn import functional as F from sharp.models.initializer import GaussianBaseValues from sharp.utils import math as math_utils from sharp.utils.color_space import ColorSpace, sRGB2linearRGB from sharp.utils.gaussians import Gaussians3D from .params import DeltaFactor def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]: """Return constants for scale activation function.""" # To ensure for delta = 0, the value of scale_factor is 1 and the gradient is 1. constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1) constant_b = math_utils.inverse_sigmoid( torch.tensor((1.0 - min_scale) / (max_scale - min_scale)) ).item() return constant_a, constant_b class GaussianComposer(nn.Module): """Converts base values and deltas into Gaussians.""" color_activation_type: math_utils.ActivationType opacity_activation_type: math_utils.ActivationType def __init__( self, delta_factor: DeltaFactor, min_scale: float, max_scale: float, color_activation_type: math_utils.ActivationType, opacity_activation_type: math_utils.ActivationType, color_space: ColorSpace, base_scale_on_predicted_mean: bool, scale_factor: int = 1, ) -> None: """Initialize GaussianComposer. Args: delta_factor: Multiply delta offsets by this factor. min_scale: The minimal scale factor for gaussian scale activation. max_scale: The maximal scale factor for gaussian scale activation. color_activation_type: Which activation function to use for colors. opacity_activation_type: Which activation function to use for opacities. color_space: Which color space is used in training. scale_factor: The scale factor to upsample the delta_values before composition. base_scale_on_predicted_mean: Whether to account z offsets for estimating base scale. """ super().__init__() self.delta_factor = delta_factor self.max_scale = max_scale self.min_scale = min_scale self.color_activation_type = color_activation_type self.opacity_activation_type = opacity_activation_type self.color_space = color_space self.scale_factor = scale_factor self.base_scale_on_predicted_mean = base_scale_on_predicted_mean def upsample_delta_value(self, delta: torch.Tensor, scale_factor: int = 1): """Upsample the delta value. Args: delta: The delta values predicted by gaussian predictor. scale_factor: The scale factor to upsample the delta_values. """ ( batch_size, num_channels, num_layers, image_height, image_width, ) = delta.shape new_height = image_height * scale_factor new_width = image_width * scale_factor upsampled_delta = F.interpolate( delta.view(batch_size, num_channels * num_layers, image_height, image_width), scale_factor=scale_factor, ).view(batch_size, num_channels, num_layers, new_height, new_width) return upsampled_delta def forward( self, delta: torch.Tensor, base_values: GaussianBaseValues, global_scale: torch.Tensor | None = None, flatten_output: bool = True, ) -> Gaussians3D: """Combine predicted delta values with base gaussian values and apply activation function. Args: delta: The delta values predicted by gaussian predictor. base_values: The gaussian base values. global_scale: Global scale of Gaussians. flatten_output: Flatten the gaussian parameters. Returns: The computed 3D Gaussians. """ # Upsample the delta if delta and base_values have different strides. scale_factor = self.scale_factor # For triplane head, the delta has already been upsampled. actual_scale_factor = base_values.mean_x_ndc.shape[-1] // delta.shape[-1] if scale_factor != 1 and actual_scale_factor != 1: delta = self.upsample_delta_value(delta, scale_factor) mean_vectors = self._forward_mean(base_values, delta) # Account for the change in base scale due to z offsets. base_scales = ( (base_values.scales * base_values.mean_inverse_z_ndc * mean_vectors[:, 2:3, ...]) if self.base_scale_on_predicted_mean else base_values.scales ) singular_values = self._scale_activation( base_scales, delta[:, 3:6], self.min_scale, self.max_scale, ) quaternions = self._quaternion_activation(base_values.quaternions, delta[:, 6:10]) colors = self._color_activation(base_values.colors, delta[:, 10:13]) opacities = self._opacity_activation(base_values.opacities, delta[:, 13]) if flatten_output: # [B, C, N, H, W] -> [B, N, H, W, C]. # NOTE: opacities is [B, N, H, W] so it doesn't need to permute. mean_vectors = mean_vectors.permute(0, 2, 3, 4, 1).flatten(1, 3) singular_values = singular_values.permute(0, 2, 3, 4, 1).flatten(1, 3) quaternions = quaternions.permute(0, 2, 3, 4, 1).flatten(1, 3) colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3) opacities = opacities.flatten(1, 3) # Apply global scaling to convert Gaussians to metric space. if global_scale is not None: mean_vectors = global_scale[:, None, None] * mean_vectors singular_values = global_scale[:, None, None] * singular_values return Gaussians3D( mean_vectors=mean_vectors, singular_values=singular_values, quaternions=quaternions, colors=colors, opacities=opacities, ) def _forward_mean(self, base_values: GaussianBaseValues, delta: torch.Tensor) -> torch.Tensor: # Concatenate base vectors and apply mean activation. delta_factor = torch.tensor( [self.delta_factor.xy, self.delta_factor.xy, self.delta_factor.z], device=delta.device, )[None, :, None, None, None] dtype = base_values.mean_x_ndc.dtype device = base_values.mean_x_ndc.device target_shape = (1, 3, 1, 1, 1) mean_x_mask = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device).reshape( target_shape ) mean_y_mask = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device).reshape( target_shape ) mean_z_mask = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).reshape( target_shape ) mean_vectors_ndc = ( base_values.mean_x_ndc.repeat(target_shape) * mean_x_mask + base_values.mean_y_ndc.repeat(target_shape) * mean_y_mask + base_values.mean_inverse_z_ndc.repeat(target_shape) * mean_z_mask ) mean_vectors = self._mean_activation(mean_vectors_ndc, delta_factor * delta[:, :3]) return mean_vectors def _mean_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: """Mean activation function. Args: base: Tensor of shape [B, 3, H, W], where first two feature dimensions (x,y) are in normalized device coordinates (NDC) where (-1, -1) is the top, while the third dimension is inverse depth. learned_delta: Tensor of shape [B, 3, H, W] with predicted delta values. Returns: Returns: The final mean vector after combining base and delta and applying nonlinearies. """ xx = base[:, 0:1] + learned_delta[:, 0:1] yy = base[:, 1:2] + learned_delta[:, 1:2] a = base[:, 2:3] b = learned_delta[:, 2:3] # Original formula: inverse_zz = F.softplus(math_utils.inverse_softplus(a) + b) zz = 1.0 / (inverse_zz + 1e-3) mean_vectors = torch.cat([zz * xx, zz * yy, zz], dim=1) return mean_vectors def _scale_activation( self, base: torch.Tensor, learned_delta: torch.Tensor, min_scale: float, max_scale: float, ) -> torch.Tensor: constant_a, constant_b = _get_scale_activation_constant(max_scale, min_scale) scale_factor = (max_scale - min_scale) * torch.sigmoid( constant_a * self.delta_factor.scale * learned_delta + constant_b ) + min_scale return base * scale_factor def _quaternion_activation( self, base: torch.Tensor, learned_delta: torch.Tensor ) -> torch.Tensor: # No need to normalize the quaternions, since this is also done in rendering. return base + self.delta_factor.quaternion * learned_delta def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: # For certain activation functions we need to clamp the base value to # a supported range. if self.color_activation_type == "sigmoid": base = torch.clamp(base, min=0.01, max=0.99) elif self.color_activation_type in ("exp", "softplus"): base = torch.clamp(base, min=0.01) activation = math_utils.create_activation_pair(self.color_activation_type) colors: torch.Tensor = activation.forward( activation.inverse(base) + self.delta_factor.color * learned_delta ) # Convert gaussian color to linear if linearRGB colorspace is specified. if self.color_space == "linearRGB": colors = sRGB2linearRGB(colors) return colors def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: activation = math_utils.create_activation_pair(self.opacity_activation_type) return activation.forward( activation.inverse(base) + self.delta_factor.opacity * learned_delta )