amael-apple's picture
Initial commit
c20d7cc
raw
history blame
2.07 kB
"""Contains decoder head for direct prediction of delta values.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import torch
from torch import nn
from .gaussian_decoder import ImageFeatures
class DirectPredictionHead(nn.Module):
"""Decodes features into delta values using convolutions."""
def __init__(self, feature_dim: int, num_layers: int) -> None:
"""Initialize DirectGaussianPredictor.
Args:
feature_dim: Number of input features.
num_layers: The number of layers of Gaussians to predict.
"""
super().__init__()
self.num_layers = num_layers
# 14 is 3 means, 3 scales, 4 quaternions, 3 colors and 1 opacity
self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1)
self.geometry_prediction_head.weight.data.zero_()
assert self.geometry_prediction_head.bias is not None
self.geometry_prediction_head.bias.data.zero_()
self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1)
self.texture_prediction_head.weight.data.zero_()
assert self.texture_prediction_head.bias is not None
self.texture_prediction_head.bias.data.zero_()
def forward(self, image_features: ImageFeatures) -> torch.Tensor:
"""Predict deltas for 3D Gaussians.
Args:
image_features: Image features from decoder.
Returns:
The predicted deltas for Gaussian attributes.
"""
delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features)
delta_values_texture = self.texture_prediction_head(image_features.texture_features)
delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers))
delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers))
delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1)
return delta_values