from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional import numpy as np import torch import torch.nn.functional as F from optgs.experimental.initializers_utils import knn, points_to_gaussians from optgs.model.types import Gaussians from optgs.scene_trainer.common.gaussian_adapter import build_covariance from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg @dataclass class InitializerPointcloudCfg(NonlearnedInitializerCfg): name: Literal["pointcloud"] path: Path # Directory containing .ply files scaling_factor: float init_opacity: float sh_degree: int filter_zero_rgb: bool # 4x4 world transform applied to point cloud positions. # Needed when the PLY is in a different coordinate system than the camera poses. # For ScanNet++/NeRFstudio: the PLY is in COLMAP space while cameras are in # NeRFstudio space. The transform is (x,y,z) -> (y,x,-z), i.e.: # [[0,1,0,0],[1,0,0,0],[0,0,-1,0],[0,0,0,1]] # Set to null to skip. world_transform: Optional[list] def get_gaussian_param_num(self): sh_d = (self.sh_degree + 1) ** 2 return 3 + 4 + 3 * sh_d + 2 + 1 def get_sh_d(self): return (self.sh_degree + 1) ** 2 class InitializerPointcloud(NonlearnedInitializer[InitializerPointcloudCfg]): def __init__(self, cfg: InitializerPointcloudCfg) -> None: super().__init__(cfg) @staticmethod def _load_ply(ply_path: Path) -> tuple[np.ndarray, np.ndarray]: """Load Open3D binary PLY file. Returns: xyz: [N, 3] float32 array of 3D positions rgb: [N, 3] uint8 array of colors """ with open(ply_path, "rb") as f: num_vertices = 0 while True: line = f.readline().decode("ascii").strip() if line.startswith("element vertex"): num_vertices = int(line.split()[-1]) if line == "end_header": break dtype = np.dtype([ ("x", " InitializerOutput: # Resolve PLY path if "scene" in kwargs: scene_name = kwargs["scene"] assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" scene_name = scene_name[0] ply_path = self.cfg.path / f"{scene_name}.ply" else: raise ValueError("Scene name is required for pointcloud initializer.") if not ply_path.exists(): raise ValueError(f"PLY file {ply_path} does not exist.") print(f"Loading point cloud from {ply_path}") # Load PLY points_xyz, points_rgb = self._load_ply(ply_path) print(f"Loaded {points_xyz.shape[0]} points.") xyz = torch.from_numpy(points_xyz).float().to(device) rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device) # Apply world transform to align point cloud with camera coordinate system if self.cfg.world_transform is not None: T = torch.tensor(self.cfg.world_transform, dtype=torch.float32, device=device) # Transform: new_xyz = (T @ [xyz, 1])[:3] xyz_h = torch.cat([xyz, torch.ones(xyz.shape[0], 1, device=device)], dim=-1) # [N, 4] xyz = (T @ xyz_h.T)[:3].T # [N, 3] # Filter zero-RGB points if self.cfg.filter_zero_rgb: valid_mask = rgbs.sum(dim=-1) > 0 xyz = xyz[valid_mask] rgbs = rgbs[valid_mask] # ── Step 1: subsampling augmentation ───────────────────────────────────── min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample if min_sub is not None or max_sub is not None: target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub) if xyz.shape[0] > target_count: indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count] xyz = xyz[indices] rgbs = rgbs[indices] # ── Step 2: subsample to fixed count (for DDP consistency) ──────────── fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num if fixed_num is not None and xyz.shape[0] > fixed_num: indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num] xyz = xyz[indices] rgbs = rgbs[indices] # KNN → scales dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity) # Pad to fixed size for distributed training if self.training and fixed_num is not None: current_num = xyz.shape[0] if current_num < fixed_num: pad_size = fixed_num - current_num xyz = F.pad(xyz, (0, 0, 0, pad_size), mode='constant', value=0.0) rgbs = F.pad(rgbs, (0, 0, 0, pad_size), mode='constant', value=0.0) scales = F.pad(scales, (0, 0, 0, pad_size), mode='constant', value=1e-10) opacities = F.pad(opacities, (0, pad_size), mode='constant', value=1e-10) points_dict = { "xyz": xyz, "rgb": rgbs, "scales": scales * self.cfg.scaling_factor, "opacities": opacities, } # Convert to Gaussian representation gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) means = gaussians_dict["xyz"] sh0 = gaussians_dict["sh0"] shN = gaussians_dict["shN"] harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] rotations_unnorm = gaussians_dict["rotations_unnorm"] opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) scales = torch.exp(gaussians_dict["scales_raw"]) rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) covariances = build_covariance(scale=scales, rotation_xyzw=rotations) gaussians = Gaussians( means=means.unsqueeze(0), covariances=covariances.unsqueeze(0), harmonics=harmonics.unsqueeze(0), opacities=opacities.unsqueeze(0), scales=scales.unsqueeze(0), rotations=rotations.unsqueeze(0), rotations_unnorm=rotations_unnorm.unsqueeze(0), ) return InitializerOutput( gaussians=gaussians, features=None, depths=None, ) @staticmethod def _sample_num_gaussians(available: int, min_sub: int | float | None, max_sub: int | float | None) -> int: """Sample a target Gaussian count from the [min_sub, max_sub] range.""" if min_sub is None: min_sub = max_sub if max_sub is None: max_sub = min_sub if isinstance(min_sub, int): target = torch.randint(min_sub, max_sub + 1, (1,)).item() else: # float → ratio of available ratio = torch.empty(1).uniform_(min_sub, max_sub).item() target = int(available * ratio) return min(target, available)