File size: 7,744 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""Contains definition of RGB-only gaussian predictor.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import logging
import torch
from torch import nn
from sharp.models.monodepth import MonodepthWithEncodingAdaptor
from sharp.utils.gaussians import Gaussians3D
from .composer import GaussianComposer
LOGGER = logging.getLogger(__name__)
class DepthAlignment(nn.Module):
"""Depth alignment in a dedicated nn.Module.
Wrap scale_map_estimator to perform the conditional logic in a separated torch
module outside the forward of RGBGaussianPredictor. This module can be then
excluded during symbolic tracing.
"""
def __init__(self, scale_map_estimator: nn.Module | None):
"""Initialize DepthAlignmentWrapper.
Args:
scale_map_estimator: Module to align monodepth to ground truth depth.
"""
super().__init__()
self.scale_map_estimator = scale_map_estimator
def forward(
self,
monodepth: torch.Tensor,
depth: torch.Tensor,
depth_decoder_features: torch.Tensor | None = None,
):
"""Optionally align monodepth to ground truth with a local scale map.
Args:
monodepth: The monodepth model with intermediate features to use.
depth: Ground truth depth to align predicted depth to.
depth_decoder_features: The (optional) monodepth decoder features.
"""
if depth is not None and self.scale_map_estimator is not None:
depth_alignment_map = self.scale_map_estimator(
monodepth[:, 0:1], depth, depth_decoder_features
)
monodepth = depth_alignment_map * monodepth
else:
# Some losses rely on the presence of an alignment map.
# We ensure that they can be computed by creating a fake alignment map.
depth_alignment_map = torch.ones_like(monodepth)
return monodepth, depth_alignment_map
class RGBGaussianPredictor(nn.Module):
"""Predicts 3D Gaussians from images."""
feature_model: nn.Module
def __init__(
self,
init_model: nn.Module,
monodepth_model: MonodepthWithEncodingAdaptor,
feature_model: nn.Module,
prediction_head: nn.Module,
gaussian_composer: GaussianComposer,
scale_map_estimator: nn.Module | None,
) -> None:
"""Initialize RGBGaussianPredictor.
Args:
init_model: A model mapping image and depth to base values.
monodepth_model: The monodepth model with intermediate features to use.
feature_model: The image2image model to predict Gaussians from.
prediction_head: Head to decode image features.
gaussian_composer: Module to compose final prediction from deltas and
base values.
scale_map_estimator: Module to align monodepth to ground truth depth.
Note:
----
when monodepth_model is trainable, using local depth alignment can
result in the monodepth model losing its ability to predict shapes. It is
hence recommend to deactivate the corresponding flag.
"""
super().__init__()
self.init_model = init_model
self.feature_model = feature_model
self.monodepth_model = monodepth_model
self.prediction_head = prediction_head
self.gaussian_composer = gaussian_composer
self.depth_alignment = DepthAlignment(scale_map_estimator)
def forward(
self,
image: torch.Tensor,
disparity_factor: torch.Tensor,
depth: torch.Tensor | None = None,
) -> Gaussians3D:
"""Predict 3D Gaussians.
Args:
image: The image to process.
disparity_factor: Factor to convert depth to disparities.
depth: Ground truth depth to align predicted depth to.
Returns:
The predicted 3D Gaussians.
Note:
----
During training, it is recommended to feed an additional ground truth depth
map to the network to align the predicted depth to. During inference, it is
recommended to use depth_gt=None and use monodepth_disparity output from the
model instead to compute depth.
"""
# Estimate depth and align to ground truth (if available).
monodepth_output = self.monodepth_model(image)
monodepth_disparity = monodepth_output.disparity
disparity_factor = disparity_factor[:, None, None, None]
monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4)
# In the model we apply additional alignment to provided ground truth depth
# as well as additional normalization.
#
# The overall graph looks as follows:
#
# monodepth depth # Both monodepth and depth are metric here.
# | |
# +------+-------+
# |
# +-------+--------+ # Optionally align monodepth to ground truth
# |depth_alignement| # with a local scale map.
# +-------+--------+
# |
# v
# monodepth (aligned) # Monodepth is now aligned to ground truth.
# |
# +-----+----+ # Normalize depth and compute base gaussians.
# |init_model| # in these normalized coordinates.
# +-----+----+
# |
# v
# +------ init_output # Init_output consists of features, base
# | | # gaussians and a global scale.
# | +------+-----+
# | |main network| # Compute delta values to base gaussians.
# | +------+-----+
# | |
# | V
# | delta_values # The delta values are computed with normalized depth.
# | |
# | +-------+---------+
# +--> |gaussian_composer| # Add delta to base values and unscale gaussians.
# +-------+---------+
# |
# v
# gaussians # The final Gaussians are metric again.
#
# The logic to decide whether to align monodepth to the ground truth is wrapped
# in a submodule 'DepthAlignement' to facilitate the symbolic tracing of the
# predictor. This way, the depth alignment submodule containing the conditional
# logic can be excluded during the tracing and the graph of the predictors is
# static.
monodepth, _ = self.depth_alignment(
monodepth,
depth,
monodepth_output.decoder_features,
)
init_output = self.init_model(image, monodepth)
image_features = self.feature_model(
init_output.feature_input, encodings=monodepth_output.output_features
)
delta_values = self.prediction_head(image_features)
gaussians = self.gaussian_composer(
delta=delta_values,
base_values=init_output.gaussian_base_values,
global_scale=init_output.global_scale,
)
return gaussians
def internal_resolution(self) -> int:
"""Internal resolution."""
return self.monodepth_model.internal_resolution()
@property
def output_resolution(self) -> int:
"""Output resolution of Gaussians."""
return self.internal_resolution() // 2
|