from dataclasses import dataclass from typing import Literal, Optional from optgs.dataset.data_types import BatchedViews import numpy as np import torch import math import torch.nn.functional as F from pathlib import Path from optgs.experimental.edgs.init import init_gaussians_with_corr from optgs.experimental.initializers_utils import knn, points_to_gaussians from optgs.model.types import Gaussians from optgs.scene_trainer.common.gaussian_adapter import build_covariance from optgs.scene_trainer.initializer.initializer import InitializerOutput, NonlearnedInitializer, NonlearnedInitializerCfg @dataclass class InitializerEdgsCfg(NonlearnedInitializerCfg): name: Literal["edgs"] sh_degree: int init_opacity: float scaling_factor: float roma_model_type: str sample_init_gaussians: int # if >0, randomly sample this many gaussians from the initialized set def get_gaussian_param_num(self): # calculate the number of parameters per Gaussian sh_d = self.get_sh_d() # TODO Naama: check where this is used, and if it is needed init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 return init_gaussian_param_num def get_sh_d(self): sh_d = (self.sh_degree + 1) ** 2 return sh_d class InitializerEdgs(NonlearnedInitializer[InitializerEdgsCfg]): def __init__(self, cfg: InitializerEdgsCfg) -> None: super().__init__(cfg) def forward( self, context: BatchedViews, visualization_dump: Optional[dict] = None, cached_data_path: Optional[Path] = None, **kwargs ) -> InitializerOutput: device = context["extrinsics"].device # unpack context (batch_dim = 1) viewpoints_img = context["image"].squeeze(0) # [N, 3, H, W] h, w = viewpoints_img.shape[2], viewpoints_img.shape[3] # poses viewpoints_c2w = context["extrinsics"].squeeze(0).clone() # [N, 4, 4] camera_centers = viewpoints_c2w[..., :3, 3] viewpoints_w2c = torch.inverse(viewpoints_c2w) # [N, 4, 4] # convert to column-major viewpoints_w2c = viewpoints_w2c.permute(0, 2, 1) # intrinsics viewpoints_intrinsics = context["intrinsics"].squeeze(0).clone() # [N, 3, 3] # un-normalize intrinsics by multiplying by image size viewpoints_intrinsics[:, 0, :] *= w viewpoints_intrinsics[:, 1, :] *= h def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) top = tanHalfFovY * znear bottom = -top right = tanHalfFovX * znear left = -right P = torch.zeros(4, 4) z_sign = 1.0 P[0, 0] = 2.0 * znear / (right - left) P[1, 1] = 2.0 * znear / (top - bottom) P[0, 2] = (right + left) / (right - left) P[1, 2] = (top + bottom) / (top - bottom) P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P def focal2fov(focal, pixels): return 2 * math.atan(pixels / (2 * focal)) viewpoints_proj = [] for idx, intrinsic in enumerate(viewpoints_intrinsics): fx = intrinsic[0, 0] fy = intrinsic[1, 1] znear = 0.01 zfar = 100.0 fovY = focal2fov(fy, h) fovX = focal2fov(fx, w) proj = getProjectionMatrix( znear=znear, zfar=zfar, fovX=fovX, fovY=fovY ).transpose(0, 1).cuda() viewpoints_proj.append(proj) viewpoints_proj = torch.stack(viewpoints_proj, dim=0) # [N, 4, 4] # compute full projection matrices viewpoints_full_proj = (viewpoints_w2c.bmm(viewpoints_proj)) # [N, 4, 4] # check if points_dict is stored on disk already (cached) found_cached = False if cached_data_path is not None: print("Checking for cached points_dict at:", str(cached_data_path)) cache_path = cached_data_path / "points_dict.pt" if cache_path.exists(): points_dict = torch.load(cache_path) print("Loaded cached points_dict from:", str(cache_path)) found_cached = True else: print("No cached points_dict found at:", str(cache_path)) if not found_cached: # recompute points_dict _, _, points_dict = init_gaussians_with_corr( viewpoints_img=viewpoints_img, # [N, 3, H, W] viewpoints_w2c=viewpoints_w2c, # [N, 4, 4] viewpoints_proj=viewpoints_full_proj, # [N, 4, 4] camera_centers=camera_centers, # [N, 3] init_opacity=self.cfg.init_opacity, roma_model_type=self.cfg.roma_model_type, verbose=False ) if cached_data_path is not None: print("Saving points_dict to cache at:", str(cache_path)) cached_data_path.mkdir(parents=True, exist_ok=True) torch.save(points_dict, cache_path) points_dict["scales"] *= self.cfg.scaling_factor # printing some stats for k, v in points_dict.items(): print(f"points_dict[{k}]: shape={v.shape}, dtype={v.dtype}, min={v.min().item()}, max={v.max().item()}") # downsample if needed if self.cfg.sample_init_gaussians > 0: # randomly sample a subset of gaussians total_points = points_dict["xyz"].shape[0] sample_num = min(self.cfg.sample_init_gaussians, total_points) sampled_indices = torch.randperm(total_points)[:sample_num] points_dict = {k: v[sampled_indices] for k, v in points_dict.items()} print("Nr points after sampling:", points_dict["xyz"].shape[0]) # pre-activation values on device gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) means = gaussians_dict["xyz"] sh0 = gaussians_dict["sh0"] shN = gaussians_dict["shN"] harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] rotations_unnorm = gaussians_dict["rotations_unnorm"] # post-activation values opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) scales = torch.exp(gaussians_dict["scales_raw"]) rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) covariances = build_covariance(scale=scales, rotation_xyzw=rotations) print("Nr gaussians initialized:", means.shape[0]) gaussians = Gaussians( means=means.unsqueeze(0), covariances=covariances.unsqueeze(0), harmonics=harmonics.unsqueeze(0), # [1, N, 3, sh_d] opacities=opacities.unsqueeze(0), scales=scales.unsqueeze(0), rotations=rotations.unsqueeze(0), rotations_unnorm=rotations_unnorm.unsqueeze(0), ) return InitializerOutput( gaussians=gaussians, features=None, depths=None )