Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal, Optional | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| from optgs.dataset.data_types import BatchedViews | |
| from optgs.scene_trainer.common.gaussian_adapter import build_covariance | |
| from optgs.model.types import Gaussians | |
| from optgs.experimental.initializers_utils import knn, points_to_gaussians | |
| from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, InitializerCfg, NonlearnedInitializerCfg | |
| from optgs.dataset.camera_datasets.camera import get_scene_scale | |
| class InitializerRandomCfg(NonlearnedInitializerCfg): | |
| name: Literal["random"] | |
| init_num_pts: int | |
| init_extent: float | |
| scaling_factor: float | |
| init_opacity: float | |
| sh_degree: int | |
| def get_gaussian_param_num(self): | |
| # calculate the number of parameters per Gaussian | |
| sh_d = self.get_sh_d() | |
| init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 | |
| return init_gaussian_param_num | |
| def get_sh_d(self): | |
| sh_d = (self.sh_degree + 1) ** 2 | |
| return sh_d | |
| class InitializerRandom(NonlearnedInitializer[InitializerRandomCfg]): | |
| def __init__(self, cfg: InitializerRandomCfg) -> None: | |
| super().__init__(cfg) | |
| def forward( | |
| self, | |
| context: BatchedViews, | |
| **kwargs | |
| ) -> InitializerOutput: | |
| device = context["extrinsics"].device | |
| init_num_pts = self.cfg.init_num_pts | |
| init_extent = self.cfg.init_extent | |
| # calculate scene scale from context | |
| camtoworlds = context["extrinsics"].cpu().numpy() # [B, 4, 4] | |
| assert camtoworlds.shape[0] == 1, "Batch size > 1 not supported in random initializer" | |
| camtoworlds = camtoworlds.squeeze(0) | |
| scene_scale = get_scene_scale(camtoworlds) | |
| xyz = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) | |
| rgbs = torch.rand((init_num_pts, 3)) | |
| # Initialize the GS size to be the average dist of the 3 nearest neighbors | |
| dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] | |
| dist_avg = torch.sqrt(dist2_avg) | |
| scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] | |
| points_dict = { | |
| "xyz": xyz, | |
| "rgb": rgbs, | |
| "scales": scales, | |
| "opacities": torch.full((xyz.shape[0],), self.cfg.init_opacity), | |
| } | |
| points_dict["scales"] *= self.cfg.scaling_factor | |
| # pre-activation values on device | |
| gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) | |
| means = gaussians_dict["xyz"] | |
| sh0 = gaussians_dict["sh0"] | |
| shN = gaussians_dict["shN"] | |
| harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] | |
| harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] | |
| rotations_unnorm = gaussians_dict["rotations_unnorm"] | |
| # post-activation values | |
| opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) | |
| scales = torch.exp(gaussians_dict["scales_raw"]) | |
| rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) | |
| covariances = build_covariance(scale=scales, rotation_xyzw=rotations) | |
| gaussians = Gaussians( | |
| means=means.unsqueeze(0), | |
| covariances=covariances.unsqueeze(0), | |
| harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d] | |
| opacities=opacities.unsqueeze(0), | |
| scales=scales.unsqueeze(0), | |
| rotations=rotations.unsqueeze(0), | |
| rotations_unnorm=rotations_unnorm.unsqueeze(0), | |
| ) | |
| return InitializerOutput( | |
| gaussians=gaussians, | |
| features=None, | |
| depths=None | |
| ) |