|
|
"""Defines module to compose final Gaussians from base values and delta values. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from sharp.models.initializer import GaussianBaseValues |
|
|
from sharp.utils import math as math_utils |
|
|
from sharp.utils.color_space import ColorSpace, sRGB2linearRGB |
|
|
from sharp.utils.gaussians import Gaussians3D |
|
|
|
|
|
from .params import DeltaFactor |
|
|
|
|
|
|
|
|
def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]: |
|
|
"""Return constants for scale activation function.""" |
|
|
|
|
|
constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1) |
|
|
constant_b = math_utils.inverse_sigmoid( |
|
|
torch.tensor((1.0 - min_scale) / (max_scale - min_scale)) |
|
|
).item() |
|
|
return constant_a, constant_b |
|
|
|
|
|
|
|
|
class GaussianComposer(nn.Module): |
|
|
"""Converts base values and deltas into Gaussians.""" |
|
|
|
|
|
color_activation_type: math_utils.ActivationType |
|
|
opacity_activation_type: math_utils.ActivationType |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
delta_factor: DeltaFactor, |
|
|
min_scale: float, |
|
|
max_scale: float, |
|
|
color_activation_type: math_utils.ActivationType, |
|
|
opacity_activation_type: math_utils.ActivationType, |
|
|
color_space: ColorSpace, |
|
|
base_scale_on_predicted_mean: bool, |
|
|
scale_factor: int = 1, |
|
|
) -> None: |
|
|
"""Initialize GaussianComposer. |
|
|
|
|
|
Args: |
|
|
delta_factor: Multiply delta offsets by this factor. |
|
|
min_scale: The minimal scale factor for gaussian scale activation. |
|
|
max_scale: The maximal scale factor for gaussian scale activation. |
|
|
color_activation_type: Which activation function to use for colors. |
|
|
opacity_activation_type: Which activation function to use for opacities. |
|
|
color_space: Which color space is used in training. |
|
|
scale_factor: The scale factor to upsample the delta_values before composition. |
|
|
base_scale_on_predicted_mean: Whether to account z offsets for estimating base scale. |
|
|
""" |
|
|
super().__init__() |
|
|
self.delta_factor = delta_factor |
|
|
self.max_scale = max_scale |
|
|
self.min_scale = min_scale |
|
|
self.color_activation_type = color_activation_type |
|
|
self.opacity_activation_type = opacity_activation_type |
|
|
self.color_space = color_space |
|
|
self.scale_factor = scale_factor |
|
|
self.base_scale_on_predicted_mean = base_scale_on_predicted_mean |
|
|
|
|
|
def upsample_delta_value(self, delta: torch.Tensor, scale_factor: int = 1): |
|
|
"""Upsample the delta value. |
|
|
|
|
|
Args: |
|
|
delta: The delta values predicted by gaussian predictor. |
|
|
scale_factor: The scale factor to upsample the delta_values. |
|
|
""" |
|
|
( |
|
|
batch_size, |
|
|
num_channels, |
|
|
num_layers, |
|
|
image_height, |
|
|
image_width, |
|
|
) = delta.shape |
|
|
new_height = image_height * scale_factor |
|
|
new_width = image_width * scale_factor |
|
|
upsampled_delta = F.interpolate( |
|
|
delta.view(batch_size, num_channels * num_layers, image_height, image_width), |
|
|
scale_factor=scale_factor, |
|
|
).view(batch_size, num_channels, num_layers, new_height, new_width) |
|
|
return upsampled_delta |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
delta: torch.Tensor, |
|
|
base_values: GaussianBaseValues, |
|
|
global_scale: torch.Tensor | None = None, |
|
|
flatten_output: bool = True, |
|
|
) -> Gaussians3D: |
|
|
"""Combine predicted delta values with base gaussian values and apply activation function. |
|
|
|
|
|
Args: |
|
|
delta: The delta values predicted by gaussian predictor. |
|
|
base_values: The gaussian base values. |
|
|
global_scale: Global scale of Gaussians. |
|
|
flatten_output: Flatten the gaussian parameters. |
|
|
|
|
|
Returns: |
|
|
The computed 3D Gaussians. |
|
|
""" |
|
|
|
|
|
scale_factor = self.scale_factor |
|
|
|
|
|
actual_scale_factor = base_values.mean_x_ndc.shape[-1] // delta.shape[-1] |
|
|
if scale_factor != 1 and actual_scale_factor != 1: |
|
|
delta = self.upsample_delta_value(delta, scale_factor) |
|
|
|
|
|
mean_vectors = self._forward_mean(base_values, delta) |
|
|
|
|
|
|
|
|
base_scales = ( |
|
|
(base_values.scales * base_values.mean_inverse_z_ndc * mean_vectors[:, 2:3, ...]) |
|
|
if self.base_scale_on_predicted_mean |
|
|
else base_values.scales |
|
|
) |
|
|
singular_values = self._scale_activation( |
|
|
base_scales, |
|
|
delta[:, 3:6], |
|
|
self.min_scale, |
|
|
self.max_scale, |
|
|
) |
|
|
quaternions = self._quaternion_activation(base_values.quaternions, delta[:, 6:10]) |
|
|
colors = self._color_activation(base_values.colors, delta[:, 10:13]) |
|
|
opacities = self._opacity_activation(base_values.opacities, delta[:, 13]) |
|
|
|
|
|
if flatten_output: |
|
|
|
|
|
|
|
|
mean_vectors = mean_vectors.permute(0, 2, 3, 4, 1).flatten(1, 3) |
|
|
singular_values = singular_values.permute(0, 2, 3, 4, 1).flatten(1, 3) |
|
|
quaternions = quaternions.permute(0, 2, 3, 4, 1).flatten(1, 3) |
|
|
colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3) |
|
|
opacities = opacities.flatten(1, 3) |
|
|
|
|
|
|
|
|
if global_scale is not None: |
|
|
mean_vectors = global_scale[:, None, None] * mean_vectors |
|
|
singular_values = global_scale[:, None, None] * singular_values |
|
|
|
|
|
return Gaussians3D( |
|
|
mean_vectors=mean_vectors, |
|
|
singular_values=singular_values, |
|
|
quaternions=quaternions, |
|
|
colors=colors, |
|
|
opacities=opacities, |
|
|
) |
|
|
|
|
|
def _forward_mean(self, base_values: GaussianBaseValues, delta: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
delta_factor = torch.tensor( |
|
|
[self.delta_factor.xy, self.delta_factor.xy, self.delta_factor.z], |
|
|
device=delta.device, |
|
|
)[None, :, None, None, None] |
|
|
|
|
|
dtype = base_values.mean_x_ndc.dtype |
|
|
device = base_values.mean_x_ndc.device |
|
|
target_shape = (1, 3, 1, 1, 1) |
|
|
mean_x_mask = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device).reshape( |
|
|
target_shape |
|
|
) |
|
|
mean_y_mask = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device).reshape( |
|
|
target_shape |
|
|
) |
|
|
mean_z_mask = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).reshape( |
|
|
target_shape |
|
|
) |
|
|
|
|
|
mean_vectors_ndc = ( |
|
|
base_values.mean_x_ndc.repeat(target_shape) * mean_x_mask |
|
|
+ base_values.mean_y_ndc.repeat(target_shape) * mean_y_mask |
|
|
+ base_values.mean_inverse_z_ndc.repeat(target_shape) * mean_z_mask |
|
|
) |
|
|
|
|
|
mean_vectors = self._mean_activation(mean_vectors_ndc, delta_factor * delta[:, :3]) |
|
|
return mean_vectors |
|
|
|
|
|
def _mean_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
|
|
"""Mean activation function. |
|
|
|
|
|
Args: |
|
|
base: Tensor of shape [B, 3, H, W], where first two feature dimensions |
|
|
(x,y) are in normalized device coordinates (NDC) where (-1, -1) is |
|
|
the top, while the third dimension is inverse depth. |
|
|
learned_delta: Tensor of shape [B, 3, H, W] with predicted delta values. |
|
|
|
|
|
Returns: |
|
|
Returns: The final mean vector after combining base and delta and applying nonlinearies. |
|
|
""" |
|
|
xx = base[:, 0:1] + learned_delta[:, 0:1] |
|
|
yy = base[:, 1:2] + learned_delta[:, 1:2] |
|
|
|
|
|
a = base[:, 2:3] |
|
|
b = learned_delta[:, 2:3] |
|
|
|
|
|
|
|
|
inverse_zz = F.softplus(math_utils.inverse_softplus(a) + b) |
|
|
zz = 1.0 / (inverse_zz + 1e-3) |
|
|
|
|
|
mean_vectors = torch.cat([zz * xx, zz * yy, zz], dim=1) |
|
|
return mean_vectors |
|
|
|
|
|
def _scale_activation( |
|
|
self, |
|
|
base: torch.Tensor, |
|
|
learned_delta: torch.Tensor, |
|
|
min_scale: float, |
|
|
max_scale: float, |
|
|
) -> torch.Tensor: |
|
|
constant_a, constant_b = _get_scale_activation_constant(max_scale, min_scale) |
|
|
scale_factor = (max_scale - min_scale) * torch.sigmoid( |
|
|
constant_a * self.delta_factor.scale * learned_delta + constant_b |
|
|
) + min_scale |
|
|
return base * scale_factor |
|
|
|
|
|
def _quaternion_activation( |
|
|
self, base: torch.Tensor, learned_delta: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
return base + self.delta_factor.quaternion * learned_delta |
|
|
|
|
|
def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
if self.color_activation_type == "sigmoid": |
|
|
base = torch.clamp(base, min=0.01, max=0.99) |
|
|
elif self.color_activation_type in ("exp", "softplus"): |
|
|
base = torch.clamp(base, min=0.01) |
|
|
|
|
|
activation = math_utils.create_activation_pair(self.color_activation_type) |
|
|
colors: torch.Tensor = activation.forward( |
|
|
activation.inverse(base) + self.delta_factor.color * learned_delta |
|
|
) |
|
|
|
|
|
if self.color_space == "linearRGB": |
|
|
colors = sRGB2linearRGB(colors) |
|
|
return colors |
|
|
|
|
|
def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: |
|
|
activation = math_utils.create_activation_pair(self.opacity_activation_type) |
|
|
return activation.forward( |
|
|
activation.inverse(base) + self.delta_factor.opacity * learned_delta |
|
|
) |
|
|
|