File size: 5,318 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor, nn
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussians import build_covariance


class GaussiansModule(nn.Module):
    def __init__(
        self, 
        means: Float[Tensor, "gaussian 3"],
        harmonics: Float[Tensor, "gaussian 3 d_sh"],
        opacities: Float[Tensor, "gaussian"],
        scales: Float[Tensor, "gaussian 3"],
        rotations_unnorm: Float[Tensor, "gaussian 4"]
    ):
        # all gaussians parameters are post-activation
        
        super().__init__()
        
        def _register_param(name, value):
            if value is None:
                setattr(self, name, None)
            else:
                param = nn.Parameter(value)
                setattr(self, name, param)

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log
        self.covariance_activation = build_covariance
        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = torch.logit
        self.rotation_activation = F.normalize

        # Register parameters
        means = means.detach().clone()
        means.requires_grad_(True)
        
        harmonics = harmonics.detach().clone()  # [G, sh_d, C]
        d_sh = harmonics.shape[-1]
        sh0 = harmonics[..., 0:1]  # [G, 3, 1]
        if d_sh == 1:
            # sh_degree = 0
            shN = None
        else:
            # sh_degree > 0
            shN = harmonics[..., 1:]  # [G, 3, d_sh-1]

        sh0.requires_grad_(True)
        if shN is not None:
            shN.requires_grad_(True)

        # Invert the opacity to optimize in the unconstrained space
        opacities_raw = self.inverse_opacity_activation(opacities.detach().clone(), eps=1e-6)
        opacities_raw.requires_grad_(True)
        
        # Invert the scales
        scales_raw = self.scaling_inverse_activation(scales.detach().clone())
        scales_raw.requires_grad_(True)
        
        # Rotations in xyzw (scalar last)
        # remember to convert to wxyz (scalar first) before rendering and saving to ply
        rotations_unnorm = rotations_unnorm.detach().clone()
        rotations_unnorm.requires_grad_(True)
        
        _register_param("opacities_raw", opacities_raw)
        _register_param("scales_raw", scales_raw)
        _register_param("means", means)
        _register_param("rotations_unnorm", rotations_unnorm)
        _register_param("sh0", sh0)
        if shN is not None:
            _register_param("shN", shN)
        
        for name, param in self.named_parameters():
            print(f"Registered parameter: {name}, shape: {param.shape}, dtype: {param.dtype}, min: {param.min()}, max: {param.max()}, requires_grad: {param.requires_grad}")

    @property
    def scales(self):
        scales = self.scaling_activation(self.scales_raw)
        return scales
    
    @property
    def opacities(self):
        opacities = self.opacity_activation(self.opacities_raw)
        return opacities
    
    @property
    def rotations(self):
        rotations = self.rotation_activation(self.rotations_unnorm, dim=-1)
        return rotations

    @property
    def harmonics(self):
        # returns [G, 3, d_sh]
        shN = getattr(self, "shN", None)
        if shN is not None:
            harmonics_ = torch.cat([self.sh0, shN], dim=-1)
        else:
            harmonics_ = self.sh0
        return harmonics_

    @property
    def covariances(self):
        rotation_xyzw = self.rotations
        covariances = self.covariance_activation(self.scales, rotation_xyzw)  # [G, 3, 3]
        return covariances
    
    def reset_opacity(self, optimizer):
        opacities_old = self.opacity_activation(self.opacities_raw)
        opacities_raw_new = self.inverse_opacity_activation(torch.min(opacities_old, torch.ones_like(opacities_old)*0.01), eps=1e-6)
        # optimizable_tensors = self.replace_tensor_to_optimizer(optimizer, opacities_raw_new, "opacity")
        # self.opacities_raw = optimizable_tensors["opacity"]
    


def gaussians2module(gaussians: Gaussians, device: torch.device) -> GaussiansModule:
    bs = gaussians.means.shape[0]
    assert bs == 1, "Batch size > 1 not supported for post-processing"
    # bs = 1
    # convert Gaussians to GaussiansModule
    gaussian_module = GaussiansModule(
        means=gaussians.means[0].to(device),
        harmonics=gaussians.harmonics[0].to(device),
        opacities=gaussians.opacities[0].to(device),
        scales=gaussians.scales[0].to(device),
        rotations_unnorm=gaussians.rotations_unnorm[0].to(device),
    )
    return gaussian_module


def module2gaussians(gaussian_module: GaussiansModule) -> Gaussians:
    gaussians = Gaussians(
        means=gaussian_module.means.unsqueeze(0),  # [1, G, 3]
        covariances=gaussian_module.covariances.unsqueeze(0),  # [1, G, 3, 3]
        harmonics=gaussian_module.harmonics.unsqueeze(0),  # [1, G, sh_d, C]
        opacities=gaussian_module.opacities.unsqueeze(0),  # [1, G]
        scales=gaussian_module.scales.unsqueeze(0),  # [1, G, 3]
        rotations=gaussian_module.rotations.unsqueeze(0),  # [1, G, 4]
        rotations_unnorm=gaussian_module.rotations.unsqueeze(0),
    )
    return gaussians