from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional, Tuple import os import tempfile import numpy as np import torch import torch.nn.functional as F from plyfile import PlyData from optgs.dataset.colmap.utils import Parser from optgs.dataset.data_types import BatchedViews from optgs.experimental.initializers_utils import knn, points_to_gaussians from optgs.misc.general_utils import SkipBatchException from optgs.model.types import Gaussians from optgs.scene_trainer.common.gaussian_adapter import build_covariance from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg @dataclass class InitializerColmapCfg(NonlearnedInitializerCfg): name: Literal["colmap"] path: Path normalize_world_space: bool scaling_factor: float init_opacity: float sh_degree: int dl3dv_settings: bool filter_zero_rgb: bool randomize_opacity: bool randomize_opacity_distribution: Literal["uniform", "gaussian"] randomize_opacity_std: float # Standard deviation for gaussian distribution randomize_opacity_min: float # Minimum value for uniform distribution points3d_subdir: Optional[str] # if set, overrides dl3dv_settings/default subdir logic points3d_ply_filename: Optional[str] # if set, loads points from this PLY file (relative to scene dir) instead of COLMAP binary override_dataset_poses: bool # if true, overrides the dataset poses with the COLMAP poses (after applying T_world transform) def get_gaussian_param_num(self): # calculate the number of parameters per Gaussian sh_d = self.get_sh_d() init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 return init_gaussian_param_num def get_sh_d(self): sh_d = (self.sh_degree + 1) ** 2 return sh_d class InitializerColmap(NonlearnedInitializer[InitializerColmapCfg]): def __init__(self, cfg: InitializerColmapCfg) -> None: super().__init__(cfg) def _npz_path(self, datadir: Path) -> Path: suffix = "_norm" if self.cfg.normalize_world_space else "" if self.cfg.points3d_ply_filename is not None: ply_stem = Path(self.cfg.points3d_ply_filename).stem return datadir / f"colmap_points_cache_ply_{ply_stem}{suffix}.npz" return datadir / f"colmap_points_cache{suffix}.npz" def _load_colmap(self, datadir: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Load COLMAP points/colors/poses. On first access, parses the raw COLMAP binary files (or a PLY file when ``points3d_ply_filename`` is set) and saves a compact .npz next to the scene folder. On subsequent calls only the tiny .npz is loaded. """ npz_path = self._npz_path(datadir) if npz_path.exists(): try: data = np.load(npz_path) return data["points"], data["points_rgb"], data["camtoworlds"] except PermissionError: print(f"Warning: No read permission for cache {npz_path}. Attempting to delete and regenerate.") try: os.unlink(npz_path) except Exception as del_e: print(f"Warning: Could not delete {npz_path} ({del_e}). Will re-parse but cannot cache.") except Exception as e: print(f"Warning: Failed to load cache {npz_path} ({e}). Re-parsing COLMAP data.") # Always parse COLMAP cameras/images for the poses. parser = Parser( data_dir=str(datadir), factor=1, normalize=self.cfg.normalize_world_space, load_images=False, dl3dv_settings=False, points3d_subdir=self.cfg.points3d_subdir, verbose=False, ) camtoworlds = parser.camtoworlds # (M, 4, 4) float64 if self.cfg.points3d_ply_filename is not None: # Load 3-D points from a PLY file located directly in the scene dir. ply_path = datadir / self.cfg.points3d_ply_filename if not ply_path.exists(): raise IOError(f"PLY file not found: {ply_path}") plydata = PlyData.read(str(ply_path)) vertex = plydata["vertex"] points = np.stack([ np.asarray(vertex["x"]), np.asarray(vertex["y"]), np.asarray(vertex["z"]), ], axis=1).astype(np.float32) points_rgb = np.stack([ np.asarray(vertex["red"]), np.asarray(vertex["green"]), np.asarray(vertex["blue"]), ], axis=1).astype(np.uint8) else: points = parser.points # (N, 3) float32 points_rgb = parser.points_rgb # (N, 3) uint8 # TODO Patricia: Fix permission denied # Write atomically with a temp file that already ends in .npz. try: tmp_path = '' tmp_fd, tmp_path = tempfile.mkstemp(dir=datadir, suffix=".npz") os.close(tmp_fd) np.savez_compressed(tmp_path, points=points, points_rgb=points_rgb, camtoworlds=camtoworlds) os.chmod(tmp_path, 0o664) # group-readable so other users can use this cache os.replace(tmp_path, npz_path) # atomic on POSIX except Exception: if os.path.exists(tmp_path): os.unlink(tmp_path) print(f"Warning: Failed to save COLMAP cache to {npz_path}. This may cause slow loading in the future.") return points, points_rgb, camtoworlds def forward( self, context: BatchedViews, visualization_dump: Optional[dict] = None, device: Optional[torch.device] = None, **kwargs ) -> InitializerOutput: verbose = False # context not used # assert COLMAP dir exists if not self.cfg.path.exists(): raise ValueError(f"COLMAP dir {self.cfg.path} does not exist.") if "scene" in kwargs: scene_name = kwargs["scene"] assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" scene_name = scene_name[0] if self.cfg.dl3dv_settings: scene_name = scene_name.replace("dl3dv_", "") if verbose: print(f"Initializing scene '{scene_name}' from COLMAP at {self.cfg.path}.") datadir = self.cfg.path / scene_name if not datadir.exists(): raise ValueError(f"COLMAP scene dir {datadir} does not exist.") else: datadir = self.cfg.path # run COLMAP parser (cached after first load) points_xyz, points_rgb, camtoworlds = self._load_colmap(datadir) if verbose: print(f"Loaded {points_xyz.shape[0]} points from COLMAP.") xyz = torch.from_numpy(points_xyz).float().to(device) rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device) if self.cfg.filter_zero_rgb: # Filter out points with 0,0,0 RGB values (these are often outliers in COLMAP reconstructions) valid_mask = (rgbs.sum(dim=-1) > 0) xyz = xyz[valid_mask] rgbs = rgbs[valid_mask] if self.cfg.dl3dv_settings: assert "target" in kwargs, "Target key is required in kwargs for COLMAP initializer with dl3dv format." target = kwargs["target"] # In some configration we might move the batch to device later, so we want to keep the device consistent batch_device = target['extrinsics'].device context_c2w_dataset = context['extrinsics'] # (b, V, 4, 4) c2w_colmap = torch.from_numpy(camtoworlds).to(device=batch_device, dtype=context_c2w_dataset.dtype) # (N, 4, 4) # T_world = c2w_dataset[0] @ c2w_colmap[0].inverse() # eps = 1e-3 # T_world[T_world.abs() < eps] = 0 # T_world[(T_world - 1.0).abs() < eps] = 1.0 # T_world[(T_world + 1.0).abs() < eps] = -1.0 T_world = torch.tensor([[0., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., -1., 0.], [0., 0., 0., 1.]], device=batch_device, dtype=context_c2w_dataset.dtype) # hard coded for dl3dv colmap reconstructions c2w_dataset_predicted = T_world @ c2w_colmap # Assume only one scene in the batch context_x_flipped = context['x_flipped'][0] target_x_flipped = target['x_flipped'][0] assert context_x_flipped == target_x_flipped, "Context and target x_flipped values must match." x_flipped = context_x_flipped flip_transform = torch.eye(4, device=batch_device, dtype=context_c2w_dataset.dtype) flip_transform[0, 0] = -1.0 if x_flipped: c2w_dataset_predicted = flip_transform @ c2w_dataset_predicted @ flip_transform # Overriding the dataset poses with the COLMAP to ensure consistency if self.cfg.override_dataset_poses: context_indices = context['index'][0] new_context_c2w = c2w_dataset_predicted[context_indices] new_context_c2w = new_context_c2w[None, ...] # (1, V, 4, 4) context['extrinsics'] = new_context_c2w target_indices = target['index'][0] new_target_c2w = c2w_dataset_predicted[target_indices] new_target_c2w = new_target_c2w[None, ...] target['extrinsics'] = new_target_c2w xyz = xyz.to(device) xyz = T_world.to(device) @ torch.cat([xyz, torch.ones_like(xyz[:, :1])], dim=-1).T if x_flipped: xyz[0] *= -1.0 xyz = xyz[:3, :].T # ── Step 1: subsampling augmentation ───────────────────────────────────── min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample if min_sub is not None or max_sub is not None: target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub) if xyz.shape[0] > target_count: indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count] xyz = xyz[indices] rgbs = rgbs[indices] # ── Step 2: subsample to fixed count before knn (so distances are correct) # If current number of points exceeds the fixed count, we subsample to the fixed count (for DDP consistency). fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num if fixed_num is not None and xyz.shape[0] > fixed_num: indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num] xyz = xyz[indices] rgbs = rgbs[indices] if xyz.shape[0] == 0: black_gaussians_num = (points_rgb == 0).all(axis=-1).sum() raise SkipBatchException(f"No valid points found in COLMAP data for scene {datadir}. Skipping batch. " f"Originally {points_xyz.shape[0]} points. Black gaussian num {black_gaussians_num}.") # ── Step 3: knn-based scale initialisation ─────────────────────────────── dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] # Initialize opacities with optional randomization if self.cfg.randomize_opacity: if self.cfg.randomize_opacity_distribution == "uniform": # Randomize opacities uniformly between min and max opacities = (torch.rand(xyz.shape[0], device=xyz.device) * (self.cfg.init_opacity - self.cfg.randomize_opacity_min)) + self.cfg.randomize_opacity_min elif self.cfg.randomize_opacity_distribution == "gaussian": # Randomize opacities with a Gaussian distribution mean = self.cfg.init_opacity stddev = self.cfg.randomize_opacity_std opacities = torch.normal(mean, stddev, size=(xyz.shape[0],), device=xyz.device) opacities = opacities.clamp(0, 1) # Clamp to ensure valid values else: raise ValueError(f"Unknown randomize_opacity_distribution: {self.cfg.randomize_opacity_distribution}") else: opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity) nr_valid = xyz.shape[0] # ── Step 4: pad to fixed count for DDP consistency ─────────────────────── if fixed_num is not None and xyz.shape[0] < fixed_num: pad = fixed_num - xyz.shape[0] xyz = F.pad(xyz, (0, 0, 0, pad), value=0.0) rgbs = F.pad(rgbs, (0, 0, 0, pad), value=0.0) scales = F.pad(scales, (0, 0, 0, pad), value=1e-10) opacities = F.pad(opacities, (0, pad), value=1e-10) # TODO Naama: might be a problem if we don't freeze zero-grad gaussians points_dict = { "xyz": xyz, "rgb": rgbs, "scales": scales, "opacities": opacities, } points_dict["scales"] *= self.cfg.scaling_factor # pre-activation values on device gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) means = gaussians_dict["xyz"] sh0 = gaussians_dict["sh0"] shN = gaussians_dict["shN"] if shN is not None: harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] else: harmonics = sh0 harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] rotations_unnorm = gaussians_dict["rotations_unnorm"] # post-activation values opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) scales = torch.exp(gaussians_dict["scales_raw"]) rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) covariances = build_covariance(scale=scales, rotation_xyzw=rotations) gaussians = Gaussians( means=means.unsqueeze(0), covariances=covariances.unsqueeze(0), harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d] opacities=opacities.unsqueeze(0), scales=scales.unsqueeze(0), rotations=rotations.unsqueeze(0), rotations_unnorm=rotations_unnorm.unsqueeze(0), nr_valid=nr_valid ) return InitializerOutput( gaussians=gaussians, features=None, depths=None ) @staticmethod def _sample_num_gaussians( available: int, min_val: int | float | None, max_val: int | float | None, ) -> int: if min_val is None and max_val is None: return available assert min_val is not None and max_val is not None, \ "Both min and max must be set together for Gaussian subsampling." assert type(min_val) == type(max_val), \ "min and max must be the same type (both int or both float)." if isinstance(min_val, int): count = torch.randint(min_val, max_val + 1, (1,)).item() else: assert 0.0 < min_val <= 1.0 and 0.0 < max_val <= 1.0, \ "Float subsampling ratios must be in (0, 1]." ratio = torch.empty(1).uniform_(min_val, max_val).item() count = int(available * ratio) return min(count, available)