Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from optgs.experimental.initializers_utils import knn, points_to_gaussians | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.common.gaussian_adapter import build_covariance | |
| from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg | |
| class InitializerPointcloudCfg(NonlearnedInitializerCfg): | |
| name: Literal["pointcloud"] | |
| path: Path # Directory containing <scene_id>.ply files | |
| scaling_factor: float | |
| init_opacity: float | |
| sh_degree: int | |
| filter_zero_rgb: bool | |
| # 4x4 world transform applied to point cloud positions. | |
| # Needed when the PLY is in a different coordinate system than the camera poses. | |
| # For ScanNet++/NeRFstudio: the PLY is in COLMAP space while cameras are in | |
| # NeRFstudio space. The transform is (x,y,z) -> (y,x,-z), i.e.: | |
| # [[0,1,0,0],[1,0,0,0],[0,0,-1,0],[0,0,0,1]] | |
| # Set to null to skip. | |
| world_transform: Optional[list] | |
| def get_gaussian_param_num(self): | |
| sh_d = (self.sh_degree + 1) ** 2 | |
| return 3 + 4 + 3 * sh_d + 2 + 1 | |
| def get_sh_d(self): | |
| return (self.sh_degree + 1) ** 2 | |
| class InitializerPointcloud(NonlearnedInitializer[InitializerPointcloudCfg]): | |
| def __init__(self, cfg: InitializerPointcloudCfg) -> None: | |
| super().__init__(cfg) | |
| def _load_ply(ply_path: Path) -> tuple[np.ndarray, np.ndarray]: | |
| """Load Open3D binary PLY file. | |
| Returns: | |
| xyz: [N, 3] float32 array of 3D positions | |
| rgb: [N, 3] uint8 array of colors | |
| """ | |
| with open(ply_path, "rb") as f: | |
| num_vertices = 0 | |
| while True: | |
| line = f.readline().decode("ascii").strip() | |
| if line.startswith("element vertex"): | |
| num_vertices = int(line.split()[-1]) | |
| if line == "end_header": | |
| break | |
| dtype = np.dtype([ | |
| ("x", "<f8"), ("y", "<f8"), ("z", "<f8"), | |
| ("r", "u1"), ("g", "u1"), ("b", "u1"), | |
| ]) | |
| data = np.frombuffer(f.read(num_vertices * dtype.itemsize), dtype=dtype) | |
| xyz = np.stack([data["x"], data["y"], data["z"]], axis=1).astype(np.float32) | |
| rgb = np.stack([data["r"], data["g"], data["b"]], axis=1) | |
| return xyz, rgb | |
| def forward( | |
| self, | |
| context, | |
| visualization_dump: Optional[dict] = None, | |
| device: Optional[torch.device] = None, | |
| **kwargs | |
| ) -> InitializerOutput: | |
| # Resolve PLY path | |
| if "scene" in kwargs: | |
| scene_name = kwargs["scene"] | |
| assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" | |
| scene_name = scene_name[0] | |
| ply_path = self.cfg.path / f"{scene_name}.ply" | |
| else: | |
| raise ValueError("Scene name is required for pointcloud initializer.") | |
| if not ply_path.exists(): | |
| raise ValueError(f"PLY file {ply_path} does not exist.") | |
| print(f"Loading point cloud from {ply_path}") | |
| # Load PLY | |
| points_xyz, points_rgb = self._load_ply(ply_path) | |
| print(f"Loaded {points_xyz.shape[0]} points.") | |
| xyz = torch.from_numpy(points_xyz).float().to(device) | |
| rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device) | |
| # Apply world transform to align point cloud with camera coordinate system | |
| if self.cfg.world_transform is not None: | |
| T = torch.tensor(self.cfg.world_transform, dtype=torch.float32, device=device) | |
| # Transform: new_xyz = (T @ [xyz, 1])[:3] | |
| xyz_h = torch.cat([xyz, torch.ones(xyz.shape[0], 1, device=device)], dim=-1) # [N, 4] | |
| xyz = (T @ xyz_h.T)[:3].T # [N, 3] | |
| # Filter zero-RGB points | |
| if self.cfg.filter_zero_rgb: | |
| valid_mask = rgbs.sum(dim=-1) > 0 | |
| xyz = xyz[valid_mask] | |
| rgbs = rgbs[valid_mask] | |
| # ββ Step 1: subsampling augmentation βββββββββββββββββββββββββββββββββββββ | |
| min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample | |
| max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample | |
| if min_sub is not None or max_sub is not None: | |
| target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub) | |
| if xyz.shape[0] > target_count: | |
| indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count] | |
| xyz = xyz[indices] | |
| rgbs = rgbs[indices] | |
| # ββ Step 2: subsample to fixed count (for DDP consistency) ββββββββββββ | |
| fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num | |
| if fixed_num is not None and xyz.shape[0] > fixed_num: | |
| indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num] | |
| xyz = xyz[indices] | |
| rgbs = rgbs[indices] | |
| # KNN β scales | |
| dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] | |
| dist_avg = torch.sqrt(dist2_avg) | |
| scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] | |
| opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity) | |
| # Pad to fixed size for distributed training | |
| if self.training and fixed_num is not None: | |
| current_num = xyz.shape[0] | |
| if current_num < fixed_num: | |
| pad_size = fixed_num - current_num | |
| xyz = F.pad(xyz, (0, 0, 0, pad_size), mode='constant', value=0.0) | |
| rgbs = F.pad(rgbs, (0, 0, 0, pad_size), mode='constant', value=0.0) | |
| scales = F.pad(scales, (0, 0, 0, pad_size), mode='constant', value=1e-10) | |
| opacities = F.pad(opacities, (0, pad_size), mode='constant', value=1e-10) | |
| points_dict = { | |
| "xyz": xyz, | |
| "rgb": rgbs, | |
| "scales": scales * self.cfg.scaling_factor, | |
| "opacities": opacities, | |
| } | |
| # Convert to Gaussian representation | |
| gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) | |
| means = gaussians_dict["xyz"] | |
| sh0 = gaussians_dict["sh0"] | |
| shN = gaussians_dict["shN"] | |
| harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] | |
| harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] | |
| rotations_unnorm = gaussians_dict["rotations_unnorm"] | |
| opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) | |
| scales = torch.exp(gaussians_dict["scales_raw"]) | |
| rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) | |
| covariances = build_covariance(scale=scales, rotation_xyzw=rotations) | |
| gaussians = Gaussians( | |
| means=means.unsqueeze(0), | |
| covariances=covariances.unsqueeze(0), | |
| harmonics=harmonics.unsqueeze(0), | |
| opacities=opacities.unsqueeze(0), | |
| scales=scales.unsqueeze(0), | |
| rotations=rotations.unsqueeze(0), | |
| rotations_unnorm=rotations_unnorm.unsqueeze(0), | |
| ) | |
| return InitializerOutput( | |
| gaussians=gaussians, | |
| features=None, | |
| depths=None, | |
| ) | |
| def _sample_num_gaussians(available: int, min_sub: int | float | None, max_sub: int | float | None) -> int: | |
| """Sample a target Gaussian count from the [min_sub, max_sub] range.""" | |
| if min_sub is None: | |
| min_sub = max_sub | |
| if max_sub is None: | |
| max_sub = min_sub | |
| if isinstance(min_sub, int): | |
| target = torch.randint(min_sub, max_sub + 1, (1,)).item() | |
| else: # float β ratio of available | |
| ratio = torch.empty(1).uniform_(min_sub, max_sub).item() | |
| target = int(available * ratio) | |
| return min(target, available) | |