Spaces:
Sleeping
Sleeping
| """Convention bridges between inria 3DGS objects and optgs types. | |
| * Gaussians: via the original-3DGS PLY schema. ``optgs/model/ply_export.py`` | |
| (``load_gaussians_ply`` / ``save_gaussian_ply``) and inria | |
| ``GaussianModel.save_ply`` / ``load_ply`` write/read the *same* schema | |
| (scales log<->exp, opacity logit<->sigmoid, quat wxyz<->xyzw, SH dc/rest), | |
| so a PLY round-trip is convention-correct by construction. | |
| * Cameras: inria stores world->camera ``R,T`` + FoV (COLMAP convention); | |
| optgs wants camera->world extrinsics and image-size-normalized intrinsics. | |
| We reuse inria's own ``getWorld2View2`` / ``fov2focal`` for the forward | |
| direction so the round trip through ``optgs.geometry.projection.get_fov`` is | |
| exact. | |
| All inria imports are deferred into functions (the inria repo is only on | |
| ``sys.path`` when the caller runs from it). | |
| """ | |
| from __future__ import annotations | |
| import tempfile | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Sequence | |
| if TYPE_CHECKING: # pragma: no cover - typing only | |
| import torch | |
| from optgs.model.types import Gaussians | |
| def _import_inria_graphics(): | |
| try: | |
| from utils.graphics_utils import fov2focal, getWorld2View2 # type: ignore | |
| except Exception as e: # ImportError or deeper | |
| from optgs.experimental.api.integration.scene_protocol import OptGSError | |
| raise OptGSError( | |
| "could not import inria graphics utils (utils.graphics_utils). " | |
| "Run from your inria gaussian-splatting checkout (so it is on " | |
| "sys.path), or use OptGS.initialize_from_ply / " | |
| "initialize_from_tensors. " | |
| f"Original error: {type(e).__name__}: {e}" | |
| ) from e | |
| return getWorld2View2, fov2focal | |
| # --------------------------------------------------------------------------- | |
| # Gaussians | |
| # --------------------------------------------------------------------------- | |
| def optgs_gaussians_from_ply( | |
| ply_path: str | Path, | |
| *, | |
| sh_degree: int, | |
| device: "torch.device", | |
| dtype: "torch.dtype", | |
| ) -> "Gaussians": | |
| """Load a 3DGS PLY into an optgs ``Gaussians`` (batch=1, post-activation).""" | |
| from optgs.model.ply_export import load_gaussians_ply | |
| from optgs.scene_trainer.common.gaussians import build_covariance | |
| g = load_gaussians_ply(str(ply_path), max_sh_degree=sh_degree) | |
| g = g.to(device=device, dtype=dtype) | |
| # Populate covariances so any depth / use_covariances path is safe (the | |
| # default color path recomputes from scales+rotations anyway). | |
| try: | |
| g.covariances = build_covariance(g.scales[0], g.rotations[0]).unsqueeze(0) | |
| except Exception: | |
| g.covariances = None | |
| return g | |
| def optgs_gaussians_from_inria_model( | |
| gm: object, | |
| *, | |
| device: "torch.device", | |
| dtype: "torch.dtype", | |
| ) -> "Gaussians": | |
| """Ingest an inria ``GaussianModel`` via a temp PLY round-trip.""" | |
| sh_degree = int(getattr(gm, "max_sh_degree", 3)) | |
| with tempfile.TemporaryDirectory(prefix="optgs_ingest_") as d: | |
| tmp = Path(d) / "gaussians.ply" | |
| gm.save_ply(str(tmp)) # inria writer (original-3DGS schema) | |
| return optgs_gaussians_from_ply( | |
| tmp, sh_degree=sh_degree, device=device, dtype=dtype | |
| ) | |
| def write_back_to_inria_model(gm: object, final: "Gaussians") -> None: | |
| """Replace an inria ``GaussianModel``'s params with refined Gaussians. | |
| Full-replacement semantics: the learned ADC may change the point count, so | |
| we reallocate every parameter (via inria ``load_ply``) and reset the inria | |
| Adam optimizer + densification accumulators. The caller must call | |
| ``gaussians.training_setup(...)`` again before resuming inria Adam. | |
| """ | |
| import torch | |
| from optgs.model.ply_export import save_gaussian_ply | |
| with tempfile.TemporaryDirectory(prefix="optgs_writeback_") as d: | |
| tmp = Path(d) / "refined.ply" | |
| # save_gaussian_ply: B==1, xyzw->wxyz, re-inverts activations. | |
| save_gaussian_ply(final, save_path=tmp) | |
| gm.load_ply(str(tmp)) # reallocs _xyz/_features_*/_opacity/_scaling/_rotation | |
| n = gm._xyz.shape[0] | |
| dev = gm._xyz.device | |
| gm.optimizer = None | |
| gm.xyz_gradient_accum = torch.zeros((n, 1), device=dev) | |
| gm.denom = torch.zeros((n, 1), device=dev) | |
| gm.max_radii2D = torch.zeros((n,), device=dev) | |
| # --------------------------------------------------------------------------- | |
| # Cameras | |
| # --------------------------------------------------------------------------- | |
| def batched_views_from_cameras( | |
| cameras: Sequence[object], | |
| *, | |
| scene_scale: float, | |
| device: "torch.device", | |
| dtype: "torch.dtype", | |
| near: float = 0.01, | |
| far: float = 100.0, | |
| ): | |
| """Build an optgs ``BatchedViews`` (B=1) from inria-style cameras. | |
| ``near``/``far`` default to inria's hardcoded ``znear=0.01``/``zfar=100.0`` | |
| (also the optgs colmap-dataset constants). All cameras must share one | |
| (H, W) — the decoder takes a single image shape. | |
| """ | |
| import torch | |
| from optgs.dataset.data_types import BatchedViews | |
| getWorld2View2, fov2focal = _import_inria_graphics() | |
| if len(cameras) == 0: | |
| from optgs.experimental.api.integration.scene_protocol import OptGSError | |
| raise OptGSError("no cameras provided.") | |
| exts, intrs, imgs = [], [], [] | |
| H0 = W0 = None | |
| for cam in cameras: | |
| W = int(cam.image_width) | |
| H = int(cam.image_height) | |
| if H0 is None: | |
| H0, W0 = H, W | |
| elif (H, W) != (H0, W0): | |
| from optgs.experimental.api.integration.scene_protocol import OptGSError | |
| raise OptGSError( | |
| f"all train cameras must share one (H, W); got {(H, W)} vs " | |
| f"{(H0, W0)}. Render to a single resolution before optimizing." | |
| ) | |
| w2c = torch.tensor( | |
| getWorld2View2(cam.R, cam.T), dtype=torch.float32 | |
| ) # [4,4] world->camera | |
| c2w = torch.inverse(w2c) # optgs extrinsics convention | |
| fx = fov2focal(cam.FoVx, W) | |
| fy = fov2focal(cam.FoVy, H) | |
| cx = float(getattr(cam, "cx", W / 2.0)) | |
| cy = float(getattr(cam, "cy", H / 2.0)) | |
| K = torch.eye(3, dtype=torch.float32) | |
| K[0, 0] = fx / W # normalized focal | |
| K[1, 1] = fy / H | |
| K[0, 2] = cx / W # normalized principal point | |
| K[1, 2] = cy / H | |
| img = cam.original_image | |
| if not torch.is_tensor(img): | |
| img = torch.as_tensor(img) | |
| img = img.float().clamp(0.0, 1.0) # [3, H, W] | |
| exts.append(c2w) | |
| intrs.append(K) | |
| imgs.append(img) | |
| V = len(cameras) | |
| extrinsics = torch.stack(exts).unsqueeze(0).to(device=device, dtype=dtype) | |
| intrinsics = torch.stack(intrs).unsqueeze(0).to(device=device, dtype=dtype) | |
| image = torch.stack(imgs).unsqueeze(0).to(device=device, dtype=dtype) | |
| return BatchedViews.from_dict( | |
| { | |
| "extrinsics": extrinsics, | |
| "intrinsics": intrinsics, | |
| "image": image, | |
| "near": torch.full((1, V), near, device=device, dtype=dtype), | |
| "far": torch.full((1, V), far, device=device, dtype=dtype), | |
| "index": torch.arange(V, device=device).unsqueeze(0), | |
| "scene_scale": torch.tensor([float(scene_scale)], device=device, dtype=dtype), | |
| } | |
| ) | |