Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from jaxtyping import Float | |
| from torch import Tensor, nn | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.common.gaussians import build_covariance | |
| class GaussiansModule(nn.Module): | |
| def __init__( | |
| self, | |
| means: Float[Tensor, "gaussian 3"], | |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], | |
| opacities: Float[Tensor, "gaussian"], | |
| scales: Float[Tensor, "gaussian 3"], | |
| rotations_unnorm: Float[Tensor, "gaussian 4"] | |
| ): | |
| # all gaussians parameters are post-activation | |
| super().__init__() | |
| def _register_param(name, value): | |
| if value is None: | |
| setattr(self, name, None) | |
| else: | |
| param = nn.Parameter(value) | |
| setattr(self, name, param) | |
| self.scaling_activation = torch.exp | |
| self.scaling_inverse_activation = torch.log | |
| self.covariance_activation = build_covariance | |
| self.opacity_activation = torch.sigmoid | |
| self.inverse_opacity_activation = torch.logit | |
| self.rotation_activation = F.normalize | |
| # Register parameters | |
| means = means.detach().clone() | |
| means.requires_grad_(True) | |
| harmonics = harmonics.detach().clone() # [G, sh_d, C] | |
| d_sh = harmonics.shape[-1] | |
| sh0 = harmonics[..., 0:1] # [G, 3, 1] | |
| if d_sh == 1: | |
| # sh_degree = 0 | |
| shN = None | |
| else: | |
| # sh_degree > 0 | |
| shN = harmonics[..., 1:] # [G, 3, d_sh-1] | |
| sh0.requires_grad_(True) | |
| if shN is not None: | |
| shN.requires_grad_(True) | |
| # Invert the opacity to optimize in the unconstrained space | |
| opacities_raw = self.inverse_opacity_activation(opacities.detach().clone(), eps=1e-6) | |
| opacities_raw.requires_grad_(True) | |
| # Invert the scales | |
| scales_raw = self.scaling_inverse_activation(scales.detach().clone()) | |
| scales_raw.requires_grad_(True) | |
| # Rotations in xyzw (scalar last) | |
| # remember to convert to wxyz (scalar first) before rendering and saving to ply | |
| rotations_unnorm = rotations_unnorm.detach().clone() | |
| rotations_unnorm.requires_grad_(True) | |
| _register_param("opacities_raw", opacities_raw) | |
| _register_param("scales_raw", scales_raw) | |
| _register_param("means", means) | |
| _register_param("rotations_unnorm", rotations_unnorm) | |
| _register_param("sh0", sh0) | |
| if shN is not None: | |
| _register_param("shN", shN) | |
| for name, param in self.named_parameters(): | |
| print(f"Registered parameter: {name}, shape: {param.shape}, dtype: {param.dtype}, min: {param.min()}, max: {param.max()}, requires_grad: {param.requires_grad}") | |
| def scales(self): | |
| scales = self.scaling_activation(self.scales_raw) | |
| return scales | |
| def opacities(self): | |
| opacities = self.opacity_activation(self.opacities_raw) | |
| return opacities | |
| def rotations(self): | |
| rotations = self.rotation_activation(self.rotations_unnorm, dim=-1) | |
| return rotations | |
| def harmonics(self): | |
| # returns [G, 3, d_sh] | |
| shN = getattr(self, "shN", None) | |
| if shN is not None: | |
| harmonics_ = torch.cat([self.sh0, shN], dim=-1) | |
| else: | |
| harmonics_ = self.sh0 | |
| return harmonics_ | |
| def covariances(self): | |
| rotation_xyzw = self.rotations | |
| covariances = self.covariance_activation(self.scales, rotation_xyzw) # [G, 3, 3] | |
| return covariances | |
| def reset_opacity(self, optimizer): | |
| opacities_old = self.opacity_activation(self.opacities_raw) | |
| opacities_raw_new = self.inverse_opacity_activation(torch.min(opacities_old, torch.ones_like(opacities_old)*0.01), eps=1e-6) | |
| # optimizable_tensors = self.replace_tensor_to_optimizer(optimizer, opacities_raw_new, "opacity") | |
| # self.opacities_raw = optimizable_tensors["opacity"] | |
| def gaussians2module(gaussians: Gaussians, device: torch.device) -> GaussiansModule: | |
| bs = gaussians.means.shape[0] | |
| assert bs == 1, "Batch size > 1 not supported for post-processing" | |
| # bs = 1 | |
| # convert Gaussians to GaussiansModule | |
| gaussian_module = GaussiansModule( | |
| means=gaussians.means[0].to(device), | |
| harmonics=gaussians.harmonics[0].to(device), | |
| opacities=gaussians.opacities[0].to(device), | |
| scales=gaussians.scales[0].to(device), | |
| rotations_unnorm=gaussians.rotations_unnorm[0].to(device), | |
| ) | |
| return gaussian_module | |
| def module2gaussians(gaussian_module: GaussiansModule) -> Gaussians: | |
| gaussians = Gaussians( | |
| means=gaussian_module.means.unsqueeze(0), # [1, G, 3] | |
| covariances=gaussian_module.covariances.unsqueeze(0), # [1, G, 3, 3] | |
| harmonics=gaussian_module.harmonics.unsqueeze(0), # [1, G, sh_d, C] | |
| opacities=gaussian_module.opacities.unsqueeze(0), # [1, G] | |
| scales=gaussian_module.scales.unsqueeze(0), # [1, G, 3] | |
| rotations=gaussian_module.rotations.unsqueeze(0), # [1, G, 4] | |
| rotations_unnorm=gaussian_module.rotations.unsqueeze(0), | |
| ) | |
| return gaussians | |