"""Convention bridges between inria 3DGS objects and optgs types. * Gaussians: via the original-3DGS PLY schema. ``optgs/model/ply_export.py`` (``load_gaussians_ply`` / ``save_gaussian_ply``) and inria ``GaussianModel.save_ply`` / ``load_ply`` write/read the *same* schema (scales log<->exp, opacity logit<->sigmoid, quat wxyz<->xyzw, SH dc/rest), so a PLY round-trip is convention-correct by construction. * Cameras: inria stores world->camera ``R,T`` + FoV (COLMAP convention); optgs wants camera->world extrinsics and image-size-normalized intrinsics. We reuse inria's own ``getWorld2View2`` / ``fov2focal`` for the forward direction so the round trip through ``optgs.geometry.projection.get_fov`` is exact. All inria imports are deferred into functions (the inria repo is only on ``sys.path`` when the caller runs from it). """ from __future__ import annotations import tempfile from pathlib import Path from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: # pragma: no cover - typing only import torch from optgs.model.types import Gaussians def _import_inria_graphics(): try: from utils.graphics_utils import fov2focal, getWorld2View2 # type: ignore except Exception as e: # ImportError or deeper from optgs.experimental.api.integration.scene_protocol import OptGSError raise OptGSError( "could not import inria graphics utils (utils.graphics_utils). " "Run from your inria gaussian-splatting checkout (so it is on " "sys.path), or use OptGS.initialize_from_ply / " "initialize_from_tensors. " f"Original error: {type(e).__name__}: {e}" ) from e return getWorld2View2, fov2focal # --------------------------------------------------------------------------- # Gaussians # --------------------------------------------------------------------------- def optgs_gaussians_from_ply( ply_path: str | Path, *, sh_degree: int, device: "torch.device", dtype: "torch.dtype", ) -> "Gaussians": """Load a 3DGS PLY into an optgs ``Gaussians`` (batch=1, post-activation).""" from optgs.model.ply_export import load_gaussians_ply from optgs.scene_trainer.common.gaussians import build_covariance g = load_gaussians_ply(str(ply_path), max_sh_degree=sh_degree) g = g.to(device=device, dtype=dtype) # Populate covariances so any depth / use_covariances path is safe (the # default color path recomputes from scales+rotations anyway). try: g.covariances = build_covariance(g.scales[0], g.rotations[0]).unsqueeze(0) except Exception: g.covariances = None return g def optgs_gaussians_from_inria_model( gm: object, *, device: "torch.device", dtype: "torch.dtype", ) -> "Gaussians": """Ingest an inria ``GaussianModel`` via a temp PLY round-trip.""" sh_degree = int(getattr(gm, "max_sh_degree", 3)) with tempfile.TemporaryDirectory(prefix="optgs_ingest_") as d: tmp = Path(d) / "gaussians.ply" gm.save_ply(str(tmp)) # inria writer (original-3DGS schema) return optgs_gaussians_from_ply( tmp, sh_degree=sh_degree, device=device, dtype=dtype ) def write_back_to_inria_model(gm: object, final: "Gaussians") -> None: """Replace an inria ``GaussianModel``'s params with refined Gaussians. Full-replacement semantics: the learned ADC may change the point count, so we reallocate every parameter (via inria ``load_ply``) and reset the inria Adam optimizer + densification accumulators. The caller must call ``gaussians.training_setup(...)`` again before resuming inria Adam. """ import torch from optgs.model.ply_export import save_gaussian_ply with tempfile.TemporaryDirectory(prefix="optgs_writeback_") as d: tmp = Path(d) / "refined.ply" # save_gaussian_ply: B==1, xyzw->wxyz, re-inverts activations. save_gaussian_ply(final, save_path=tmp) gm.load_ply(str(tmp)) # reallocs _xyz/_features_*/_opacity/_scaling/_rotation n = gm._xyz.shape[0] dev = gm._xyz.device gm.optimizer = None gm.xyz_gradient_accum = torch.zeros((n, 1), device=dev) gm.denom = torch.zeros((n, 1), device=dev) gm.max_radii2D = torch.zeros((n,), device=dev) # --------------------------------------------------------------------------- # Cameras # --------------------------------------------------------------------------- def batched_views_from_cameras( cameras: Sequence[object], *, scene_scale: float, device: "torch.device", dtype: "torch.dtype", near: float = 0.01, far: float = 100.0, ): """Build an optgs ``BatchedViews`` (B=1) from inria-style cameras. ``near``/``far`` default to inria's hardcoded ``znear=0.01``/``zfar=100.0`` (also the optgs colmap-dataset constants). All cameras must share one (H, W) — the decoder takes a single image shape. """ import torch from optgs.dataset.data_types import BatchedViews getWorld2View2, fov2focal = _import_inria_graphics() if len(cameras) == 0: from optgs.experimental.api.integration.scene_protocol import OptGSError raise OptGSError("no cameras provided.") exts, intrs, imgs = [], [], [] H0 = W0 = None for cam in cameras: W = int(cam.image_width) H = int(cam.image_height) if H0 is None: H0, W0 = H, W elif (H, W) != (H0, W0): from optgs.experimental.api.integration.scene_protocol import OptGSError raise OptGSError( f"all train cameras must share one (H, W); got {(H, W)} vs " f"{(H0, W0)}. Render to a single resolution before optimizing." ) w2c = torch.tensor( getWorld2View2(cam.R, cam.T), dtype=torch.float32 ) # [4,4] world->camera c2w = torch.inverse(w2c) # optgs extrinsics convention fx = fov2focal(cam.FoVx, W) fy = fov2focal(cam.FoVy, H) cx = float(getattr(cam, "cx", W / 2.0)) cy = float(getattr(cam, "cy", H / 2.0)) K = torch.eye(3, dtype=torch.float32) K[0, 0] = fx / W # normalized focal K[1, 1] = fy / H K[0, 2] = cx / W # normalized principal point K[1, 2] = cy / H img = cam.original_image if not torch.is_tensor(img): img = torch.as_tensor(img) img = img.float().clamp(0.0, 1.0) # [3, H, W] exts.append(c2w) intrs.append(K) imgs.append(img) V = len(cameras) extrinsics = torch.stack(exts).unsqueeze(0).to(device=device, dtype=dtype) intrinsics = torch.stack(intrs).unsqueeze(0).to(device=device, dtype=dtype) image = torch.stack(imgs).unsqueeze(0).to(device=device, dtype=dtype) return BatchedViews.from_dict( { "extrinsics": extrinsics, "intrinsics": intrinsics, "image": image, "near": torch.full((1, V), near, device=device, dtype=dtype), "far": torch.full((1, V), far, device=device, dtype=dtype), "index": torch.arange(V, device=device).unsqueeze(0), "scene_scale": torch.tensor([float(scene_scale)], device=device, dtype=dtype), } )