from dataclasses import dataclass from typing import Literal, Optional from pathlib import Path import torch import torch.nn.functional as F from optgs.dataset.data_types import BatchedViews from optgs.scene_trainer.common.gaussian_adapter import build_covariance from optgs.model.types import Gaussians from optgs.experimental.initializers_utils import knn, points_to_gaussians from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, InitializerCfg, NonlearnedInitializerCfg from optgs.dataset.camera_datasets.camera import get_scene_scale @dataclass class InitializerRandomCfg(NonlearnedInitializerCfg): name: Literal["random"] init_num_pts: int init_extent: float scaling_factor: float init_opacity: float sh_degree: int def get_gaussian_param_num(self): # calculate the number of parameters per Gaussian sh_d = self.get_sh_d() init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 return init_gaussian_param_num def get_sh_d(self): sh_d = (self.sh_degree + 1) ** 2 return sh_d class InitializerRandom(NonlearnedInitializer[InitializerRandomCfg]): def __init__(self, cfg: InitializerRandomCfg) -> None: super().__init__(cfg) def forward( self, context: BatchedViews, **kwargs ) -> InitializerOutput: device = context["extrinsics"].device init_num_pts = self.cfg.init_num_pts init_extent = self.cfg.init_extent # calculate scene scale from context camtoworlds = context["extrinsics"].cpu().numpy() # [B, 4, 4] assert camtoworlds.shape[0] == 1, "Batch size > 1 not supported in random initializer" camtoworlds = camtoworlds.squeeze(0) scene_scale = get_scene_scale(camtoworlds) xyz = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) rgbs = torch.rand((init_num_pts, 3)) # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] points_dict = { "xyz": xyz, "rgb": rgbs, "scales": scales, "opacities": torch.full((xyz.shape[0],), self.cfg.init_opacity), } points_dict["scales"] *= self.cfg.scaling_factor # pre-activation values on device gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) means = gaussians_dict["xyz"] sh0 = gaussians_dict["sh0"] shN = gaussians_dict["shN"] harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] rotations_unnorm = gaussians_dict["rotations_unnorm"] # post-activation values opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) scales = torch.exp(gaussians_dict["scales_raw"]) rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) covariances = build_covariance(scale=scales, rotation_xyzw=rotations) gaussians = Gaussians( means=means.unsqueeze(0), covariances=covariances.unsqueeze(0), harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d] opacities=opacities.unsqueeze(0), scales=scales.unsqueeze(0), rotations=rotations.unsqueeze(0), rotations_unnorm=rotations_unnorm.unsqueeze(0), ) return InitializerOutput( gaussians=gaussians, features=None, depths=None )