File size: 2,068 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
"""Contains decoder head for direct prediction of delta values.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import torch
from torch import nn
from .gaussian_decoder import ImageFeatures
class DirectPredictionHead(nn.Module):
"""Decodes features into delta values using convolutions."""
def __init__(self, feature_dim: int, num_layers: int) -> None:
"""Initialize DirectGaussianPredictor.
Args:
feature_dim: Number of input features.
num_layers: The number of layers of Gaussians to predict.
"""
super().__init__()
self.num_layers = num_layers
# 14 is 3 means, 3 scales, 4 quaternions, 3 colors and 1 opacity
self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1)
self.geometry_prediction_head.weight.data.zero_()
assert self.geometry_prediction_head.bias is not None
self.geometry_prediction_head.bias.data.zero_()
self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1)
self.texture_prediction_head.weight.data.zero_()
assert self.texture_prediction_head.bias is not None
self.texture_prediction_head.bias.data.zero_()
def forward(self, image_features: ImageFeatures) -> torch.Tensor:
"""Predict deltas for 3D Gaussians.
Args:
image_features: Image features from decoder.
Returns:
The predicted deltas for Gaussian attributes.
"""
delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features)
delta_values_texture = self.texture_prediction_head(image_features.texture_features)
delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers))
delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers))
delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1)
return delta_values
|